diff --git a/.checklist.yaml b/.checklist.yaml new file mode 100644 index 000000000..f0c211714 --- /dev/null +++ b/.checklist.yaml @@ -0,0 +1,30 @@ +apiVersion: quintoandar.com.br/checklist/v2 +kind: ServiceChecklist +metadata: + name: butterfree +spec: + description: >- + A solution for Feature Stores. + + costCenter: C055 + department: engineering + lifecycle: production + docs: true + + ownership: + team: data_products_mlops + line: tech_platform + owner: otavio.cals@quintoandar.com.br + + libraries: + - name: butterfree + type: common-usage + path: https://quintoandar.github.io/python-package-server/ + description: A lib to build Feature Stores. + registries: + - github-packages + tier: T0 + + channels: + squad: 'mlops' + alerts: 'data-products-reports' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 7dff34a78..8b4d9c73c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -4,16 +4,26 @@ on: paths: - 'setup.py' - jobs: Pipeline: - if: github.ref == 'refs/heads/master' || contains(github.ref, 'hotfix/') - - runs-on: ubuntu-16.04 - container: quintoandar/python-3-7-java + if: github.ref == 'refs/heads/master' + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' - name: Install dependencies run: make ci-install @@ -24,6 +34,10 @@ jobs: - name: Get version run: echo "version=$(grep __version__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2)" >> $GITHUB_ENV + - name: Get release notes + id: get_release_notes + uses: ffurrer2/extract-release-notes@v1 + - name: Create release uses: actions/create-release@v1 env: @@ -31,6 +45,7 @@ jobs: with: tag_name: ${{ env.version }} release_name: Release ${{ env.version }} + body: ${{ steps.get_release_notes.outputs.release_notes }} - name: Release already exist if: ${{ failure() }} diff --git a/.github/workflows/skip_lint.yml b/.github/workflows/skip_lint.yml new file mode 100644 index 000000000..1c768a238 --- /dev/null +++ b/.github/workflows/skip_lint.yml @@ -0,0 +1,17 @@ +# This step is used only because we want to mark the runner-linter check as required +# for PRs to develop, but not for the merge queue to merge into develop, +# github does not have this functionality yet + +name: 'Skip github-actions/runner-linter check at merge queue' + +on: + merge_group: + +jobs: + empty_job: + name: 'github-actions/runner-linter' + runs-on: github-actions-developers-runner + steps: + - name: Skip github-actions/runner-linter check at merge queue + run: | + echo "Done" diff --git a/.github/workflows/staging.yml b/.github/workflows/staging.yml new file mode 100644 index 000000000..9885ba688 --- /dev/null +++ b/.github/workflows/staging.yml @@ -0,0 +1,55 @@ +name: "Publish Dev Package" +on: + push: + paths: + - "setup.py" + +jobs: + Pipeline: + if: github.ref == 'refs/heads/staging' + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' + + - name: Install dependencies + run: make ci-install + + - name: Get version + run: echo "version=$(grep __version__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2 )" >> $GITHUB_ENV + + - name: Build package + run: make package + + - name: Create release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ env.version }} + release_name: Release ${{ env.version }} + prerelease: true + + - name: Release already exist + if: ${{ failure() }} + run: echo Release already exist + + - name: Publish release to pypi.org + if: ${{ success() }} + env: + PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }} + PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: PYTHONPATH=./pip/deps python -m twine upload -u $PYPI_USERNAME -p $PYPI_PASSWORD --verbose dist/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 29394a0e8..96ad666f9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,16 +3,29 @@ on: push: branches: - master + - staging - hotfix/** pull_request: jobs: Pipeline: - runs-on: ubuntu-16.04 - container: quintoandar/python-3-7-java + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' - name: Install dependencies run: make ci-install diff --git a/.gitignore b/.gitignore index 72b591f39..0c59b49ab 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ coverage.xml *.cover .hypothesis/ *cov.xml +test_folder/ # Translations *.mo @@ -65,6 +66,7 @@ instance/ # PyBuilder target/ +pip/ # Jupyter Notebook .ipynb_checkpoints diff --git a/CHANGELOG.md b/CHANGELOG.md index 48b5cbf1a..19d9b5f41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,108 @@ All notable changes to this project will be documented in this file. Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each release or unreleased log for a better organization. +## [Unreleased] + +## [1.4.0](https://github.com/quintoandar/butterfree/releases/tag/1.4.0) +* Add Delta support ([#370](https://github.com/quintoandar/butterfree/pull/370)) + +## [1.3.5](https://github.com/quintoandar/butterfree/releases/tag/1.3.5) +* Auto create feature sets ([#368](https://github.com/quintoandar/butterfree/pull/368)) + +## [1.3.4](https://github.com/quintoandar/butterfree/releases/tag/1.3.4) +* Fix Cassandra Config and tests ([#366](https://github.com/quintoandar/butterfree/pull/366)) + +## [1.3.3](https://github.com/quintoandar/butterfree/releases/tag/1.3.3) +* Fix Cassandra Config and Numpy version ([#364](https://github.com/quintoandar/butterfree/pull/364)) + +## [1.3.2](https://github.com/quintoandar/butterfree/releases/tag/1.3.2) +* Fix publish script ([#362](https://github.com/quintoandar/butterfree/pull/362)) + +## [1.3.2](https://github.com/quintoandar/butterfree/releases/tag/1.3.2) +* Fix publish script ([#360](https://github.com/quintoandar/butterfree/pull/362)) + +## [1.3.1](https://github.com/quintoandar/butterfree/releases/tag/1.3.1) +* Timestamp NTZ available ([#360](https://github.com/quintoandar/butterfree/pull/360)) + +## [1.3.0](https://github.com/quintoandar/butterfree/releases/tag/1.3.0) +* Bump versions ([#355](https://github.com/quintoandar/butterfree/pull/355)) +* Sphinx version ([#356](https://github.com/quintoandar/butterfree/pull/356)) + +## [1.2.4](https://github.com/quintoandar/butterfree/releases/tag/1.2.4) +* Auto create feature sets ([#351](https://github.com/quintoandar/butterfree/pull/351)) + +## [1.2.3](https://github.com/quintoandar/butterfree/releases/tag/1.2.3) +* Optional params ([#347](https://github.com/quintoandar/butterfree/pull/347)) + +### Changed +* Optional row count validation ([#340](https://github.com/quintoandar/butterfree/pull/340)) + +## [1.2.2](https://github.com/quintoandar/butterfree/releases/tag/1.2.2) + +### Changed +* Optional row count validation ([#284](https://github.com/quintoandar/butterfree/pull/284)) +* Bump several libs versions ([#333](https://github.com/quintoandar/butterfree/pull/333)) + +## [1.2.1](https://github.com/quintoandar/butterfree/releases/tag/1.2.1) + +### Changed +* Update README.md ([#331](https://github.com/quintoandar/butterfree/pull/331)) +* Update Github Actions Workflow runner ([#332](https://github.com/quintoandar/butterfree/pull/332)) +* Delete sphinx version. ([#334](https://github.com/quintoandar/butterfree/pull/334)) + +### Fixed +* Add the missing link for H3 geohash ([#330](https://github.com/quintoandar/butterfree/pull/330)) + +## [1.2.0](https://github.com/quintoandar/butterfree/releases/tag/1.2.0) + +### Added +* [MLOP-636] Create migration classes ([#282](https://github.com/quintoandar/butterfree/pull/282)) +* [MLOP-635] Rebase Incremental Job/Interval Run branch for test on selected feature sets ([#278](https://github.com/quintoandar/butterfree/pull/278)) +* Allow slide selection ([#293](https://github.com/quintoandar/butterfree/pull/293)) +* [MLOP-637] Implement diff method ([#292](https://github.com/quintoandar/butterfree/pull/292)) +* [MLOP-640] Create CLI with migrate command ([#298](https://github.com/quintoandar/butterfree/pull/298)) +* [MLOP-645] Implement query method, cassandra ([#291](https://github.com/quintoandar/butterfree/pull/291)) +* [MLOP-671] Implement get_schema on Spark client ([#301](https://github.com/quintoandar/butterfree/pull/301)) +* [MLOP-648] Implement query method, metastore ([#294](https://github.com/quintoandar/butterfree/pull/294)) +* [MLOP-647] / [MLOP-646] Apply migrations ([#300](https://github.com/quintoandar/butterfree/pull/300)) +* [MLOP-639] Track logs in S3 ([#306](https://github.com/quintoandar/butterfree/pull/306)) +* [MLOP-702] Debug mode for Automate Migration ([#322](https://github.com/quintoandar/butterfree/pull/322)) + +### Changed +* Keep milliseconds when using 'from_ms' argument in timestamp feature ([#284](https://github.com/quintoandar/butterfree/pull/284)) +* Read and write consistency level options ([#309](https://github.com/quintoandar/butterfree/pull/309)) +* [MLOP-691] Include step to add partition to SparkMetastore during writing of Butterfree ([#327](https://github.com/quintoandar/butterfree/pull/327)) + +### Fixed +* [BUG] Apply create_partitions to historical validate ([#303](https://github.com/quintoandar/butterfree/pull/303)) +* [BUG] Fix key path for validate read ([#304](https://github.com/quintoandar/butterfree/pull/304)) +* [FIX] Add Partition types for Metastore ([#305](https://github.com/quintoandar/butterfree/pull/305)) +* Change solution for tracking logs ([#308](https://github.com/quintoandar/butterfree/pull/308)) +* [BUG] Fix Cassandra Connect Session ([#316](https://github.com/quintoandar/butterfree/pull/316)) +* Fix method to generate agg feature name. ([#326](https://github.com/quintoandar/butterfree/pull/326)) + +## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3) +### Added +* [MLOP-599] Apply mypy to ButterFree ([#273](https://github.com/quintoandar/butterfree/pull/273)) + +### Changed +* [MLOP-634] Butterfree dev workflow, set triggers for branches staging and master ([#280](https://github.com/quintoandar/butterfree/pull/280)) +* Keep milliseconds when using 'from_ms' argument in timestamp feature ([#284](https://github.com/quintoandar/butterfree/pull/284)) +* [MLOP-633] Butterfree dev workflow, update documentation ([#281](https://github.com/quintoandar/butterfree/commit/74278986a49f1825beee0fd8df65a585764e5524)) +* [MLOP-632] Butterfree dev workflow, automate release description ([#279](https://github.com/quintoandar/butterfree/commit/245eaa594846166972241b03fddc61ee5117b1f7)) + +### Fixed +* Change trigger for pipeline staging ([#287](https://github.com/quintoandar/butterfree/pull/287)) + +## [1.1.2](https://github.com/quintoandar/butterfree/releases/tag/1.1.2) +### Fixed +* [HOTFIX] Add both cache and count back to Butterfree ([#274](https://github.com/quintoandar/butterfree/pull/274)) +* [MLOP-606] Change docker image in Github Actions Pipeline ([#275](https://github.com/quintoandar/butterfree/pull/275)) +* FIX Read the Docs build ([#272](https://github.com/quintoandar/butterfree/pull/272)) +* [BUG] Fix style ([#271](https://github.com/quintoandar/butterfree/pull/271)) +* [MLOP-594] Remove from_column in some transforms ([#270](https://github.com/quintoandar/butterfree/pull/270)) +* [MLOP-536] Rename S3 config to Metastore config ([#269](https://github.com/quintoandar/butterfree/pull/269)) + ## [1.1.1](https://github.com/quintoandar/butterfree/releases/tag/1.2.0) ### Added * [MLOP-590] Adapt KafkaConfig to receive a custom topic name ([#266](https://github.com/quintoandar/butterfree/pull/266)) @@ -15,11 +117,11 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each * Update README ([#257](https://github.com/quintoandar/butterfree/pull/257)) ### Fixed -* Fix Butterfree's workflow ([#262](https://github.com/quintoandar/butterfree/pull/262)) +* Fix Butterfree's workflow ([#262](https://github.com/quintoandar/butterfree/pull/262)) * [FIX] Downgrade Python Version in Pyenv ([#227](https://github.com/quintoandar/butterfree/pull/227)) -* [FIX] Fix docs ([#229](https://github.com/quintoandar/butterfree/pull/229)) +* [FIX] Fix docs ([#229](https://github.com/quintoandar/butterfree/pull/229)) * [FIX] Fix Docs - Add more dependencies ([#230](https://github.com/quintoandar/butterfree/pull/230)) -* Fix broken notebook URL ([#236](https://github.com/quintoandar/butterfree/pull/236)) +* Fix broken notebook URL ([#236](https://github.com/quintoandar/butterfree/pull/236)) * Issue #77 Fix ([#245](https://github.com/quintoandar/butterfree/pull/245)) ## [1.0.2](https://github.com/quintoandar/butterfree/releases/tag/1.0.2) @@ -30,7 +132,7 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each * [MLOP-426] Change branching strategy on butterfree to use only master branch ([#216](https://github.com/quintoandar/butterfree/pull/216)) ### Fixed -* [MLOP-440] Python 3.7 bump and Fixing Dependencies ([#220](https://github.com/quintoandar/butterfree/pull/220)) +* [MLOP-440] Python 3.7 bump and Fixing Dependencies ([#220](https://github.com/quintoandar/butterfree/pull/220)) ## [1.0.1](https://github.com/quintoandar/butterfree/releases/tag/1.0.1) ### Added @@ -229,4 +331,4 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each * [MLOP-143] Fix Bugs for HouseMain FeatureSet ([#62](https://github.com/quintoandar/butterfree/pull/62)) ## [0.1.0](https://github.com/quintoandar/butterfree/releases/tag/0.1.0) -* First modules and entities of butterfree package. \ No newline at end of file +* First modules and entities of butterfree package. diff --git a/Makefile b/Makefile index 41ad00ab4..a93104ab3 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,16 @@ +# globals + +PACKAGE_NAME := $(shell grep __package_name__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2) +VERSION := $(shell grep __version__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2) + + +#custom targets + .PHONY: environment ## create virtual environment for butterfree environment: - @pyenv install -s 3.7.6 - @pyenv virtualenv 3.7.6 butterfree + @pyenv install -s 3.9.19 + @pyenv virtualenv 3.9.19 butterfree @pyenv local butterfree @PYTHONPATH=. python -m pip install --upgrade pip @@ -28,7 +36,7 @@ minimum-requirements: .PHONY: requirements ## install all requirements -requirements: requirements-test requirements-lint dev-requirements minimum-requirements +requirements: minimum-requirements dev-requirements requirements-test requirements-lint .PHONY: ci-install ci-install: @@ -68,7 +76,7 @@ style-check: @echo "Code Style" @echo "==========" @echo "" - @python -m black --check -t py36 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1) + @python -m black --check -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1) .PHONY: quality-check ## run code quality checks with flake8 @@ -96,8 +104,8 @@ checks: style-check quality-check type-check .PHONY: apply-style ## fix stylistic errors with black apply-style: - @python -m black -t py36 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . - @python -m isort -rc butterfree/ tests/ + @python -m black -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . + @python -m isort --atomic butterfree/ tests/ .PHONY: clean ## clean unused artifacts @@ -114,32 +122,38 @@ clean: @find ./ -type f -name 'coverage.xml' -exec rm -f {} \; @find ./ -type f -name '.coverage*' -exec rm -f {} \; @find ./ -type f -name '*derby.log' -exec rm -f {} \; + @find ./ -type f -name 'logging.json' -exec rm -f {} \; @find ./ -name '*.pyc' -exec rm -f {} \; @find ./ -name '*.pyo' -exec rm -f {} \; @find ./ -name '*~' -exec rm -f {} \; .PHONY: version -## dump package name into VERSION env variable and show +## show version version: - @export VERSION=$(grep __version__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2) - @$(info VERSION is [${VERSION}]) + @echo "VERSION: $(VERSION)" + +.PHONY: change-version +## change the version to string received in the NEW_VERSION variable and show +change-version: + @sed -i 's/$(VERSION)/$(NEW_VERSION)/g' setup.py + @echo "VERSION: $(NEW_VERSION)" .PHONY: package-name -## dump package name into PACKAGE_NAME env variable and show +## show package name package-name: - @PACKAGE_NAME=$(grep __package_name__ setup.py | head -1 | cut -d \" -f2 | cut -d \' -f2 | sed 's/.*/&${build}/') - @echo $PACKAGE_NAME + @echo "PACKAGE_NAME: $(PACKAGE_NAME)" .PHONY: package ## build butterfree package wheel package: + @PYTHONPATH=. pip3 install wheel @PYTHONPATH=. python -m setup sdist bdist_wheel .PHONY: update-docs ## update Butterfree API docs update-docs: cd ./docs; rm -rf source/butterfree.* - cd ./docs; sphinx-apidoc -T -E -o source/ ../butterfree + cd ./docs; sphinx-apidoc -o source/ ../butterfree cd ./docs; make coverage .PHONY: docs @@ -208,4 +222,4 @@ help: } \ printf "\n"; \ }' \ - | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') \ No newline at end of file + | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') diff --git a/README.md b/README.md index d221d8666..7b93f000f 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ To learn how to use Butterfree in practice, see [Butterfree's notebook examples] ## Requirements and Installation Butterfree depends on **Python 3.7+** and it is **Spark 3.0 ready** :heavy_check_mark: -[Python Package Index](https://quintoandar.github.io/python-package-server/) hosts reference to a pip-installable module of this library, using it is as straightforward as including it on your project's requirements. +[PyPI hosts reference to a pip-installable module of this library](https://pypi.org/project/butterfree/), using it is as straightforward as including it on your project's requirements. ```bash pip install butterfree @@ -44,6 +44,8 @@ Or after listing `butterfree` in your `requirements.txt` file: pip install -r requirements.txt ``` +Dev Package are available for testing using the .devN versions of the Butterfree on PyPi. + ## License [Apache License 2.0](https://github.com/quintoandar/butterfree/blob/staging/LICENSE) diff --git a/WORKFLOW.md b/WORKFLOW.md index 601e37932..5eaa18cdd 100644 --- a/WORKFLOW.md +++ b/WORKFLOW.md @@ -2,20 +2,18 @@ ## Features -A feature is based on the `master` branch and merged back into the `master` branch. - -![](https://docs.microsoft.com/en-us/azure/devops/repos/git/media/branching-guidance/featurebranching.png?view=azure-devops) +A feature is based on the `staging` branch and merged back into the `staging` branch. ### Working Locally ``` -# checkout master, fetch the latest changes and pull them from remote into local -git checkout master +# checkout staging, fetch the latest changes and pull them from remote into local +git checkout staging git fetch -git pull origin master +git pull origin staging -# create a feature branch that is based off master +# create a feature branch that is based off staging git checkout -b /some-description # do your work @@ -24,10 +22,10 @@ git commit -m "first commit" git add another git commit -m "second commit" -# rebase against master to pull in any changes that have been made +# rebase against staging to pull in any changes that have been made # since you started your feature branch. git fetch -git rebase origin/master +git rebase origin/staging # push your local changes up to the remote git push @@ -35,41 +33,71 @@ git push # if you've already pushed changes and have rebased, your history has changed # so you will need to force the push git fetch -git rebase origin/master +git rebase origin/staging git push --force-with-lease ```` ### GitHub workflow -- Open a Pull Request against `master`. Check our PR guidelines [here](https://github.com/quintoandar/butterfree/blob/master/CONTRIBUTING.md#pull-request-guideline). +- Open a Pull Request against `staging`. Check our PR guidelines [here](https://github.com/quintoandar/butterfree/blob/master/CONTRIBUTING.md#pull-request-guideline). - When the Pull Request has been approved, merge using `squash and merge`, adding a brief description: ie, ` Enable stream pipelines in Butterfree`. - This squashes all your commits into a single clean commit. Remember to clean detailed descriptions, otherwise our git logs will be a mess. -If you are unable to squash merge because of conflicts, you need to rebase against `master` again: +If you are unable to squash merge because of conflicts, you need to rebase against `staging` again: ``` # in your feature branch git fetch -git rebase origin/master +git rebase origin/staging # fix conflicts if they exist git push --force-with-lease ``` +## Pre-Releases + +The pre-release will always occur when we change the version in the setup.py file to staging branch. + + +### Working Locally + +``` +# create a feature branch +git checkout staging +git fetch +git pull origin staging +git checkout -b pre-release/ + +# finalize the changelog in Unreleased and bump the version into setup.py then: +git add CHANGELOG.md +git add setup.py +git commit -m "pre-release " + +# push the new version +git fetch +git push --force-with-lease +``` + +### Github workflow + +- Open a Pull Request against `staging`. +- When the PR's approved and the code is tested, `squash and merge` to squash your commits into a single commit. +- The creation of the pre-release tag and the update of the PyPi version will be done +automatically from the Publish Dev Package workflow, you can follow [here](https://github.com/quintoandar/butterfree/actions?query=workflow%3A%22Publish+Dev+Package%22). ## Releases -The release will always occur when we change the version in the setup.py file. +The release will always occur when we change the version in the setup.py file to master branch. ### Working Locally ``` # create a feature branch -git checkout master +git checkout staging git fetch -git pull origin master +git pull origin staging git checkout -b release/ # finalize the changelog, bump the version into setup.py and update the documentation then: @@ -121,7 +149,6 @@ git checkout master@ git fetch git pull origin master git checkout -b hotfix/ -git checkout -b describe-the-problem git add patch.fix git add setup.py @@ -133,7 +160,7 @@ Don't forget to update the Changelog and the version in `setup.py`. ### Github workflow -- Open a Pull Request against `hotfix/` +- Open a Pull Request against `master`. - When the PR's approved and the code is tested, `squash and merge` to squash your commits into a single commit. - A tag will automatically be triggered in our CI/CD. This tag/release will use the version for its title and push a new version of Butterfree's python package to our private server. diff --git a/butterfree/_cli/__init__.py b/butterfree/_cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/butterfree/_cli/main.py b/butterfree/_cli/main.py new file mode 100644 index 000000000..b8b12f14d --- /dev/null +++ b/butterfree/_cli/main.py @@ -0,0 +1,9 @@ +import typer + +from butterfree._cli import migrate + +app = typer.Typer(no_args_is_help=True) +app.add_typer(migrate.app, name="migrate") + +if __name__ == "__main__": + app() diff --git a/butterfree/_cli/migrate.py b/butterfree/_cli/migrate.py new file mode 100644 index 000000000..f51615097 --- /dev/null +++ b/butterfree/_cli/migrate.py @@ -0,0 +1,194 @@ +import datetime +import importlib +import inspect +import os +import pkgutil +import sys +from typing import Set + +import boto3 +import setuptools +import typer +from botocore.exceptions import ClientError + +from butterfree.configs import environment +from butterfree.configs.logger import __logger +from butterfree.migrations.database_migration import ALLOWED_DATABASE +from butterfree.pipelines import FeatureSetPipeline + +app = typer.Typer( + help="Apply the automatic migrations in a database.", no_args_is_help=True +) + +logger = __logger("migrate", True) + + +def __find_modules(path: str) -> Set[str]: + modules = set() + for pkg in setuptools.find_packages(path): + modules.add(pkg) + pkg_path = path + "/" + pkg.replace(".", "/") + + # different usage for older python3 versions + if sys.version_info.minor < 6: + for _, name, is_pkg in pkgutil.iter_modules([pkg_path]): + if not is_pkg: + modules.add(pkg + "." + name) + else: + for info in pkgutil.iter_modules([pkg_path]): + if not info.ispkg: + modules.add(pkg + "." + info.name) + return modules + + +def __fs_objects(path: str) -> Set[FeatureSetPipeline]: + logger.info(f"Looking for python modules under {path}...") + modules = __find_modules(path) + if not modules: + logger.error(f"Path: {path} not found!") + return set() + + logger.info("Importing modules...") + package = ".".join(path.strip("/").split("/")) + imported = set( + importlib.import_module(f".{name}", package=package) for name in modules + ) + + logger.info("Scanning modules...") + content = { + module: set( + filter( + lambda x: not x.startswith("__"), # filter "__any__" attributes + set(item for item in dir(module)), + ) + ) + for module in imported + } + + instances = set() + for module, items in content.items(): + for item in items: + value = getattr(module, item) + if not value: + continue + + # filtering non-classes + if not inspect.isclass(value): + continue + + # filtering abstractions + if inspect.isabstract(value): + continue + + # filtering classes that doesn't inherit from FeatureSetPipeline + if not issubclass(value, FeatureSetPipeline): + continue + + # filtering FeatureSetPipeline itself + if value == FeatureSetPipeline: + continue + + instances.add(value) + + logger.info("Creating instances...") + return set(value() for value in instances) # type: ignore + + +PATH = typer.Argument( + ..., + help="Full or relative path to where feature set pipelines are being defined.", +) + +GENERATE_LOGS = typer.Option( + False, help="To generate the logs in local file 'logging.json'." +) + +DEBUG_MODE = typer.Option( + False, + help="To view the queries resulting from the migration, DON'T apply the migration.", +) + + +class Migrate: + """Execute migration operations in a Database based on pipeline Writer. + + Attributes: + pipelines: list of Feature Set Pipelines to use to migration. + """ + + def __init__( + self, + pipelines: Set[FeatureSetPipeline], + ) -> None: + self.pipelines = pipelines + + def _send_logs_to_s3(self, file_local: bool, debug_mode: bool) -> None: + """Send all migration logs to S3.""" + file_name = "../logging.json" + + if not file_local and os.path.exists(file_name): + s3_client = boto3.client("s3") + + timestamp = datetime.datetime.now() + + if debug_mode: + object_name = ( + f"logs/migrate-debug-mode/" + f"{timestamp.strftime('%Y-%m-%d')}" + f"/logging-{timestamp.strftime('%H:%M:%S')}.json" + ) + else: + object_name = ( + f"logs/migrate/" + f"{timestamp.strftime('%Y-%m-%d')}" + f"/logging-{timestamp.strftime('%H:%M:%S')}.json" + ) + bucket = environment.get_variable("FEATURE_STORE_S3_BUCKET") + + try: + s3_client.upload_file( + file_name, + bucket, + object_name, + ExtraArgs={"ACL": "bucket-owner-full-control"}, + ) + except ClientError: + raise + + os.remove(file_name) + elif os.path.exists(file_name): + print("Logs written to ../logging.json") + else: + print("No logs were generated.") + + def run(self, generate_logs: bool = False, debug_mode: bool = False) -> None: + """Construct and apply the migrations.""" + for pipeline in self.pipelines: + for writer in pipeline.sink.writers: + db = writer.db_config.database + if db == "cassandra": + migration = ALLOWED_DATABASE[db] + migration.apply_migration(pipeline.feature_set, writer, debug_mode) + else: + logger.warning(f"Butterfree not supporting {db} Migrations yet.") + + self._send_logs_to_s3(generate_logs, debug_mode) + + +@app.command("apply") +def migrate( + path: str = PATH, generate_logs: bool = GENERATE_LOGS, debug_mode: bool = DEBUG_MODE +) -> Set[FeatureSetPipeline]: + """Scan and run database migrations for feature set pipelines defined under PATH. + + Butterfree will scan a given path for classes that inherit from its + FeatureSetPipeline and create dry instances of it to extract schema and writer + information. By doing this, Butterfree can compare all defined feature set schemas + to their current state on each sink being used. + + All pipelines must be under python modules inside path, so we can dynamically + import and instantiate them. + """ + pipe_set = __fs_objects(path) + Migrate(pipe_set).run(generate_logs, debug_mode) + return pipe_set diff --git a/butterfree/automated/__init__.py b/butterfree/automated/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/butterfree/automated/feature_set_creation.py b/butterfree/automated/feature_set_creation.py new file mode 100644 index 000000000..4a078135c --- /dev/null +++ b/butterfree/automated/feature_set_creation.py @@ -0,0 +1,199 @@ +import re +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from pyspark.sql import DataFrame + +from butterfree.constants.data_type import DataType + +BUTTERFREE_DTYPES = { + "string": DataType.STRING.spark_sql, + "long": DataType.BIGINT.spark_sql, + "double": DataType.DOUBLE.spark_sql, + "boolean": DataType.BOOLEAN.spark_sql, + "integer": DataType.INTEGER.spark_sql, + "date": DataType.DATE.spark_sql, + "timestamp": DataType.TIMESTAMP.spark_sql, + "array": { + "long": DataType.ARRAY_BIGINT.spark_sql, + "float": DataType.ARRAY_FLOAT.spark_sql, + "string": DataType.ARRAY_STRING.spark_sql, + }, +} + + +@dataclass(frozen=True) +class Table: # noqa: D101 + id: str + database: str + name: str + + +class FeatureSetCreation: + """Class to auto-generate readers and features.""" + + def _get_features_with_regex(self, sql_query: str) -> List[str]: + features = [] + sql_query = " ".join(sql_query.split()) + first_pattern = re.compile("[(]?([\w.*]+)[)]?,", re.IGNORECASE) + second_pattern = re.compile("(\w+)\s(from)", re.IGNORECASE) + + for pattern in [first_pattern, second_pattern]: + matches = pattern.finditer(sql_query) + for match in matches: + feature = match.group(1) + + if "." in feature: + feature = feature.split(".")[1] + + features.append(feature) + + return features + + def _get_data_type(self, field_name: str, df: DataFrame) -> str: + for field in df.schema.jsonValue()["fields"]: + if field["name"] == field_name: + + field_type = field["type"] + + if isinstance(field_type, dict): + + field_type_keys = field_type.keys() + + if "type" in field_type_keys and "elementType" in field_type_keys: + return ( + "." + + BUTTERFREE_DTYPES[field_type["type"]][ # type: ignore + field_type["elementType"] + ] + ) + + return "." + BUTTERFREE_DTYPES[field["type"]] + + return "" + + def _get_tables_with_regex(self, sql_query: str) -> Tuple[List[Table], str]: + + modified_sql_query = sql_query + tables = [] + stop_words = [ + "left", + "right", + "full outer", + "inner", + "where", + "join", + "on", + "as", + ] + keywords = ["from", "join"] + + for keyword in keywords: + pattern = re.compile( + rf"\b{keyword}\s+(\w+\.\w+|\w+)\s+(\w+)", re.IGNORECASE + ) + matches = pattern.finditer(sql_query) + + for match in matches: + full_table_name = match.group(1) + id = match.group(2).strip() + + if id in stop_words: + id = full_table_name + + if "." in full_table_name: + database, table = full_table_name.split(".") + + modified_sql_query = re.sub( + rf"\b{database}\.{table}\b", table, modified_sql_query + ) + + tables.append(Table(id=id, database=database, name=table)) + else: + modified_sql_query = re.sub( + rf"\b{full_table_name}\b", full_table_name, modified_sql_query + ) + tables.append(Table(id=id, database="TBD", name=full_table_name)) + + return tables, modified_sql_query + + def get_readers(self, sql_query: str) -> str: + """ + Extracts table readers from a SQL query and formats them as a string. + + Args: + sql_query (str): The SQL query from which to extract table readers. + + Returns: + str: A formatted string containing the table readers. + """ + tables, modified_sql_query = self._get_tables_with_regex(sql_query.lower()) + readers = [] + for table in tables: + table_reader_string = f""" + TableReader( + id="{table.id}", + database="{table.database}", + table="{table.name}" + ), + """ + readers.append(table_reader_string) + + final_string = """ + source=Source( + readers=[ + {} + ], + query=( + \"\"\" + {} + \"\"\" + ), + ), + """.format( + "".join(readers), modified_sql_query.replace("\n", "\n\t\t") + ) + + return final_string + + def get_features(self, sql_query: str, df: Optional[DataFrame] = None) -> str: + """ + Extract features from a SQL query and return them formatted as a string. + + Args: + sql_query (str): The SQL query used to extract features. + df (Optional[DataFrame], optional): Optional DataFrame used to infer data types. Defaults to None. + + Returns: + str: A formatted string containing the extracted features. + + This sould be used on Databricks. + + Especially if you want automatic type inference without passing a reference dataframe. + The utility will only work in an environment where a spark session is available in the environment + """ # noqa: E501 + + features = self._get_features_with_regex(sql_query) + features_formatted = [] + for feature in features: + description = feature.replace("__", " ").replace("_", " ").capitalize() + + data_type = "." + + if df is None: + df = spark.sql(sql_query) # type: ignore # noqa: F821 + + data_type = self._get_data_type(feature, df) + + feature_string = f""" + Feature( + name="{feature}", + description="{description}", + dtype=DataType{data_type}, + ), + """ + features_formatted.append(feature_string) + + final_string = ("features=[\t{}],\n),").format("".join(features_formatted)) + + return final_string diff --git a/butterfree/clients/__init__.py b/butterfree/clients/__init__.py index 5f6f0ffa6..7e8d1a95b 100644 --- a/butterfree/clients/__init__.py +++ b/butterfree/clients/__init__.py @@ -1,4 +1,5 @@ """Holds connection clients.""" + from butterfree.clients.abstract_client import AbstractClient from butterfree.clients.cassandra_client import CassandraClient from butterfree.clients.spark_client import SparkClient diff --git a/butterfree/clients/abstract_client.py b/butterfree/clients/abstract_client.py index 265706e68..b9027bd88 100644 --- a/butterfree/clients/abstract_client.py +++ b/butterfree/clients/abstract_client.py @@ -1,6 +1,7 @@ """Abstract class for database clients.""" + from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional class AbstractClient(ABC): @@ -23,3 +24,17 @@ def sql(self, query: str) -> Any: Set of records. """ pass + + @abstractmethod + def get_schema(self, table: str, database: Optional[str] = None) -> Any: + """Returns desired table schema. + + Attributes: + table: desired table. + + Returns: + A list of dictionaries in the format + [{"column_name": "example1", type: "Spark_type"}, ...] + + """ + pass diff --git a/butterfree/clients/cassandra_client.py b/butterfree/clients/cassandra_client.py index 1e5416886..714e82483 100644 --- a/butterfree/clients/cassandra_client.py +++ b/butterfree/clients/cassandra_client.py @@ -1,11 +1,18 @@ """CassandraClient entity.""" + from ssl import CERT_REQUIRED, PROTOCOL_TLSv1 from typing import Dict, List, Optional from cassandra.auth import PlainTextAuthProvider -from cassandra.cluster import Cluster, ResponseFuture, Session -from cassandra.policies import RoundRobinPolicy -from cassandra.query import dict_factory +from cassandra.cluster import ( + EXEC_PROFILE_DEFAULT, + Cluster, + ExecutionProfile, + ResponseFuture, + Session, +) +from cassandra.policies import DCAwareRoundRobinPolicy +from cassandra.query import ConsistencyLevel, dict_factory from typing_extensions import TypedDict from butterfree.clients import AbstractClient @@ -33,53 +40,58 @@ class CassandraClient(AbstractClient): """Cassandra Client. Attributes: - cassandra_user: username to use in connection. - cassandra_password: password to use in connection. - cassandra_key_space: key space used in connection. - cassandra_host: cassandra endpoint used in connection. + user: username to use in connection. + password: password to use in connection. + keyspace: key space used in connection. + host: cassandra endpoint used in connection. """ def __init__( self, - cassandra_host: List[str], - cassandra_key_space: str, - cassandra_user: Optional[str] = None, - cassandra_password: Optional[str] = None, + host: List[str], + keyspace: str, + user: Optional[str] = None, + password: Optional[str] = None, ) -> None: - self.cassandra_host = cassandra_host - self.cassandra_key_space = cassandra_key_space - self.cassandra_user = cassandra_user - self.cassandra_password = cassandra_password + self.host = host + self.keyspace = keyspace + self.user = user + self.password = password self._session: Optional[Session] = None @property def conn(self, *, ssl_path: str = None) -> Session: # type: ignore """Establishes a Cassandra connection.""" - auth_provider = ( - PlainTextAuthProvider( - username=self.cassandra_user, password=self.cassandra_password + if not self._session: + auth_provider = ( + PlainTextAuthProvider(username=self.user, password=self.password) + if self.user is not None + else None + ) + ssl_opts = ( + { + "ca_certs": ssl_path, + "ssl_version": PROTOCOL_TLSv1, + "cert_reqs": CERT_REQUIRED, + } + if ssl_path is not None + else None ) - if self.cassandra_user is not None - else None - ) - ssl_opts = ( - { - "ca_certs": ssl_path, - "ssl_version": PROTOCOL_TLSv1, - "cert_reqs": CERT_REQUIRED, - } - if ssl_path is not None - else None - ) - cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider, - ssl_options=ssl_opts, - load_balancing_policy=RoundRobinPolicy(), - ) - self._session = cluster.connect(self.cassandra_key_space) - self._session.row_factory = dict_factory + execution_profiles = { + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=DCAwareRoundRobinPolicy(), + consistency_level=ConsistencyLevel.LOCAL_QUORUM, + row_factory=dict_factory, + ) + } + cluster = Cluster( + contact_points=self.host, + auth_provider=auth_provider, + ssl_options=ssl_opts, + execution_profiles=execution_profiles, + ) + self._session = cluster.connect(self.keyspace) return self._session def sql(self, query: str) -> ResponseFuture: @@ -89,11 +101,11 @@ def sql(self, query: str) -> ResponseFuture: query: desired query. """ - if not self._session: - raise RuntimeError("There's no session available for this query.") - return self._session.execute(query) + return self.conn.execute(query) - def get_schema(self, table: str) -> List[Dict[str, str]]: + def get_schema( + self, table: str, database: Optional[str] = None + ) -> List[Dict[str, str]]: """Returns desired table schema. Attributes: @@ -106,7 +118,7 @@ def get_schema(self, table: str) -> List[Dict[str, str]]: """ query = ( f"SELECT column_name, type FROM system_schema.columns " # noqa - f"WHERE keyspace_name = '{self.cassandra_key_space}' " # noqa + f"WHERE keyspace_name = '{self.keyspace}' " # noqa f" AND table_name = '{table}';" # noqa ) @@ -114,14 +126,15 @@ def get_schema(self, table: str) -> List[Dict[str, str]]: if not response: raise RuntimeError( - f"No columns found for table: {table}" - f"in key space: {self.cassandra_key_space}" + f"No columns found for table: {table}" f"in key space: {self.keyspace}" ) return response def _get_create_table_query( - self, columns: List[CassandraColumn], table: str, + self, + columns: List[CassandraColumn], + table: str, ) -> str: """Creates CQL statement to create a table.""" parsed_columns = [] @@ -143,7 +156,7 @@ def _get_create_table_query( else: columns_str = joined_parsed_columns - query = f"CREATE TABLE {self.cassandra_key_space}.{table} " f"({columns_str}); " + query = f"CREATE TABLE {self.keyspace}.{table} " f"({columns_str}); " return query diff --git a/butterfree/clients/spark_client.py b/butterfree/clients/spark_client.py index 0a8c717c5..f4b6ea652 100644 --- a/butterfree/clients/spark_client.py +++ b/butterfree/clients/spark_client.py @@ -1,5 +1,6 @@ """SparkClient entity.""" +import json from typing import Any, Dict, List, Optional, Union from pyspark.sql import DataFrame, DataFrameReader, SparkSession @@ -29,14 +30,16 @@ def conn(self) -> SparkSession: """ if not self._session: self._session = SparkSession.builder.getOrCreate() + return self._session def read( self, format: str, - options: Dict[str, Any], + path: Optional[Union[str, List[str]]] = None, schema: Optional[StructType] = None, stream: bool = False, + **options: Any, ) -> DataFrame: """Use the SparkSession.read interface to load data into a dataframe. @@ -45,9 +48,10 @@ def read( Args: format: string with the format to be used by the DataframeReader. - options: options to setup the DataframeReader. + path: optional string or a list of string for file-system. stream: flag to indicate if data must be read in stream mode. schema: an optional pyspark.sql.types.StructType for the input schema. + options: options to setup the DataframeReader. Returns: Dataframe @@ -55,16 +59,18 @@ def read( """ if not isinstance(format, str): raise ValueError("format needs to be a string with the desired read format") - if not isinstance(options, dict): - raise ValueError("options needs to be a dict with the setup configurations") + if path and not isinstance(path, (str, list)): + raise ValueError("path needs to be a string or a list of string") + + df_reader: Union[DataStreamReader, DataFrameReader] = ( + self.conn.readStream if stream else self.conn.read + ) - df_reader: Union[ - DataStreamReader, DataFrameReader - ] = self.conn.readStream if stream else self.conn.read df_reader = df_reader.schema(schema) if schema else df_reader - return df_reader.format(format).options(**options).load() - def read_table(self, table: str, database: str = None) -> DataFrame: + return df_reader.format(format).load(path=path, **options) # type: ignore + + def read_table(self, table: str, database: Optional[str] = None) -> DataFrame: """Use the SparkSession.read interface to read a metastore table. Args: @@ -174,9 +180,9 @@ def write_table( database: Optional[str], table_name: str, path: str, - format_: str = None, - mode: str = None, - partition_by: List[str] = None, + format_: Optional[str] = None, + mode: Optional[str] = None, + partition_by: Optional[List[str]] = None, **options: Any, ) -> None: """Receive a spark DataFrame and write it as a table in metastore. @@ -212,7 +218,8 @@ def write_table( **options, ) - def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any: + @staticmethod + def create_temporary_view(dataframe: DataFrame, name: str) -> Any: """Create a temporary view from a given dataframe. Args: @@ -223,3 +230,116 @@ def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any: if not dataframe.isStreaming: return dataframe.createOrReplaceTempView(name) return dataframe.writeStream.format("memory").queryName(name).start() + + def add_table_partitions( + self, + partitions: List[Dict[str, Any]], + table: str, + database: Optional[str] = None, + ) -> None: + """Add partitions to an existing table. + + Args: + partitions: partitions to add to the table. + It's expected a list of partition dicts to add to the table. + Example: `[{"year": 2020, "month": 8, "day": 14}, ...]` + table: table to add the partitions. + database: name of the database where the table is saved. + """ + for partition_dict in partitions: + if not all( + ( + isinstance(key, str) + and (isinstance(value, str) or isinstance(value, int)) + ) + for key, value in partition_dict.items() + ): + raise ValueError( + "Partition keys must be column names " + "and values must be string or int." + ) + + database_expr = f"`{database}`." if database else "" + key_values_expr = [ + ", ".join( + [ + ( + "{} = {}".format(k, v) + if not isinstance(v, str) + else "{} = '{}'".format(k, v) + ) + for k, v in partition.items() + ] + ) + for partition in partitions + ] + partitions_expr = " ".join(f"PARTITION ( {expr} )" for expr in key_values_expr) + command = ( + f"ALTER TABLE {database_expr}`{table}` ADD IF NOT EXISTS {partitions_expr}" + ) + + self.conn.sql(command) + + @staticmethod + def _filter_schema(schema: DataFrame) -> List[str]: + """Returns filtered schema with the desired information. + + Attributes: + schema: desired table. + + Returns: + A list of strings in the format + ['{"column_name": "example1", type: "Spark_type"}', ...] + + """ + return ( + schema.filter( + ~schema.col_name.isin( + ["# Partition Information", "# col_name", "year", "month", "day"] + ) + ) + .toJSON() + .collect() + ) + + def _convert_schema(self, schema: DataFrame) -> List[Dict[str, str]]: + """Returns schema with the desired information. + + Attributes: + schema: desired table. + + Returns: + A list of dictionaries in the format + [{"column_name": "example1", type: "Spark_type"}, ...] + + """ + schema_list = self._filter_schema(schema) + converted_schema = [] + for row in schema_list: + converted_schema.append(json.loads(row)) + + return converted_schema + + def get_schema( + self, table: str, database: Optional[str] = None + ) -> List[Dict[str, str]]: + """Returns desired table schema. + + Attributes: + table: desired table. + + Returns: + A list of dictionaries in the format + [{"column_name": "example1", type: "Spark_type"}, ...] + + """ + query = f"DESCRIBE {database}.{table} " # noqa + + response = self.sql(query) + + if not response: + raise RuntimeError( + f"No columns found for table: {table}" f"in database: {database}" + ) + + return self._convert_schema(response) diff --git a/butterfree/configs/db/abstract_config.py b/butterfree/configs/db/abstract_config.py index 8e98aab61..fbd48c534 100644 --- a/butterfree/configs/db/abstract_config.py +++ b/butterfree/configs/db/abstract_config.py @@ -7,6 +7,11 @@ class AbstractWriteConfig(ABC): """Abstract class for database write configurations with spark.""" + @property + @abstractmethod + def database(self) -> str: + """Database name.""" + @property @abstractmethod def mode(self) -> Any: diff --git a/butterfree/configs/db/cassandra_config.py b/butterfree/configs/db/cassandra_config.py index b58a2e0a2..919fee8e8 100644 --- a/butterfree/configs/db/cassandra_config.py +++ b/butterfree/configs/db/cassandra_config.py @@ -1,4 +1,5 @@ """Holds configurations to read and write with Spark to Cassandra DB.""" + from typing import Any, Dict, List, Optional from butterfree.configs import environment @@ -21,6 +22,8 @@ class CassandraConfig(AbstractWriteConfig): stream_processing_time: processing time interval for streaming jobs. stream_output_mode: specify the mode from writing streaming data. stream_checkpoint_path: path on S3 to save checkpoints for the stream job. + read_consistency_level: read consistency level used in connection. + write_consistency_level: write consistency level used in connection. More information about processing_time, output_mode and checkpoint_path can be found in Spark documentation: @@ -30,15 +33,18 @@ class CassandraConfig(AbstractWriteConfig): def __init__( self, - username: str = None, - password: str = None, - host: str = None, - keyspace: str = None, - mode: str = None, - format_: str = None, - stream_processing_time: str = None, - stream_output_mode: str = None, - stream_checkpoint_path: str = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + keyspace: Optional[str] = None, + mode: Optional[str] = None, + format_: Optional[str] = None, + stream_processing_time: Optional[str] = None, + stream_output_mode: Optional[str] = None, + stream_checkpoint_path: Optional[str] = None, + read_consistency_level: Optional[str] = None, + write_consistency_level: Optional[str] = None, + local_dc: Optional[str] = None, ): self.username = username self.password = password @@ -49,6 +55,14 @@ def __init__( self.stream_processing_time = stream_processing_time self.stream_output_mode = stream_output_mode self.stream_checkpoint_path = stream_checkpoint_path + self.read_consistency_level = read_consistency_level + self.write_consistency_level = write_consistency_level + self.local_dc = local_dc + + @property + def database(self) -> str: + """Database name.""" + return "cassandra" @property def username(self) -> Optional[str]: @@ -145,6 +159,37 @@ def stream_checkpoint_path(self, value: str) -> None: "STREAM_CHECKPOINT_PATH" ) + @property + def read_consistency_level(self) -> Optional[str]: + """Read consistency level for Cassandra.""" + return self.__read_consistency_level + + @read_consistency_level.setter + def read_consistency_level(self, value: str) -> None: + self.__read_consistency_level = value or environment.get_variable( + "CASSANDRA_READ_CONSISTENCY_LEVEL", "LOCAL_ONE" + ) + + @property + def write_consistency_level(self) -> Optional[str]: + """Write consistency level for Cassandra.""" + return self.__write_consistency_level + + @write_consistency_level.setter + def write_consistency_level(self, value: str) -> None: + self.__write_consistency_level = value or environment.get_variable( + "CASSANDRA_WRITE_CONSISTENCY_LEVEL", "LOCAL_QUORUM" + ) + + @property + def local_dc(self) -> Optional[str]: + """Local DC for Cassandra connection.""" + return self.__local_dc + + @local_dc.setter + def local_dc(self, value: str) -> None: + self.__local_dc = value or environment.get_variable("CASSANDRA_LOCAL_DC") + def get_options(self, table: str) -> Dict[Optional[str], Optional[str]]: """Get options for connect to Cassandra DB. @@ -164,6 +209,9 @@ def get_options(self, table: str) -> Dict[Optional[str], Optional[str]]: "spark.cassandra.auth.username": self.username, "spark.cassandra.auth.password": self.password, "spark.cassandra.connection.host": self.host, + "spark.cassandra.connection.localDC": self.local_dc, + "spark.cassandra.input.consistency.level": self.read_consistency_level, + "spark.cassandra.output.consistency.level": self.write_consistency_level, } def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @@ -180,26 +228,29 @@ def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ cassandra_mapping = { - "TimestampType": "timestamp", - "BinaryType": "boolean", - "BooleanType": "boolean", - "DateType": "timestamp", - "DecimalType": "decimal", - "DoubleType": "double", - "FloatType": "float", - "IntegerType": "int", - "LongType": "bigint", - "StringType": "text", - "ArrayType(LongType,true)": "frozen>", - "ArrayType(StringType,true)": "frozen>", - "ArrayType(FloatType,true)": "frozen>", + "timestamptype": "timestamp", + "timestampntztype": "timestamp", + "binarytype": "boolean", + "booleantype": "boolean", + "datetype": "timestamp", + "decimaltype": "decimal", + "doubletype": "double", + "floattype": "float", + "integertype": "int", + "longtype": "bigint", + "stringtype": "text", + "arraytype(longtype, true)": "frozen>", + "arraytype(stringtype, true)": "frozen>", + "arraytype(floattype, true)": "frozen>", } cassandra_schema = [] for features in schema: cassandra_schema.append( { "column_name": features["column_name"], - "type": cassandra_mapping[str(features["type"])], + "type": cassandra_mapping[ + str(features["type"]).replace("()", "").lower() + ], "primary_key": features["primary_key"], } ) diff --git a/butterfree/configs/db/kafka_config.py b/butterfree/configs/db/kafka_config.py index 67b2dc57c..e0c14baf3 100644 --- a/butterfree/configs/db/kafka_config.py +++ b/butterfree/configs/db/kafka_config.py @@ -1,4 +1,5 @@ """Holds configurations to read and write with Spark to Kafka.""" + from typing import Any, Dict, List, Optional from butterfree.configs import environment @@ -25,13 +26,13 @@ class KafkaConfig(AbstractWriteConfig): def __init__( self, - kafka_topic: str = None, - kafka_connection_string: str = None, - mode: str = None, - format_: str = None, - stream_processing_time: str = None, - stream_output_mode: str = None, - stream_checkpoint_path: str = None, + kafka_topic: Optional[str] = None, + kafka_connection_string: Optional[str] = None, + mode: Optional[str] = None, + format_: Optional[str] = None, + stream_processing_time: Optional[str] = None, + stream_output_mode: Optional[str] = None, + stream_checkpoint_path: Optional[str] = None, ): self.kafka_topic = kafka_topic self.kafka_connection_string = kafka_connection_string @@ -41,6 +42,11 @@ def __init__( self.stream_output_mode = stream_output_mode self.stream_checkpoint_path = stream_checkpoint_path + @property + def database(self) -> str: + """Database name.""" + return "kafka" + @property def kafka_topic(self) -> Optional[str]: """Kafka topic name.""" @@ -142,4 +148,4 @@ def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: Kafka schema. """ - pass + return [{}] diff --git a/butterfree/configs/db/metastore_config.py b/butterfree/configs/db/metastore_config.py index d94b792c8..323aded0c 100644 --- a/butterfree/configs/db/metastore_config.py +++ b/butterfree/configs/db/metastore_config.py @@ -3,8 +3,11 @@ import os from typing import Any, Dict, List, Optional +from pyspark.sql import DataFrame + from butterfree.configs import environment from butterfree.configs.db import AbstractWriteConfig +from butterfree.dataframe_service import extract_partition_values class MetastoreConfig(AbstractWriteConfig): @@ -22,16 +25,21 @@ class MetastoreConfig(AbstractWriteConfig): def __init__( self, - path: str = None, - mode: str = None, - format_: str = None, - file_system: str = None, + path: Optional[str] = None, + mode: Optional[str] = None, + format_: Optional[str] = None, + file_system: Optional[str] = None, ): self.path = path self.mode = mode self.format_ = format_ self.file_system = file_system + @property + def database(self) -> str: + """Database name.""" + return "metastore" + @property def path(self) -> Optional[str]: """Bucket name.""" @@ -87,6 +95,56 @@ def get_options(self, key: str) -> Dict[Optional[str], Optional[str]]: "path": os.path.join(f"{self.file_system}://{self.path}/", key), } + def get_path_with_partitions(self, key: str, dataframe: DataFrame) -> List: + """Get options for AWS S3 from partitioned parquet file. + + Options will be a dictionary with the write and read configuration for + Spark to AWS S3. + + Args: + key: path to save data into AWS S3 bucket. + dataframe: spark dataframe containing data from a feature set. + + Returns: + A list of string for file-system backed data sources. + """ + path_list = [] + dataframe_values = extract_partition_values( + dataframe, partition_columns=["year", "month", "day"] + ) + for row in dataframe_values: + path_list.append( + f"{self.file_system}://{self.path}/{key}/year={row['year']}/" + f"month={row['month']}/day={row['day']}" + ) + + return path_list + def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Translate feature set spark schema to the corresponding database.""" - pass + spark_sql_mapping = { + "TimestampType": "TIMESTAMP", + "BinaryType": "BINARY", + "BooleanType": "BOOLEAN", + "DateType": "DATE", + "DecimalType": "DECIMAL", + "DoubleType": "DOUBLE", + "FloatType": "FLOAT", + "IntegerType": "INT", + "LongType": "BIGINT", + "StringType": "STRING", + "ArrayType(LongType,true)": "ARRAY", + "ArrayType(StringType,true)": "ARRAY", + "ArrayType(FloatType,true)": "ARRAY", + } + sql_schema = [] + for features in schema: + sql_schema.append( + { + "column_name": features["column_name"], + "type": spark_sql_mapping[str(features["type"])], + "primary_key": features["primary_key"], + } + ) + + return sql_schema diff --git a/butterfree/configs/environment.py b/butterfree/configs/environment.py index 6f5accbc5..f6ba18a5b 100644 --- a/butterfree/configs/environment.py +++ b/butterfree/configs/environment.py @@ -1,4 +1,5 @@ """Holds functions for managing the running environment.""" + import os from typing import Optional @@ -12,6 +13,9 @@ "FEATURE_STORE_HISTORICAL_DATABASE": "test", "KAFKA_CONSUMER_CONNECTION_STRING": "test_host:1234,test_host2:1234", "STREAM_CHECKPOINT_PATH": None, + "CASSANDRA_READ_CONSISTENCY_LEVEL": None, + "CASSANDRA_WRITE_CONSISTENCY_LEVEL": None, + "CASSANDRA_LOCAL_DC": None, } @@ -31,12 +35,14 @@ def __init__(self, variable_name: str): ) -def get_variable(variable_name: str, default_value: str = None) -> Optional[str]: +def get_variable( + variable_name: str, default_value: Optional[str] = None +) -> Optional[str]: """Gets an environment variable. The variable comes from it's explicitly declared value in the running - environment or from the default value declared in the environment.yaml - specification or from the default_value. + environment or from the default value declared in specification or from the + default_value. Args: variable_name: environment variable name. diff --git a/butterfree/configs/logger.py b/butterfree/configs/logger.py new file mode 100644 index 000000000..60dab67c7 --- /dev/null +++ b/butterfree/configs/logger.py @@ -0,0 +1,24 @@ +"""Logger funcion.""" + +import logging + + +def __config(json_file_logs: bool = False) -> None: + + if json_file_logs: + return logging.basicConfig( + format='{"name": "%(name)s", "timestamp": "%(asctime)-15s", ' + '"level": "%(levelname)s", "message": "%(message)s"}', + level=logging.INFO, + filename="../logging.json", + ) + return logging.basicConfig( + format="%(name)s:%(asctime)-15s:%(levelname)s:< %(message)s >", + level=logging.INFO, + ) + + +def __logger(name: str, file_logs: bool = False) -> logging.Logger: + + __config(file_logs) + return logging.getLogger(name) diff --git a/butterfree/constants/__init__.py b/butterfree/constants/__init__.py index ec70d41b5..aa0c76e65 100644 --- a/butterfree/constants/__init__.py +++ b/butterfree/constants/__init__.py @@ -1,4 +1,5 @@ """Holds constant attributes that are common for Butterfree.""" + from butterfree.constants.data_type import DataType __all__ = ["DataType"] diff --git a/butterfree/constants/data_type.py b/butterfree/constants/data_type.py index 157d4a1fe..6166f1fc6 100644 --- a/butterfree/constants/data_type.py +++ b/butterfree/constants/data_type.py @@ -12,6 +12,7 @@ IntegerType, LongType, StringType, + TimestampNTZType, TimestampType, ) from typing_extensions import final @@ -21,20 +22,22 @@ class DataType(Enum): """Holds constants for data types within Butterfree.""" - TIMESTAMP = (TimestampType(), "timestamp") - BINARY = (BinaryType(), "boolean") - BOOLEAN = (BooleanType(), "boolean") - DATE = (DateType(), "timestamp") - DECIMAL = (DecimalType(), "decimal") - DOUBLE = (DoubleType(), "double") - FLOAT = (FloatType(), "float") - INTEGER = (IntegerType(), "int") - BIGINT = (LongType(), "bigint") - STRING = (StringType(), "text") - ARRAY_BIGINT = (ArrayType(LongType()), "frozen>") - ARRAY_STRING = (ArrayType(StringType()), "frozen>") - ARRAY_FLOAT = (ArrayType(FloatType()), "frozen>") + TIMESTAMP_NTZ = (TimestampNTZType(), "timestamp", "TIMESTAMP_NTZ") + TIMESTAMP = (TimestampType(), "timestamp", "TIMESTAMP") + BINARY = (BinaryType(), "boolean", "BINARY") + BOOLEAN = (BooleanType(), "boolean", "BOOLEAN") + DATE = (DateType(), "timestamp", "DATE") + DECIMAL = (DecimalType(), "decimal", "DECIMAL") + DOUBLE = (DoubleType(), "double", "DOUBLE") + FLOAT = (FloatType(), "float", "FLOAT") + INTEGER = (IntegerType(), "int", "INT") + BIGINT = (LongType(), "bigint", "BIGINT") + STRING = (StringType(), "text", "STRING") + ARRAY_BIGINT = (ArrayType(LongType()), "frozen>", "ARRAY") + ARRAY_STRING = (ArrayType(StringType()), "frozen>", "ARRAY") + ARRAY_FLOAT = (ArrayType(FloatType()), "frozen>", "ARRAY") - def __init__(self, spark: PySparkDataType, cassandra: str) -> None: + def __init__(self, spark: PySparkDataType, cassandra: str, spark_sql: str) -> None: self.spark = spark self.cassandra = cassandra + self.spark_sql = spark_sql diff --git a/butterfree/constants/migrations.py b/butterfree/constants/migrations.py new file mode 100644 index 000000000..f31d08418 --- /dev/null +++ b/butterfree/constants/migrations.py @@ -0,0 +1,9 @@ +"""Migrations' Constants.""" + +from butterfree.constants import columns + +PARTITION_BY = [ + {"column_name": columns.PARTITION_YEAR, "type": "INT"}, + {"column_name": columns.PARTITION_MONTH, "type": "INT"}, + {"column_name": columns.PARTITION_DAY, "type": "INT"}, +] diff --git a/butterfree/constants/window_definitions.py b/butterfree/constants/window_definitions.py new file mode 100644 index 000000000..560904f75 --- /dev/null +++ b/butterfree/constants/window_definitions.py @@ -0,0 +1,16 @@ +"""Allowed windows units and lengths in seconds.""" + +ALLOWED_WINDOWS = { + "second": 1, + "seconds": 1, + "minute": 60, + "minutes": 60, + "hour": 3600, + "hours": 3600, + "day": 86400, + "days": 86400, + "week": 604800, + "weeks": 604800, + "year": 29030400, + "years": 29030400, +} diff --git a/butterfree/dataframe_service/__init__.py b/butterfree/dataframe_service/__init__.py index 5116261d6..5fd02d453 100644 --- a/butterfree/dataframe_service/__init__.py +++ b/butterfree/dataframe_service/__init__.py @@ -1,4 +1,12 @@ """Dataframe optimization components regarding Butterfree.""" + +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy +from butterfree.dataframe_service.partitioning import extract_partition_values from butterfree.dataframe_service.repartition import repartition_df, repartition_sort_df -__all__ = ["repartition_df", "repartition_sort_df"] +__all__ = [ + "extract_partition_values", + "IncrementalStrategy", + "repartition_df", + "repartition_sort_df", +] diff --git a/butterfree/dataframe_service/incremental_strategy.py b/butterfree/dataframe_service/incremental_strategy.py new file mode 100644 index 000000000..957064f15 --- /dev/null +++ b/butterfree/dataframe_service/incremental_strategy.py @@ -0,0 +1,125 @@ +"""IncrementalStrategy entity.""" + +from __future__ import annotations + +from typing import Optional + +from pyspark.sql import DataFrame + + +class IncrementalStrategy: + """Define an incremental strategy to be used on data sources. + + Entity responsible for defining a column expression that will be used to + filter the original data source. The purpose is to get only the data related + to a specific pipeline execution time interval. + + Attributes: + column: column expression on which incremental filter will be applied. + The expression need to result on a date or timestamp format, so the + filter can properly work with the defined upper and lower bounds. + """ + + def __init__(self, column: Optional[str] = None): + self.column = column + + def from_milliseconds(self, column_name: str) -> IncrementalStrategy: + """Create a column expression from ts column defined as milliseconds. + + Args: + column_name: column name where the filter will be applied. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy(column=f"from_unixtime({column_name}/ 1000.0)") + + def from_string( + self, column_name: str, mask: Optional[str] = None + ) -> IncrementalStrategy: + """Create a column expression from ts column defined as a simple string. + + Args: + column_name: column name where the filter will be applied. + mask: mask defining the date/timestamp format on the string. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy(column=f"to_date({column_name}, '{mask}')") + + def from_year_month_day_partitions( + self, + year_column: str = "year", + month_column: str = "month", + day_column: str = "day", + ) -> IncrementalStrategy: + """Create a column expression from year, month and day partitions. + + Args: + year_column: column name from the year partition. + month_column: column name from the month partition. + day_column: column name from the day partition. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy( + column=f"concat(string({year_column}), " + f"'-', string({month_column}), " + f"'-', string({day_column}))" + ) + + def get_expression( + self, start_date: Optional[str] = None, end_date: Optional[str] = None + ) -> str: + """Get the incremental filter expression using the defined dates. + + Both arguments can be set to defined a specific date interval, but it's + only necessary to set one of the arguments for this method to work. + + Args: + start_date: date lower bound to use in the filter. + end_date: date upper bound to use in the filter. + + Returns: + Filter expression based on defined column and bounds. + + Raises: + ValuerError: If both arguments, start_date and end_date, are None. + ValueError: If the column expression was not defined. + """ + if not self.column: + raise ValueError("column parameter can't be None") + if not (start_date or end_date): + raise ValueError("Both arguments start_date and end_date can't be None.") + if start_date: + expression = f"date({self.column}) >= date('{start_date}')" + if end_date: + expression += f" and date({self.column}) <= date('{end_date}')" + return expression + return f"date({self.column}) <= date('{end_date}')" + + def filter_with_incremental_strategy( + self, + dataframe: DataFrame, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> DataFrame: + """Filters the dataframe according to the date boundaries. + + Args: + dataframe: dataframe that will be filtered. + start_date: date lower bound to use in the filter. + end_date: date upper bound to use in the filter. + + Returns: + Filtered dataframe based on defined time boundaries. + """ + return ( + dataframe.where( + self.get_expression(start_date=start_date, end_date=end_date) + ) + if start_date or end_date + else dataframe + ) diff --git a/butterfree/dataframe_service/partitioning.py b/butterfree/dataframe_service/partitioning.py new file mode 100644 index 000000000..21e9b0ab7 --- /dev/null +++ b/butterfree/dataframe_service/partitioning.py @@ -0,0 +1,25 @@ +"""Module defining partitioning methods.""" + +from typing import Any, Dict, List + +from pyspark.sql import DataFrame + + +def extract_partition_values( + dataframe: DataFrame, partition_columns: List[str] +) -> List[Dict[str, Any]]: + """Extract distinct partition values from a given dataframe. + + Args: + dataframe: dataframe from where to extract partition values. + partition_columns: name of partition columns presented on the dataframe. + + Returns: + distinct partition values. + """ + return ( + dataframe.select(*partition_columns) + .distinct() + .rdd.map(lambda row: row.asDict(True)) + .collect() + ) diff --git a/butterfree/dataframe_service/repartition.py b/butterfree/dataframe_service/repartition.py index 8635557f9..e84202ba7 100644 --- a/butterfree/dataframe_service/repartition.py +++ b/butterfree/dataframe_service/repartition.py @@ -1,5 +1,6 @@ """Module where there are repartition methods.""" -from typing import List + +from typing import List, Optional from pyspark.sql.dataframe import DataFrame @@ -10,7 +11,7 @@ def _num_partitions_definition( - num_processors: int = None, num_partitions: int = None + num_processors: Optional[int] = None, num_partitions: Optional[int] = None ) -> int: num_partitions = ( num_processors * PARTITION_PROCESSOR_RATIO @@ -24,8 +25,8 @@ def _num_partitions_definition( def repartition_df( dataframe: DataFrame, partition_by: List[str], - num_partitions: int = None, - num_processors: int = None, + num_partitions: Optional[int] = None, + num_processors: Optional[int] = None, ) -> DataFrame: """Partition the DataFrame. @@ -47,8 +48,8 @@ def repartition_sort_df( dataframe: DataFrame, partition_by: List[str], order_by: List[str], - num_processors: int = None, - num_partitions: int = None, + num_processors: Optional[int] = None, + num_partitions: Optional[int] = None, ) -> DataFrame: """Partition and Sort the DataFrame. diff --git a/butterfree/extract/__init__.py b/butterfree/extract/__init__.py index bb056255b..64c8ae4a1 100644 --- a/butterfree/extract/__init__.py +++ b/butterfree/extract/__init__.py @@ -1,4 +1,5 @@ """The Source Component of a Feature Set.""" + from butterfree.extract.source import Source __all__ = ["Source"] diff --git a/butterfree/extract/pre_processing/__init__.py b/butterfree/extract/pre_processing/__init__.py index 72b37c4db..e142de6d1 100644 --- a/butterfree/extract/pre_processing/__init__.py +++ b/butterfree/extract/pre_processing/__init__.py @@ -1,4 +1,5 @@ """Pre Processing Components regarding Readers.""" + from butterfree.extract.pre_processing.explode_json_column_transform import ( explode_json_column, ) diff --git a/butterfree/extract/pre_processing/explode_json_column_transform.py b/butterfree/extract/pre_processing/explode_json_column_transform.py index db79b5ce0..76c90f739 100644 --- a/butterfree/extract/pre_processing/explode_json_column_transform.py +++ b/butterfree/extract/pre_processing/explode_json_column_transform.py @@ -1,4 +1,5 @@ """Explode json column for dataframes.""" + from pyspark.sql.dataframe import DataFrame, StructType from pyspark.sql.functions import from_json, get_json_object diff --git a/butterfree/extract/pre_processing/filter_transform.py b/butterfree/extract/pre_processing/filter_transform.py index 78e5df78f..a7e4fff81 100644 --- a/butterfree/extract/pre_processing/filter_transform.py +++ b/butterfree/extract/pre_processing/filter_transform.py @@ -1,4 +1,5 @@ """Module where filter DataFrames coming from readers.""" + from pyspark.sql.dataframe import DataFrame diff --git a/butterfree/extract/pre_processing/forward_fill_transform.py b/butterfree/extract/pre_processing/forward_fill_transform.py index 96d9bcdda..2d3a232d6 100644 --- a/butterfree/extract/pre_processing/forward_fill_transform.py +++ b/butterfree/extract/pre_processing/forward_fill_transform.py @@ -1,6 +1,7 @@ """Forward Fill Transform for dataframes.""" + import sys -from typing import List, Union +from typing import List, Optional, Union from pyspark.sql import DataFrame, Window, functions @@ -10,7 +11,7 @@ def forward_fill( partition_by: Union[str, List[str]], order_by: Union[str, List[str]], fill_column: str, - filled_column: str = None, + filled_column: Optional[str] = None, ) -> DataFrame: """Applies a forward fill to a single column. diff --git a/butterfree/extract/pre_processing/pivot_transform.py b/butterfree/extract/pre_processing/pivot_transform.py index 078b47464..f255f4577 100644 --- a/butterfree/extract/pre_processing/pivot_transform.py +++ b/butterfree/extract/pre_processing/pivot_transform.py @@ -1,5 +1,6 @@ """Pivot Transform for dataframes.""" -from typing import Callable, List, Union + +from typing import Callable, List, Optional, Union from pyspark.sql import DataFrame, functions from pyspark.sql.types import DataType @@ -13,8 +14,8 @@ def pivot( pivot_column: str, agg_column: str, aggregation: Callable, - mock_value: Union[float, str] = None, - mock_type: Union[DataType, str] = None, + mock_value: Optional[Union[float, str]] = None, + mock_type: Optional[Union[DataType, str]] = None, with_forward_fill: bool = False, ) -> DataFrame: """Defines a pivot transformation. diff --git a/butterfree/extract/pre_processing/replace_transform.py b/butterfree/extract/pre_processing/replace_transform.py index a7dd1d67a..3127c6d9d 100644 --- a/butterfree/extract/pre_processing/replace_transform.py +++ b/butterfree/extract/pre_processing/replace_transform.py @@ -1,4 +1,5 @@ """Replace transformer for dataframes.""" + from itertools import chain from typing import Dict diff --git a/butterfree/extract/readers/__init__.py b/butterfree/extract/readers/__init__.py index 37da63a6c..8c7bd74e0 100644 --- a/butterfree/extract/readers/__init__.py +++ b/butterfree/extract/readers/__init__.py @@ -1,4 +1,5 @@ """The Reader Component of a Source.""" + from butterfree.extract.readers.file_reader import FileReader from butterfree.extract.readers.kafka_reader import KafkaReader from butterfree.extract.readers.table_reader import TableReader diff --git a/butterfree/extract/readers/file_reader.py b/butterfree/extract/readers/file_reader.py index 17f68f1cb..da046f083 100644 --- a/butterfree/extract/readers/file_reader.py +++ b/butterfree/extract/readers/file_reader.py @@ -1,5 +1,6 @@ """FileReader entity.""" -from typing import Any, Dict + +from typing import Any, Dict, Optional from pyspark.sql import DataFrame from pyspark.sql.types import StructType @@ -75,8 +76,8 @@ def __init__( id: str, path: str, format: str, - schema: StructType = None, - format_options: Dict[Any, Any] = None, + schema: Optional[StructType] = None, + format_options: Optional[Dict[Any, Any]] = None, stream: bool = False, ): super().__init__(id) @@ -87,9 +88,7 @@ def __init__( self.path = path self.format = format self.schema = schema - self.options = dict( - {"path": self.path}, **format_options if format_options else {} - ) + self.options = dict(format_options if format_options else {}) self.stream = stream def consume(self, client: SparkClient) -> DataFrame: @@ -106,11 +105,15 @@ def consume(self, client: SparkClient) -> DataFrame: """ schema = ( - client.read(format=self.format, options=self.options,).schema + client.read(format=self.format, path=self.path, **self.options).schema if (self.stream and not self.schema) else self.schema ) return client.read( - format=self.format, options=self.options, schema=schema, stream=self.stream, + format=self.format, + schema=schema, + stream=self.stream, + path=self.path, + **self.options, ) diff --git a/butterfree/extract/readers/kafka_reader.py b/butterfree/extract/readers/kafka_reader.py index 8cac4c198..44731d207 100644 --- a/butterfree/extract/readers/kafka_reader.py +++ b/butterfree/extract/readers/kafka_reader.py @@ -1,5 +1,6 @@ """KafkaSource entity.""" -from typing import Any, Dict + +from typing import Any, Dict, Optional from pyspark.sql.dataframe import DataFrame, StructType from pyspark.sql.functions import col, struct @@ -107,8 +108,8 @@ def __init__( id: str, topic: str, value_schema: StructType, - connection_string: str = None, - topic_options: Dict[Any, Any] = None, + connection_string: Optional[str] = None, + topic_options: Optional[Dict[Any, Any]] = None, stream: bool = True, ): super().__init__(id) @@ -174,7 +175,7 @@ def consume(self, client: SparkClient) -> DataFrame: """ # read using client and cast key and value columns from binary to string raw_df = ( - client.read(format="kafka", options=self.options, stream=self.stream) + client.read(format="kafka", stream=self.stream, **self.options) .withColumn("key", col("key").cast("string")) .withColumn("value", col("value").cast("string")) ) diff --git a/butterfree/extract/readers/reader.py b/butterfree/extract/readers/reader.py index 78be28232..5053d82c4 100644 --- a/butterfree/extract/readers/reader.py +++ b/butterfree/extract/readers/reader.py @@ -2,14 +2,16 @@ from abc import ABC, abstractmethod from functools import reduce -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from pyspark.sql import DataFrame from butterfree.clients import SparkClient +from butterfree.dataframe_service import IncrementalStrategy +from butterfree.hooks import HookableComponent -class Reader(ABC): +class Reader(ABC, HookableComponent): """Abstract base class for Readers. Attributes: @@ -19,9 +21,13 @@ class Reader(ABC): """ - def __init__(self, id: str): + def __init__( + self, id: str, incremental_strategy: Optional[IncrementalStrategy] = None + ): + super().__init__() self.id = id self.transformations: List[Dict[str, Any]] = [] + self.incremental_strategy = incremental_strategy def with_( self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any @@ -48,14 +54,19 @@ def with_( self.transformations.append(new_transformation) return self - def _apply_transformations(self, df: DataFrame) -> Any: - return reduce( - lambda result_df, transformation: transformation["transformer"]( - result_df, *transformation["args"], **transformation["kwargs"] - ), - self.transformations, - df, - ) + def with_incremental_strategy( + self, incremental_strategy: IncrementalStrategy + ) -> "Reader": + """Define the incremental strategy for the Reader. + + Args: + incremental_strategy: definition of the incremental strategy. + + Returns: + Reader with defined incremental strategy. + """ + self.incremental_strategy = incremental_strategy + return self @abstractmethod def consume(self, client: SparkClient) -> DataFrame: @@ -70,24 +81,61 @@ def consume(self, client: SparkClient) -> DataFrame: :return: Spark dataframe """ - def build(self, client: SparkClient, columns: List[Any] = None) -> None: + def build( + self, + client: SparkClient, + columns: Optional[List[Any]] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> None: """Register the data got from the reader in the Spark metastore. Create a temporary view in Spark metastore referencing the data extracted from the target origin after the application of all the defined pre-processing transformations. + The arguments start_date and end_date are going to be use only when there + is a defined `IncrementalStrategy` on the `Reader`. + Args: client: client responsible for connecting to Spark session. - columns: list of tuples for renaming/filtering the dataset. + columns: list of tuples for selecting/renaming columns on the df. + start_date: lower bound to use in the filter expression. + end_date: upper bound to use in the filter expression. """ - transformed_df = self._apply_transformations(self.consume(client)) - - if columns: - select_expression = [] - for old_expression, new_column_name in columns: - select_expression.append(f"{old_expression} as {new_column_name}") - transformed_df = transformed_df.selectExpr(*select_expression) + column_selection_df = self._select_columns(columns, client) + transformed_df = self._apply_transformations(column_selection_df) + + if self.incremental_strategy: + transformed_df = self.incremental_strategy.filter_with_incremental_strategy( + transformed_df, start_date, end_date + ) + + post_hook_df = self.run_post_hooks(transformed_df) + + post_hook_df.createOrReplaceTempView(self.id) + + def _select_columns( + self, columns: Optional[List[Any]], client: SparkClient + ) -> DataFrame: + df = self.consume(client) + return df.selectExpr( + *( + [ + f"{old_expression} as {new_column_name}" + for old_expression, new_column_name in columns + ] + if columns + else df.columns + ) + ) - transformed_df.createOrReplaceTempView(self.id) + def _apply_transformations(self, df: DataFrame) -> DataFrame: + return reduce( + lambda result_df, transformation: transformation["transformer"]( + result_df, *transformation["args"], **transformation["kwargs"] + ), + self.transformations, + df, + ) diff --git a/butterfree/extract/readers/table_reader.py b/butterfree/extract/readers/table_reader.py index 343f25f38..b5decfc1f 100644 --- a/butterfree/extract/readers/table_reader.py +++ b/butterfree/extract/readers/table_reader.py @@ -1,5 +1,7 @@ """TableSource entity.""" +from typing import Optional + from pyspark.sql import DataFrame from butterfree.clients import SparkClient @@ -44,7 +46,7 @@ class TableReader(Reader): __name__ = "Table Reader" - def __init__(self, id: str, table: str, database: str = None): + def __init__(self, id: str, table: str, database: Optional[str] = None): super().__init__(id) if not isinstance(table, str): raise ValueError( diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py index 00ac9e43f..bfc15271f 100644 --- a/butterfree/extract/source.py +++ b/butterfree/extract/source.py @@ -1,14 +1,15 @@ """Holds the SourceSelector class.""" -from typing import List +from typing import List, Optional from pyspark.sql import DataFrame from butterfree.clients import SparkClient from butterfree.extract.readers.reader import Reader +from butterfree.hooks import HookableComponent -class Source: +class Source(HookableComponent): """The definition of the the entry point data for the ETL pipeline. A FeatureSet (the next step in the pipeline) expects a single dataframe as @@ -48,34 +49,62 @@ class Source: temporary views regarding each reader and, after, will run the desired query and return a dataframe. + The `eager_evaluation` param forces Spark to apply the currently + mapped changes to the DataFrame. When this parameter is set to + False, Spark follows its standard behaviour of lazy evaluation. + Lazy evaluation can improve Spark's performance as it allows + Spark to build the best version of the execution plan. + """ - def __init__(self, readers: List[Reader], query: str) -> None: + def __init__( + self, + readers: List[Reader], + query: str, + eager_evaluation: bool = True, + ) -> None: + super().__init__() + self.enable_pre_hooks = False self.readers = readers self.query = query - - def construct(self, client: SparkClient) -> DataFrame: + self.eager_evaluation = eager_evaluation + + def construct( + self, + client: SparkClient, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> DataFrame: """Construct an entry point dataframe for a feature set. This method will assemble multiple readers, by building each one and - querying them using a Spark SQL. + querying them using a Spark SQL. It's important to highlight that in + order to filter a dataframe regarding date boundaries, it's important + to define a IncrementalStrategy, otherwise your data will not be filtered. + Besides, both start and end dates parameters are optional. After that, there's the caching of the dataframe, however since cache() in Spark is lazy, an action is triggered in order to force persistence. Args: client: client responsible for connecting to Spark session. + start_date: user defined start date for filtering. + end_date: user defined end date for filtering. Returns: DataFrame with the query result against all readers. """ for reader in self.readers: - reader.build(client) # create temporary views for each reader + reader.build( + client=client, start_date=start_date, end_date=end_date + ) # create temporary views for each reader dataframe = client.sql(self.query) - if not dataframe.isStreaming: + if not dataframe.isStreaming and self.eager_evaluation: dataframe.cache().count() - return dataframe + post_hook_df = self.run_post_hooks(dataframe) + + return post_hook_df diff --git a/butterfree/hooks/__init__.py b/butterfree/hooks/__init__.py new file mode 100644 index 000000000..e4a32170c --- /dev/null +++ b/butterfree/hooks/__init__.py @@ -0,0 +1,6 @@ +"""Holds Hooks definitions.""" + +from butterfree.hooks.hook import Hook +from butterfree.hooks.hookable_component import HookableComponent + +__all__ = ["Hook", "HookableComponent"] diff --git a/butterfree/hooks/hook.py b/butterfree/hooks/hook.py new file mode 100644 index 000000000..f7d8c562f --- /dev/null +++ b/butterfree/hooks/hook.py @@ -0,0 +1,20 @@ +"""Hook abstract class entity.""" + +from abc import ABC, abstractmethod + +from pyspark.sql import DataFrame + + +class Hook(ABC): + """Definition of a hook function to call on a Dataframe.""" + + @abstractmethod + def run(self, dataframe: DataFrame) -> DataFrame: + """Run interface for Hook. + + Args: + dataframe: dataframe to use in the Hook. + + Returns: + dataframe result from the Hook. + """ diff --git a/butterfree/hooks/hookable_component.py b/butterfree/hooks/hookable_component.py new file mode 100644 index 000000000..d89babcea --- /dev/null +++ b/butterfree/hooks/hookable_component.py @@ -0,0 +1,148 @@ +"""Definition of hookable component.""" + +from __future__ import annotations + +from typing import List + +from pyspark.sql import DataFrame + +from butterfree.hooks.hook import Hook + + +class HookableComponent: + """Defines a component with the ability to hold pre and post hook functions. + + All main module of Butterfree have a common object that enables their integration: + dataframes. Spark's dataframe is the glue that enables the transmission of data + between the main modules. Hooks have a simple interface, they are functions that + accepts a dataframe and outputs a dataframe. These Hooks can be triggered before or + after the main execution of a component. + + Components from Butterfree that inherit HookableComponent entity, are components + that can define a series of steps to occur before or after the execution of their + main functionality. + + Attributes: + pre_hooks: function steps to trigger before component main functionality. + post_hooks: function steps to trigger after component main functionality. + enable_pre_hooks: property to indicate if the component can define pre_hooks. + enable_post_hooks: property to indicate if the component can define post_hooks. + """ + + def __init__(self) -> None: + self.pre_hooks = [] + self.post_hooks = [] + self.enable_pre_hooks = True + self.enable_post_hooks = True + + @property + def pre_hooks(self) -> List[Hook]: + """Function steps to trigger before component main functionality.""" + return self.__pre_hook + + @pre_hooks.setter + def pre_hooks(self, value: List[Hook]) -> None: + if not isinstance(value, list): + raise ValueError("pre_hooks should be a list of Hooks.") + if not all(isinstance(item, Hook) for item in value): + raise ValueError( + "All items on pre_hooks list should be an instance of Hook." + ) + self.__pre_hook = value + + @property + def post_hooks(self) -> List[Hook]: + """Function steps to trigger after component main functionality.""" + return self.__post_hook + + @post_hooks.setter + def post_hooks(self, value: List[Hook]) -> None: + if not isinstance(value, list): + raise ValueError("post_hooks should be a list of Hooks.") + if not all(isinstance(item, Hook) for item in value): + raise ValueError( + "All items on post_hooks list should be an instance of Hook." + ) + self.__post_hook = value + + @property + def enable_pre_hooks(self) -> bool: + """Property to indicate if the component can define pre_hooks.""" + return self.__enable_pre_hooks + + @enable_pre_hooks.setter + def enable_pre_hooks(self, value: bool) -> None: + if not isinstance(value, bool): + raise ValueError("enable_pre_hooks accepts only boolean values.") + self.__enable_pre_hooks = value + + @property + def enable_post_hooks(self) -> bool: + """Property to indicate if the component can define post_hooks.""" + return self.__enable_post_hooks + + @enable_post_hooks.setter + def enable_post_hooks(self, value: bool) -> None: + if not isinstance(value, bool): + raise ValueError("enable_post_hooks accepts only boolean values.") + self.__enable_post_hooks = value + + def add_pre_hook(self, *hooks: Hook) -> HookableComponent: + """Add a pre-hook steps to the component. + + Args: + hooks: Hook steps to add to pre_hook list. + + Returns: + Component with the Hook inserted in pre_hook list. + + Raises: + ValueError: if the component does not accept pre-hooks. + """ + if not self.enable_pre_hooks: + raise ValueError("This component does not enable adding pre-hooks") + self.pre_hooks += list(hooks) + return self + + def add_post_hook(self, *hooks: Hook) -> HookableComponent: + """Add a post-hook steps to the component. + + Args: + hooks: Hook steps to add to post_hook list. + + Returns: + Component with the Hook inserted in post_hook list. + + Raises: + ValueError: if the component does not accept post-hooks. + """ + if not self.enable_post_hooks: + raise ValueError("This component does not enable adding post-hooks") + self.post_hooks += list(hooks) + return self + + def run_pre_hooks(self, dataframe: DataFrame) -> DataFrame: + """Run all defined pre-hook steps from a given dataframe. + + Args: + dataframe: data to input in the defined pre-hook steps. + + Returns: + dataframe after passing for all defined pre-hooks. + """ + for hook in self.pre_hooks: + dataframe = hook.run(dataframe) + return dataframe + + def run_post_hooks(self, dataframe: DataFrame) -> DataFrame: + """Run all defined post-hook steps from a given dataframe. + + Args: + dataframe: data to input in the defined post-hook steps. + + Returns: + dataframe after passing for all defined post-hooks. + """ + for hook in self.post_hooks: + dataframe = hook.run(dataframe) + return dataframe diff --git a/butterfree/hooks/schema_compatibility/__init__.py b/butterfree/hooks/schema_compatibility/__init__.py new file mode 100644 index 000000000..a00adef8d --- /dev/null +++ b/butterfree/hooks/schema_compatibility/__init__.py @@ -0,0 +1,10 @@ +"""Holds Schema Compatibility Hooks definitions.""" + +from butterfree.hooks.schema_compatibility.cassandra_table_schema_compatibility_hook import ( # noqa + CassandraTableSchemaCompatibilityHook, +) +from butterfree.hooks.schema_compatibility.spark_table_schema_compatibility_hook import ( # noqa + SparkTableSchemaCompatibilityHook, +) + +__all__ = ["SparkTableSchemaCompatibilityHook", "CassandraTableSchemaCompatibilityHook"] diff --git a/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py new file mode 100644 index 000000000..cdb40472b --- /dev/null +++ b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py @@ -0,0 +1,58 @@ +"""Cassandra table schema compatibility Hook definition.""" + +from pyspark.sql import DataFrame + +from butterfree.clients import CassandraClient +from butterfree.constants import DataType +from butterfree.hooks.hook import Hook + + +class CassandraTableSchemaCompatibilityHook(Hook): + """Hook to verify the schema compatibility with a Cassandra's table. + + Verifies if all columns presented on the dataframe exists and are the same + type on the target Cassandra's table. + + Attributes: + cassandra_client: client to connect to Cassandra DB. + table: table name. + """ + + def __init__(self, cassandra_client: CassandraClient, table: str): + self.cassandra_client = cassandra_client + self.table = table + + def run(self, dataframe: DataFrame) -> DataFrame: + """Check the schema compatibility from a given Dataframe. + + This method does not change anything on the Dataframe. + + Args: + dataframe: dataframe to verify schema compatibility. + + Returns: + unchanged dataframe. + + Raises: + ValueError if the schemas are incompatible. + """ + table_schema = self.cassandra_client.get_schema(self.table) + type_cassandra = [ + type.cassandra + for field_id in range(len(dataframe.schema.fieldNames())) + for type in DataType + if dataframe.schema.fields.__getitem__(field_id).dataType == type.spark + ] + schema = [ + {"column_name": f"{column}", "type": f"{type}"} + for column, type in zip(dataframe.columns, type_cassandra) + ] + + if not all([column in table_schema for column in schema]): + raise ValueError( + "There's a schema incompatibility " + "between the defined dataframe and the Cassandra table.\n" + f"Dataframe schema = {schema}" + f"Target table schema = {table_schema}" + ) + return dataframe diff --git a/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py new file mode 100644 index 000000000..eea50c06d --- /dev/null +++ b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py @@ -0,0 +1,50 @@ +"""Spark table schema compatibility Hook definition.""" + +from typing import Optional + +from pyspark.sql import DataFrame + +from butterfree.clients import SparkClient +from butterfree.hooks.hook import Hook + + +class SparkTableSchemaCompatibilityHook(Hook): + """Hook to verify the schema compatibility with a Spark's table. + + Verifies if all columns presented on the dataframe exists and are the same + type on the target Spark's table. + + Attributes: + spark_client: client to connect to Spark's metastore. + table: table name. + database: database name. + """ + + def __init__( + self, spark_client: SparkClient, table: str, database: Optional[str] = None + ): + self.spark_client = spark_client + self.table_expression = (f"`{database}`." if database else "") + f"`{table}`" + + def run(self, dataframe: DataFrame) -> DataFrame: + """Check the schema compatibility from a given Dataframe. + + This method does not change anything on the Dataframe. + + Args: + dataframe: dataframe to verify schema compatibility. + + Returns: + unchanged dataframe. + + Raises: + ValueError if the schemas are incompatible. + """ + table_schema = self.spark_client.conn.table(self.table_expression).schema + if not all([column in table_schema for column in dataframe.schema]): + raise ValueError( + "The dataframe has a schema incompatible with the defined table.\n" + f"Dataframe schema = {dataframe.schema}" + f"Target table schema = {table_schema}" + ) + return dataframe diff --git a/butterfree/load/processing/__init__.py b/butterfree/load/processing/__init__.py index e2ad51578..06c5cb450 100644 --- a/butterfree/load/processing/__init__.py +++ b/butterfree/load/processing/__init__.py @@ -1,4 +1,5 @@ """Pre Processing Components regarding Readers.""" + from butterfree.load.processing.json_transform import json_transform __all__ = ["json_transform"] diff --git a/butterfree/load/processing/json_transform.py b/butterfree/load/processing/json_transform.py index 19ddecae2..598064dba 100644 --- a/butterfree/load/processing/json_transform.py +++ b/butterfree/load/processing/json_transform.py @@ -1,4 +1,5 @@ """Json conversion for writers.""" + from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import struct, to_json diff --git a/butterfree/load/sink.py b/butterfree/load/sink.py index b4bf93e8c..59b001a53 100644 --- a/butterfree/load/sink.py +++ b/butterfree/load/sink.py @@ -1,17 +1,19 @@ """Holds the Sink class.""" + from typing import List, Optional from pyspark.sql.dataframe import DataFrame from pyspark.sql.streaming import StreamingQuery from butterfree.clients import SparkClient +from butterfree.hooks import HookableComponent from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet from butterfree.validations import BasicValidation from butterfree.validations.validation import Validation -class Sink: +class Sink(HookableComponent): """Define the destinations for the feature set pipeline. A Sink is created from a set of writers. The main goal of the Sink is to @@ -26,6 +28,8 @@ class Sink: """ def __init__(self, writers: List[Writer], validation: Optional[Validation] = None): + super().__init__() + self.enable_post_hooks = False self.writers = writers self.validation = validation @@ -66,14 +70,15 @@ def validate( """ failures = [] for writer in self.writers: - try: - writer.validate( - feature_set=feature_set, - dataframe=dataframe, - spark_client=spark_client, - ) - except AssertionError as e: - failures.append(e) + if writer.row_count_validation: + try: + writer.validate( + feature_set=feature_set, + dataframe=dataframe, + spark_client=spark_client, + ) + except AssertionError as e: + failures.append(e) if failures: raise RuntimeError( @@ -94,12 +99,16 @@ def flush( Streaming handlers for each defined writer, if writing streaming dfs. """ + pre_hook_df = self.run_pre_hooks(dataframe) + if self.validation is not None: - self.validation.input(dataframe).check() + self.validation.input(pre_hook_df).check() handlers = [ writer.write( - feature_set=feature_set, dataframe=dataframe, spark_client=spark_client + feature_set=feature_set, + dataframe=pre_hook_df, + spark_client=spark_client, ) for writer in self.writers ] diff --git a/butterfree/load/writers/__init__.py b/butterfree/load/writers/__init__.py index 72945d278..f1f0e449b 100644 --- a/butterfree/load/writers/__init__.py +++ b/butterfree/load/writers/__init__.py @@ -1,8 +1,9 @@ """Holds data loaders for historical and online feature store.""" +from butterfree.load.writers.delta_writer import DeltaWriter from butterfree.load.writers.historical_feature_store_writer import ( HistoricalFeatureStoreWriter, ) from butterfree.load.writers.online_feature_store_writer import OnlineFeatureStoreWriter -__all__ = ["HistoricalFeatureStoreWriter", "OnlineFeatureStoreWriter"] +__all__ = ["HistoricalFeatureStoreWriter", "OnlineFeatureStoreWriter", "DeltaWriter"] diff --git a/butterfree/load/writers/delta_writer.py b/butterfree/load/writers/delta_writer.py new file mode 100644 index 000000000..933f1adb0 --- /dev/null +++ b/butterfree/load/writers/delta_writer.py @@ -0,0 +1,162 @@ +from delta.tables import DeltaTable +from pyspark.sql.dataframe import DataFrame + +from butterfree.clients import SparkClient +from butterfree.configs.logger import __logger + +logger = __logger("delta_writer", True) + + +class DeltaWriter: + """Control operations on Delta Tables. + + Resposible for merging and optimizing. + """ + + @staticmethod + def _get_full_table_name(table, database): + if database: + return "{}.{}".format(database, table) + else: + return table + + @staticmethod + def _convert_to_delta(client: SparkClient, table: str): + logger.info(f"Converting {table} to Delta...") + client.conn.sql(f"CONVERT TO DELTA {table}") + logger.info("Conversion done.") + + @staticmethod + def merge( + client: SparkClient, + database: str, + table: str, + merge_on: list, + source_df: DataFrame, + when_not_matched_insert_condition: str = None, + when_matched_update_condition: str = None, + when_matched_delete_condition: str = None, + ): + """ + Merge a source dataframe to a Delta table. + + By default, it will update when matched, and insert when + not matched (simple upsert). + + You can change this behavior by setting: + - when_not_matched_insert_condition: it will only insert + when this specified condition is true + - when_matched_update_condition: it will only update when this + specified condition is true. You can refer to the columns + in the source dataframe as source., and the columns + in the target table as target.. + - when_matched_delete_condition: it will add an operation to delete, + but only if this condition is true. Again, source and + target dataframe columns can be referred to respectively as + source. and target. + """ + try: + full_table_name = DeltaWriter._get_full_table_name(table, database) + + table_exists = client.conn.catalog.tableExists(full_table_name) + + if table_exists: + pd_df = client.conn.sql( + f"DESCRIBE TABLE EXTENDED {full_table_name}" + ).toPandas() + provider = ( + pd_df.reset_index() + .groupby(["col_name"])["data_type"] + .aggregate("first") + .Provider + ) + table_is_delta = provider.lower() == "delta" + + if not table_is_delta: + DeltaWriter()._convert_to_delta(client, full_table_name) + + # For schema evolution + client.conn.conf.set( + "spark.databricks.delta.schema.autoMerge.enabled", "true" + ) + + target_table = DeltaTable.forName(client.conn, full_table_name) + join_condition = " AND ".join( + [f"source.{col} = target.{col}" for col in merge_on] + ) + merge_builder = target_table.alias("target").merge( + source_df.alias("source"), join_condition + ) + if when_matched_delete_condition: + merge_builder = merge_builder.whenMatchedDelete( + condition=when_matched_delete_condition + ) + + merge_builder.whenMatchedUpdateAll( + condition=when_matched_update_condition + ).whenNotMatchedInsertAll( + condition=when_not_matched_insert_condition + ).execute() + except Exception as e: + logger.error(f"Merge operation on {full_table_name} failed: {e}") + + @staticmethod + def vacuum(table: str, retention_hours: int, client: SparkClient): + """Vacuum a Delta table. + + Vacuum remove unused files (files not managed by Delta + files + that are not in the latest state). + After vacuum it's impossible to time travel to versions + older than the `retention` time. + Default retention is 7 days. Lower retentions will be warned, + unless it's set to false. + Set spark.databricks.delta.retentionDurationCheck.enabled + to false for low retentions. + https://docs.databricks.com/en/sql/language-manual/delta-vacuum.html + """ + + command = f"VACUUM {table} RETAIN {retention_hours} HOURS" + logger.info(f"Running vacuum with command {command}") + client.conn.sql(command) + logger.info(f"Vacuum successful for table {table}") + + @staticmethod + def optimize( + client: SparkClient, + table: str = None, + z_order: list = None, + date_column: str = "timestamp", + from_date: str = None, + auto_compact: bool = False, + optimize_write: bool = False, + ): + """Optimize a Delta table. + + For auto-compaction and optimize write DBR >= 14.3 LTS + and Delta >= 3.1.0 are MANDATORY. + For z-ordering DBR >= 13.3 LTS and Delta >= 2.0.0 are MANDATORY. + Auto-compaction (recommended) reduces the small file problem + (overhead due to lots of metadata). + Z-order by columns that is commonly used in queries + predicates and has a high cardinality. + https://docs.delta.io/latest/optimizations-oss.html + """ + + if auto_compact: + client.conf.set("spark.databricks.delta.autoCompact.enabled", "true") + + if optimize_write: + client.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true") + + if table: + command = f"OPTIMIZE {table}" + + if from_date: + command += f"WHERE {date_column} >= {from_date}" + + if z_order: + command += f" ZORDER BY {','.join(z_order)}" + + logger.info(f"Running optimize with command {command}...") + client.conn.sql(command) + logger.info(f"Optimize successful for table {table}.") diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py index d70f68f0b..99bfe66a5 100644 --- a/butterfree/load/writers/historical_feature_store_writer.py +++ b/butterfree/load/writers/historical_feature_store_writer.py @@ -1,7 +1,7 @@ """Holds the Historical Feature Store writer class.""" import os -from typing import Union +from typing import Any, Optional from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import dayofmonth, month, year @@ -12,6 +12,9 @@ from butterfree.constants import columns from butterfree.constants.spark_constants import DEFAULT_NUM_PARTITIONS from butterfree.dataframe_service import repartition_df +from butterfree.hooks import Hook +from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook +from butterfree.load.writers.delta_writer import DeltaWriter from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet @@ -60,6 +63,20 @@ class HistoricalFeatureStoreWriter(Writer): For what settings you can use on S3Config and default settings, to read S3Config class. + We can write with interval mode, where HistoricalFeatureStoreWrite + will need to use Dynamic Partition Inserts, + the behaviour of OVERWRITE keyword is controlled by + spark.sql.sources.partitionOverwriteMode configuration property. + The dynamic overwrite mode is enabled Spark will only delete the + partitions for which it has data to be written to. + All the other partitions remain intact. + + >>> spark_client = SparkClient() + >>> writer = HistoricalFeatureStoreWriter(interval_mode=True) + >>> writer.write(feature_set=feature_set, + ... dataframe=dataframe, + ... spark_client=spark_client) + We can instantiate HistoricalFeatureStoreWriter class to validate the df to be written. @@ -76,6 +93,15 @@ class HistoricalFeatureStoreWriter(Writer): improve queries performance. The data is stored in partition folders in AWS S3 based on time (per year, month and day). + >>> spark_client = SparkClient() + >>> writer = HistoricalFeatureStoreWriter() + >>> writer.write(feature_set=feature_set, + ... dataframe=dataframe, + ... spark_client=spark_client + ... merge_on=["id", "timestamp"]) + + This procedure will skip dataframe write and will activate Delta Merge. + Use it when the table already exist. """ PARTITION_BY = [ @@ -90,23 +116,36 @@ class HistoricalFeatureStoreWriter(Writer): def __init__( self, - db_config: Union[AbstractWriteConfig, MetastoreConfig] = None, - database: str = None, - num_partitions: int = None, + db_config: Optional[AbstractWriteConfig] = None, + database: Optional[str] = None, + num_partitions: Optional[int] = None, validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD, debug_mode: bool = False, + interval_mode: bool = False, + check_schema_hook: Optional[Hook] = None, + row_count_validation: bool = True, + merge_on: list = None, ): - super(HistoricalFeatureStoreWriter, self).__init__() - self.db_config = db_config or MetastoreConfig() + super(HistoricalFeatureStoreWriter, self).__init__( + db_config or MetastoreConfig(), + debug_mode, + interval_mode, + False, + row_count_validation, + merge_on, + ) self.database = database or environment.get_variable( "FEATURE_STORE_HISTORICAL_DATABASE" ) self.num_partitions = num_partitions or DEFAULT_NUM_PARTITIONS self.validation_threshold = validation_threshold - self.debug_mode = debug_mode + self.check_schema_hook = check_schema_hook def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> None: """Loads the data from a feature set into the Historical Feature Store. @@ -114,6 +153,7 @@ def write( feature_set: object processed with feature_set informations. dataframe: spark dataframe containing data from a feature set. spark_client: client for spark connections with external services. + merge_on: when filled, the writing is an upsert in a Delta table. If the debug_mode is set to True, a temporary table with a name in the format: historical_feature_store__{feature_set.name} will be created instead of writing @@ -124,6 +164,20 @@ def write( dataframe = self._apply_transformations(dataframe) + if self.interval_mode: + + partition_overwrite_mode = spark_client.conn.conf.get( # type: ignore + "spark.sql.sources.partitionOverwriteMode" + ).lower() + + if partition_overwrite_mode != "dynamic": + raise RuntimeError( + "m=load_incremental_table, " + "spark.sql.sources.partitionOverwriteMode={}, " + "msg=partitionOverwriteMode have to " + "be configured to 'dynamic'".format(partition_overwrite_mode) + ) + if self.debug_mode: spark_client.create_temporary_view( dataframe=dataframe, @@ -132,13 +186,23 @@ def write( return s3_key = os.path.join("historical", feature_set.entity, feature_set.name) - spark_client.write_table( - dataframe=dataframe, - database=self.database, - table_name=feature_set.name, - partition_by=self.PARTITION_BY, - **self.db_config.get_options(s3_key), - ) + + if self.merge_on: + DeltaWriter.merge( + client=spark_client, + database=self.database, + table=feature_set.name, + merge_on=self.merge_on, + source_df=dataframe, + ) + else: + spark_client.write_table( + dataframe=dataframe, + database=self.database, + table_name=feature_set.name, + partition_by=self.PARTITION_BY, + **self.db_config.get_options(s3_key), + ) def _assert_validation_count( self, table_name: str, written_count: int, dataframe_count: int @@ -166,15 +230,30 @@ def validate( Raises: AssertionError: if count of written data doesn't match count in current feature set dataframe. - """ table_name = ( - f"{self.database}.{feature_set.name}" - if not self.debug_mode - else f"historical_feature_store__{feature_set.name}" + os.path.join("historical", feature_set.entity, feature_set.name) + if self.interval_mode and not self.debug_mode + else ( + f"{self.database}.{feature_set.name}" + if not self.debug_mode + else f"historical_feature_store__{feature_set.name}" + ) ) - written_count = spark_client.read_table(table_name).count() + + written_count = ( + spark_client.read( + self.db_config.format_, + path=self.db_config.get_path_with_partitions( + table_name, self._create_partitions(dataframe) + ), + ).count() + if self.interval_mode and not self.debug_mode + else spark_client.read_table(table_name).count() + ) + dataframe_count = dataframe.count() + self._assert_validation_count(table_name, written_count, dataframe_count) def _create_partitions(self, dataframe: DataFrame) -> DataFrame: @@ -191,3 +270,25 @@ def _create_partitions(self, dataframe: DataFrame) -> DataFrame: columns.PARTITION_DAY, dayofmonth(dataframe[columns.TIMESTAMP_COLUMN]) ) return repartition_df(dataframe, self.PARTITION_BY, self.num_partitions) + + def check_schema( + self, + client: Any, + dataframe: DataFrame, + table_name: str, + database: Optional[str] = None, + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + if not self.check_schema_hook: + self.check_schema_hook = SparkTableSchemaCompatibilityHook( + client, table_name, database + ) + + return self.check_schema_hook.run(dataframe) diff --git a/butterfree/load/writers/online_feature_store_writer.py b/butterfree/load/writers/online_feature_store_writer.py index a81a1040e..bce5a3751 100644 --- a/butterfree/load/writers/online_feature_store_writer.py +++ b/butterfree/load/writers/online_feature_store_writer.py @@ -1,7 +1,7 @@ """Holds the Online Feature Store writer class.""" import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from pyspark.sql import DataFrame, Window from pyspark.sql.functions import col, row_number @@ -10,6 +10,8 @@ from butterfree.clients import SparkClient from butterfree.configs.db import AbstractWriteConfig, CassandraConfig from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.hooks import Hook +from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet @@ -66,20 +68,30 @@ class OnlineFeatureStoreWriter(Writer): Both methods (writer and validate) will need the Spark Client, Feature Set and DataFrame, to write or to validate, according to OnlineFeatureStoreWriter class arguments. + + There's an important aspect to be highlighted here: if you're using + the incremental mode, we do not check if your data is the newest before + writing to the online feature store. + + This behavior is known and will be fixed soon. """ __name__ = "Online Feature Store Writer" def __init__( self, - db_config: Union[AbstractWriteConfig, CassandraConfig] = None, - debug_mode: bool = False, - write_to_entity: bool = False, + db_config: Optional[AbstractWriteConfig] = None, + database: Optional[str] = None, + debug_mode: Optional[bool] = False, + write_to_entity: Optional[bool] = False, + interval_mode: Optional[bool] = False, + check_schema_hook: Optional[Hook] = None, ): - super(OnlineFeatureStoreWriter, self).__init__() - self.db_config = db_config or CassandraConfig() - self.debug_mode = debug_mode - self.write_to_entity = write_to_entity + super(OnlineFeatureStoreWriter, self).__init__( + db_config or CassandraConfig(), debug_mode, interval_mode, write_to_entity + ) + self.check_schema_hook = check_schema_hook + self.database = database @staticmethod def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame: @@ -104,7 +116,10 @@ def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame: window = Window.partitionBy(*id_columns).orderBy(col(TIMESTAMP_COLUMN).desc()) return ( - dataframe.select(col("*"), row_number().over(window).alias("rn"),) + dataframe.select( + col("*"), + row_number().over(window).alias("rn"), + ) .filter(col("rn") == 1) .drop("rn") ) @@ -150,7 +165,10 @@ def _write_in_debug_mode( ) def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> Union[StreamingQuery, None]: """Loads the latest data from a feature set into the Feature Store. @@ -236,3 +254,25 @@ def get_db_schema(self, feature_set: FeatureSet) -> List[Dict[Any, Any]]: """ db_schema = self.db_config.translate(feature_set.get_schema()) return db_schema + + def check_schema( + self, + client: Any, + dataframe: DataFrame, + table_name: str, + database: Optional[str] = None, + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + if not self.check_schema_hook: + self.check_schema_hook = CassandraTableSchemaCompatibilityHook( + client, table_name + ) + + return self.check_schema_hook.run(dataframe) diff --git a/butterfree/load/writers/writer.py b/butterfree/load/writers/writer.py index f76b4c253..a99514ae8 100644 --- a/butterfree/load/writers/writer.py +++ b/butterfree/load/writers/writer.py @@ -2,15 +2,17 @@ from abc import ABC, abstractmethod from functools import reduce -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from pyspark.sql.dataframe import DataFrame from butterfree.clients import SparkClient +from butterfree.configs.db import AbstractWriteConfig +from butterfree.hooks import HookableComponent from butterfree.transform import FeatureSet -class Writer(ABC): +class Writer(ABC, HookableComponent): """Abstract base class for Writers. Args: @@ -18,8 +20,23 @@ class Writer(ABC): """ - def __init__(self) -> None: + def __init__( + self, + db_config: AbstractWriteConfig, + debug_mode: Optional[bool] = False, + interval_mode: Optional[bool] = False, + write_to_entity: Optional[bool] = False, + row_count_validation: Optional[bool] = True, + merge_on: Optional[list] = None, + ) -> None: + super().__init__() + self.db_config = db_config self.transformations: List[Dict[str, Any]] = [] + self.debug_mode = debug_mode + self.interval_mode = interval_mode + self.write_to_entity = write_to_entity + self.row_count_validation = row_count_validation + self.merge_on = merge_on def with_( self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any @@ -57,7 +74,10 @@ def _apply_transformations(self, df: DataFrame) -> DataFrame: @abstractmethod def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> Any: """Loads the data from a feature set into the Feature Store. @@ -70,6 +90,23 @@ def write( """ + @abstractmethod + def check_schema( + self, + client: Any, + dataframe: DataFrame, + table_name: str, + database: Optional[str] = None, + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + @abstractmethod def validate( self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient diff --git a/butterfree/migrations/__init__.py b/butterfree/migrations/__init__.py new file mode 100644 index 000000000..791f5fefe --- /dev/null +++ b/butterfree/migrations/__init__.py @@ -0,0 +1 @@ +"""Holds available migrations.""" diff --git a/butterfree/migrations/database_migration/__init__.py b/butterfree/migrations/database_migration/__init__.py new file mode 100644 index 000000000..e31800884 --- /dev/null +++ b/butterfree/migrations/database_migration/__init__.py @@ -0,0 +1,17 @@ +"""Holds available database migrations.""" + +from butterfree.migrations.database_migration.cassandra_migration import ( + CassandraMigration, +) +from butterfree.migrations.database_migration.database_migration import Diff +from butterfree.migrations.database_migration.metastore_migration import ( + MetastoreMigration, +) + +__all__ = ["CassandraMigration", "MetastoreMigration", "Diff"] + + +ALLOWED_DATABASE = { + "cassandra": CassandraMigration(), + "metastore": MetastoreMigration(), +} diff --git a/butterfree/migrations/database_migration/cassandra_migration.py b/butterfree/migrations/database_migration/cassandra_migration.py new file mode 100644 index 000000000..5a4f755f9 --- /dev/null +++ b/butterfree/migrations/database_migration/cassandra_migration.py @@ -0,0 +1,143 @@ +"""Cassandra Migration entity.""" + +from typing import Any, Dict, List + +from butterfree.clients import CassandraClient +from butterfree.configs.db import CassandraConfig +from butterfree.migrations.database_migration.database_migration import ( + DatabaseMigration, + Diff, +) + + +class CassandraMigration(DatabaseMigration): + """Cassandra class for performing migrations. + + This class implements some methods of the parent DatabaseMigration class and + has specific methods for query building. + + The CassandraMigration class will be used, as the name suggests, for applying + changes to a given Cassandra table. There are, however, some remarks that need + to be highlighted: + - If an existing feature has its type changed, then it's extremely important to + make sure that this conversion would not result in data loss; + - If new features are added to your feature set, then they're going to be added + to the corresponding Cassandra table; + - Since feature sets can be written both to a feature set and an entity table, + we're not going to automatically drop features when using entity tables, since + it means that some features belong to a different feature set. In summary, if + data is being loaded into an entity table, then users can drop columns manually. + + """ + + def __init__(self) -> None: + self._db_config = CassandraConfig() + super(CassandraMigration, self).__init__( + CassandraClient( + host=[self._db_config.host], + keyspace=self._db_config.keyspace, # type: ignore + user=self._db_config.username, + password=self._db_config.password, + ) + ) + + @staticmethod + def _get_parsed_columns(columns: List[Diff]) -> List[str]: + """Parse columns from a list of Diff objects. + + Args: + columns: list of Diff objects. + + Returns: + Parsed columns. + + """ + parsed_columns = [] + for col in columns: + parsed_columns.append(f"{col.column} {col.value}") + + parsed_columns = ", ".join(parsed_columns) # type: ignore + + return parsed_columns + + def _get_alter_table_add_query(self, columns: List[Diff], table_name: str) -> str: + """Creates CQL statement to add columns to a table. + + Args: + columns: list of Diff objects with ADD kind. + table_name: table name. + + Returns: + Alter table query. + + """ + parsed_columns = self._get_parsed_columns(columns) + + return f"ALTER TABLE {table_name} ADD ({parsed_columns});" + + def _get_alter_column_type_query(self, column: Diff, table_name: str) -> str: + """Creates CQL statement to alter columns' types. + + Args: + columns: list of Diff objects with ALTER_TYPE kind. + table_name: table name. + + Returns: + Alter column type query. + + """ + parsed_columns = self._get_parsed_columns([column]) + + return ( + f"ALTER TABLE {table_name} ALTER {parsed_columns.replace(' ', ' TYPE ')};" + ) + + @staticmethod + def _get_create_table_query(columns: List[Dict[str, Any]], table_name: str) -> str: + """Creates CQL statement to create a table. + + Args: + columns: object that contains column's schemas. + table_name: table name. + + Returns: + Create table query. + + """ + parsed_columns = [] + primary_keys = [] + + for col in columns: + col_str = f"{col['column_name']} {col['type']}" + if col["primary_key"]: + primary_keys.append(col["column_name"]) + parsed_columns.append(col_str) + + joined_parsed_columns = ", ".join(parsed_columns) + + if len(primary_keys) > 0: + joined_primary_keys = ", ".join(primary_keys) + columns_str = ( + f"{joined_parsed_columns}, PRIMARY KEY ({joined_primary_keys})" + ) + else: + columns_str = joined_parsed_columns + + keyspace = CassandraConfig().keyspace + + return f"CREATE TABLE {keyspace}.{table_name} " f"({columns_str});" + + def _get_alter_table_drop_query(self, columns: List[Diff], table_name: str) -> str: + """Creates CQL statement to drop columns from a table. + + Args: + columns: list of Diff objects with DROP kind. + table_name: table name. + + Returns: + Drop columns from a given table query. + + """ + parsed_columns = self._get_parsed_columns(columns) + + return f"ALTER TABLE {table_name} DROP ({parsed_columns});" diff --git a/butterfree/migrations/database_migration/database_migration.py b/butterfree/migrations/database_migration/database_migration.py new file mode 100644 index 000000000..351a47243 --- /dev/null +++ b/butterfree/migrations/database_migration/database_migration.py @@ -0,0 +1,307 @@ +"""Migration entity.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Dict, List, Optional, Set + +from butterfree.clients import AbstractClient +from butterfree.configs.logger import __logger +from butterfree.load.writers.writer import Writer +from butterfree.transform import FeatureSet + +logger = __logger("database_migrate", True) + + +@dataclass +class Diff: + """DataClass to help identifying different types of diff between schemas.""" + + class Kind(Enum): + """Mapping actions to take given a difference between columns of a schema.""" + + ADD = auto() + ALTER_KEY = auto() + ALTER_TYPE = auto() + DROP = auto() + + column: str + kind: Kind + value: Any + + def __hash__(self) -> int: + return hash((self.column, self.kind, self.value)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + raise NotImplementedError + return ( + self.column == other.column + and self.kind == other.kind + and self.value == other.value + ) + + +class DatabaseMigration(ABC): + """Abstract base class for Migrations.""" + + def __init__(self, client: AbstractClient) -> None: + self._client = client + + @abstractmethod + def _get_create_table_query( + self, columns: List[Dict[str, Any]], table_name: str + ) -> Any: + """Creates desired statement to create a table. + + Args: + columns: object that contains column's schemas. + table_name: table name. + + Returns: + Create table query. + + """ + pass + + @abstractmethod + def _get_alter_table_add_query(self, columns: List[Diff], table_name: str) -> str: + """Creates desired statement to add columns to a table. + + Args: + columns: list of Diff objects with ADD kind. + table_name: table name. + + Returns: + Alter table query. + + """ + pass + + @abstractmethod + def _get_alter_table_drop_query(self, columns: List[Diff], table_name: str) -> str: + """Creates desired statement to drop columns from a table. + + Args: + columns: list of Diff objects with DROP kind. + table_name: table name. + + Returns: + Drop columns from a given table query. + + """ + pass + + @abstractmethod + def _get_alter_column_type_query(self, column: Diff, table_name: str) -> str: + """Creates desired statement to alter columns' types. + + Args: + columns: list of Diff objects with ALTER_TYPE kind. + table_name: table name. + + Returns: + Alter column type query. + + """ + pass + + def _get_queries( + self, + schema_diff: Set[Diff], + table_name: str, + write_on_entity: Optional[bool] = None, + ) -> Any: + """Create the desired queries for migration. + + Args: + schema_diff: list of Diff objects. + table_name: table name. + + Returns: + List of queries. + + """ + add_items = [] + drop_items = [] + alter_type_items = [] + alter_key_items = [] + + for diff in schema_diff: + if diff.kind == Diff.Kind.ADD: + add_items.append(diff) + elif diff.kind == Diff.Kind.ALTER_TYPE: + alter_type_items.append(diff) + elif diff.kind == Diff.Kind.DROP: + drop_items.append(diff) + elif diff.kind == Diff.Kind.ALTER_KEY: + alter_key_items.append(diff) + + queries = [] + if add_items: + alter_table_add_query = self._get_alter_table_add_query( + add_items, table_name + ) + queries.append(alter_table_add_query) + if drop_items: + if not write_on_entity: + drop_columns_query = self._get_alter_table_drop_query( + drop_items, table_name + ) + queries.append(drop_columns_query) + if alter_type_items: + for item in alter_type_items: + alter_column_types_query = self._get_alter_column_type_query( + item, table_name + ) + queries.append(alter_column_types_query) + if alter_key_items: + logger.warning( + "The 'change the primary key column' action is not supported by Spark." + ) + + return queries + + def create_query( + self, + fs_schema: List[Dict[str, Any]], + table_name: str, + db_schema: Optional[List[Dict[str, Any]]] = None, + write_on_entity: Optional[bool] = None, + ) -> Any: + """Create a query regarding a data source. + + Returns: + The desired queries for the given database. + + """ + if not db_schema: + return [self._get_create_table_query(fs_schema, table_name)] + + schema_diff = self._get_diff(fs_schema, db_schema) + + return self._get_queries(schema_diff, table_name, write_on_entity) + + @staticmethod + def _get_diff( + fs_schema: List[Dict[str, Any]], + db_schema: List[Dict[str, Any]], + ) -> Set[Diff]: + """Gets schema difference between feature set and the table of a given db. + + Args: + fs_schema: object that contains feature set's schemas. + db_schema: object that contains the table of a given db schema. + + Returns: + Object with schema differences. + + """ + db_columns = set(item.get("column_name") for item in db_schema) + fs_columns = set(item.get("column_name") for item in fs_schema) + + add_columns = fs_columns - db_columns + drop_columns = db_columns - fs_columns + + # This could be way easier to write (and to read) if the schemas were a simple + # Dict[str, Any] where each key would be the column name itself... + # but changing that could break things so: + # TODO version 2 change get schema to return a dict(columns, properties) + add_type_columns = dict() + alter_type_columns = dict() + alter_key_columns = dict() + for fs_item in fs_schema: + if fs_item.get("column_name") in add_columns: + add_type_columns.update( + {fs_item.get("column_name"): fs_item.get("type")} + ) + for db_item in db_schema: + if fs_item.get("column_name") == db_item.get("column_name"): + if fs_item.get("type") != db_item.get("type"): + if fs_item.get("primary_key") is True: + logger.warning( + "Type changes are not applied to " + "columns that are the primary key." + ) + alter_type_columns.update( + {fs_item.get("column_name"): fs_item.get("type")} + ) + if fs_item.get("primary_key") != db_item.get("primary_key"): + alter_key_columns.update( + {fs_item.get("column_name"): fs_item.get("primary_key")} + ) + break + + schema_diff = set( + Diff(str(col), kind=Diff.Kind.ADD, value=value) + for col, value in add_type_columns.items() + ) + schema_diff |= set( + Diff(str(col), kind=Diff.Kind.DROP, value=None) for col in drop_columns + ) + schema_diff |= set( + Diff(str(col), kind=Diff.Kind.ALTER_TYPE, value=value) + for col, value in alter_type_columns.items() + ) + schema_diff |= set( + Diff(str(col), kind=Diff.Kind.ALTER_KEY, value=None) + for col, value in alter_key_columns.items() + ) + return schema_diff + + def _get_schema( + self, table_name: str, database: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Get a table schema in the respective database. + + Args: + table_name: Table name to get schema. + + Returns: + Schema object. + """ + try: + db_schema = self._client.get_schema(table_name, database) + except Exception: # noqa + db_schema = [] + return db_schema + + def apply_migration( + self, feature_set: FeatureSet, writer: Writer, debug_mode: bool + ) -> None: + """Apply the migration in the respective database. + + Args: + feature_set: the feature set. + writer: the writer being used to load the feature set. + debug_mode: if active, it brings up the queries generated. + """ + logger.info(f"Migrating feature set: {feature_set.name}") + + table_name = ( + feature_set.name if not writer.write_to_entity else feature_set.entity + ) + + fs_schema = writer.db_config.translate(feature_set.get_schema()) + db_schema = self._get_schema(table_name, writer.database) + + queries = self.create_query( + fs_schema, table_name, db_schema, writer.write_to_entity + ) + + if debug_mode: + print( + "#### DEBUG MODE ###\n" + f"Feature set: {feature_set.name}\n" + "Queries:\n" + f"{queries}" + ) + else: + for q in queries: + logger.info(f"Applying this query: {q} ...") + self._client.sql(q) + + logger.info("Feature Set migration finished successfully.") + + # inform in drone console which feature set was migrated + print(f"The {feature_set.name} feature set was migrated.") diff --git a/butterfree/migrations/database_migration/metastore_migration.py b/butterfree/migrations/database_migration/metastore_migration.py new file mode 100644 index 000000000..07e2bd89f --- /dev/null +++ b/butterfree/migrations/database_migration/metastore_migration.py @@ -0,0 +1,135 @@ +"""Metastore Migration entity.""" + +from typing import Any, Dict, List, Optional + +from butterfree.clients import SparkClient +from butterfree.configs import environment +from butterfree.configs.db import MetastoreConfig +from butterfree.constants.migrations import PARTITION_BY +from butterfree.migrations.database_migration.database_migration import ( + DatabaseMigration, + Diff, +) + + +class MetastoreMigration(DatabaseMigration): + """MetastoreMigration class for performing migrations. + + This class implements some methods of the parent DatabaseMigration class and + has specific methods for query building. + The MetastoreMigration class will be used, as the name suggests, for applying + changes to a given Metastore table. There are, however, some remarks that need + to be highlighted: + - If an existing feature has its type changed, then it's extremely important to + make sure that this conversion would not result in data loss; + - If new features are added to your feature set, then they're going to be added + to the corresponding Metastore table; + - Since feature sets can be written both to a feature set and an entity table, + we're not going to automatically drop features when using entity tables, since + it means that some features belong to a different feature set. In summary, if + data is being loaded into an entity table, then users can drop columns manually. + """ + + def __init__( + self, + database: Optional[str] = None, + ) -> None: + self._db_config = MetastoreConfig() + self.database = database or environment.get_variable( + "FEATURE_STORE_HISTORICAL_DATABASE" + ) + super(MetastoreMigration, self).__init__(SparkClient()) + + @staticmethod + def _get_parsed_columns(columns: List[Diff]) -> List[str]: + """Parse columns from a list of Diff objects. + + Args: + columns: list of Diff objects. + + Returns: + Parsed columns. + + """ + parsed_columns = [] + for col in columns: + parsed_columns.append(f"{col.column} {col.value}") + + parsed_columns = ", ".join(parsed_columns) # type: ignore + + return parsed_columns + + def _get_alter_table_add_query(self, columns: List[Diff], table_name: str) -> str: + """Creates SQL statement to add columns to a table. + + Args: + columns: list of Diff objects with ADD kind. + table_name: table name. + + Returns: + Alter table query. + + """ + parsed_columns = self._get_parsed_columns(columns) + + return ( + f"ALTER TABLE {self.database}.{table_name} " + f"ADD IF NOT EXISTS columns ({parsed_columns});" + ) + + def _get_alter_column_type_query(self, column: Diff, table_name: str) -> str: + """Creates SQL statement to alter columns' types. + + Args: + columns: list of Diff objects with ALTER_TYPE kind. + table_name: table name. + + Returns: + Alter column type query. + + """ + parsed_columns = self._get_parsed_columns([column]) + + return f"ALTER TABLE {table_name} ALTER COLUMN {parsed_columns};" + + def _get_create_table_query( + self, columns: List[Dict[str, Any]], table_name: str + ) -> str: + """Creates SQL statement to create a table. + + Args: + columns: object that contains column's schemas. + table_name: table name. + + Returns: + Create table query. + + """ + parsed_columns = [] + for col in columns: + parsed_columns.append(f"{col['column_name']} {col['type']}") + parsed_columns = ", ".join(parsed_columns) # type: ignore + + return ( + f"CREATE TABLE IF NOT EXISTS " + f"{self.database}.{table_name} ({parsed_columns}) " + f"PARTITIONED BY (" + f"{PARTITION_BY[0]['column_name']} {PARTITION_BY[0]['type']}, " + f"{PARTITION_BY[1]['column_name']} {PARTITION_BY[1]['type']}, " + f"{PARTITION_BY[2]['column_name']} {PARTITION_BY[2]['type']});" + ) + + def _get_alter_table_drop_query(self, columns: List[Diff], table_name: str) -> str: + """Creates SQL statement to drop columns from a table. + + Args: + columns: list of Diff objects with DROP kind. + table_name: table name. + + Returns: + Drop columns from a given table query. + + """ + parsed_columns = self._get_parsed_columns(columns) + + return f"ALTER TABLE {table_name} DROP IF EXISTS ({parsed_columns});" diff --git a/butterfree/pipelines/__init__.py b/butterfree/pipelines/__init__.py index a868e48f2..8bbc5c39e 100644 --- a/butterfree/pipelines/__init__.py +++ b/butterfree/pipelines/__init__.py @@ -1,4 +1,5 @@ """ETL Pipelines.""" + from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline __all__ = ["FeatureSetPipeline"] diff --git a/butterfree/pipelines/feature_set_pipeline.py b/butterfree/pipelines/feature_set_pipeline.py index ce1b7ba4d..8ba1a636c 100644 --- a/butterfree/pipelines/feature_set_pipeline.py +++ b/butterfree/pipelines/feature_set_pipeline.py @@ -1,5 +1,6 @@ """FeatureSetPipeline entity.""" -from typing import List + +from typing import List, Optional from butterfree.clients import SparkClient from butterfree.dataframe_service import repartition_sort_df @@ -40,11 +41,12 @@ class FeatureSetPipeline: ... ) >>> from butterfree.load import Sink >>> from butterfree.load.writers import HistoricalFeatureStoreWriter - >>> import pyspark.sql.functions as F + >>> from pyspark.sql import functions >>> def divide(df, fs, column1, column2): ... name = fs.get_output_columns()[0] - ... df = df.withColumn(name, F.col(column1) / F.col(column2)) + ... df = df.withColumn(name, + ... functions.col(column1) / functions.col(column2)) ... return df >>> pipeline = FeatureSetPipeline( @@ -67,7 +69,8 @@ class FeatureSetPipeline: ... name="feature1", ... description="test", ... transformation=SparkFunctionTransform( - ... functions=[F.avg, F.stddev_pop] + ... functions=[Function(functions.avg, DataType.DOUBLE), + ... Function(functions.stddev_pop, DataType.DOUBLE)], ... ).with_window( ... partition_by="id", ... order_by=TIMESTAMP_COLUMN, @@ -113,6 +116,19 @@ class FeatureSetPipeline: the defined sources, compute all the transformations and save the data to the specified locations. + We can run the pipeline over a range of dates by passing an end-date + and a start-date, where it will only bring data within this date range. + + >>> pipeline.run(end_date="2020-08-04", start_date="2020-07-04") + + Or run up to a date, where it will only bring data up to the specific date. + + >>> pipeline.run(end_date="2020-08-04") + + Or just a specific date, where you will only bring data for that day. + + >>> pipeline.run_for_date(execution_date="2020-08-04") + """ def __init__( @@ -120,7 +136,7 @@ def __init__( source: Source, feature_set: FeatureSet, sink: Sink, - spark_client: SparkClient = None, + spark_client: Optional[SparkClient] = None, ): self.source = source self.feature_set = feature_set @@ -175,10 +191,11 @@ def spark_client(self, spark_client: SparkClient) -> None: def run( self, - end_date: str = None, - partition_by: List[str] = None, - order_by: List[str] = None, - num_processors: int = None, + end_date: Optional[str] = None, + partition_by: Optional[List[str]] = None, + order_by: Optional[List[str]] = None, + num_processors: Optional[int] = None, + start_date: Optional[str] = None, ) -> None: """Runs the defined feature set pipeline. @@ -192,7 +209,11 @@ def run( soon. Use only if strictly necessary. """ - dataframe = self.source.construct(client=self.spark_client) + dataframe = self.source.construct( + client=self.spark_client, + start_date=self.feature_set.define_start_date(start_date), + end_date=end_date, + ) if partition_by: order_by = order_by or partition_by @@ -203,6 +224,7 @@ def run( dataframe = self.feature_set.construct( dataframe=dataframe, client=self.spark_client, + start_date=start_date, end_date=end_date, num_processors=num_processors, ) @@ -219,3 +241,30 @@ def run( feature_set=self.feature_set, spark_client=self.spark_client, ) + + def run_for_date( + self, + execution_date: Optional[str] = None, + partition_by: Optional[List[str]] = None, + order_by: Optional[List[str]] = None, + num_processors: Optional[int] = None, + ) -> None: + """Runs the defined feature set pipeline for a specific date. + + The pipeline consists in the following steps: + + - Constructs the input dataframe from the data source. + - Construct the feature set dataframe using the defined Features. + - Load the data to the configured sink locations. + + It's important to notice, however, that both parameters partition_by + and num_processors are WIP, we intend to enhance their functionality + soon. Use only if strictly necessary. + """ + self.run( + start_date=execution_date, + end_date=execution_date, + partition_by=partition_by, + order_by=order_by, + num_processors=num_processors, + ) diff --git a/butterfree/reports/__init__.py b/butterfree/reports/__init__.py index 4b57dafc2..d272943d9 100644 --- a/butterfree/reports/__init__.py +++ b/butterfree/reports/__init__.py @@ -1,4 +1,5 @@ """Reports module.""" + from butterfree.reports.metadata import Metadata __all__ = ["Metadata"] diff --git a/butterfree/reports/metadata.py b/butterfree/reports/metadata.py index d54bbba9d..dc1f7cbb4 100644 --- a/butterfree/reports/metadata.py +++ b/butterfree/reports/metadata.py @@ -162,7 +162,7 @@ def to_json(self) -> Any: "features": [ { "column_name": c["column_name"], - "data_type": str(c["type"]), + "data_type": str(c["type"]).replace("()", ""), "description": desc, } for c, desc in params._features @@ -208,7 +208,7 @@ def to_markdown(self) -> Any: features = ["Column name", "Data type", "Description"] for c, desc in params._features: - features.extend([c["column_name"], str(c["type"]), desc]) + features.extend([c["column_name"], str(c["type"]).replace("()", ""), desc]) count_rows = len(features) // 3 diff --git a/butterfree/testing/dataframe/__init__.py b/butterfree/testing/dataframe/__init__.py index 15481a54a..5b465bc64 100644 --- a/butterfree/testing/dataframe/__init__.py +++ b/butterfree/testing/dataframe/__init__.py @@ -1,6 +1,7 @@ """Methods to assert properties regarding Apache Spark Dataframes.""" + from json import dumps -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pyspark import SparkContext from pyspark.sql import Column, DataFrame, SparkSession @@ -72,7 +73,7 @@ def create_df_from_collection( data: List[Dict[Any, Any]], spark_context: SparkContext, spark_session: SparkSession, - schema: StructType = None, + schema: Optional[StructType] = None, ) -> DataFrame: """Creates a dataframe from a list of dicts.""" return spark_session.read.json( diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py index f43c12d5d..6706bf8cf 100644 --- a/butterfree/transform/aggregated_feature_set.py +++ b/butterfree/transform/aggregated_feature_set.py @@ -1,6 +1,7 @@ """AggregatedFeatureSet entity.""" + import itertools -from datetime import timedelta +from datetime import datetime, timedelta from functools import reduce from typing import Any, Dict, List, Optional, Union @@ -8,6 +9,7 @@ from pyspark.sql import DataFrame, functions from butterfree.clients import SparkClient +from butterfree.constants.window_definitions import ALLOWED_WINDOWS from butterfree.dataframe_service import repartition_df from butterfree.transform import FeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature @@ -196,6 +198,8 @@ def __init__( keys: List[KeyFeature], timestamp: TimestampFeature, features: List[Feature], + deduplicate_rows: bool = True, + eager_evaluation: bool = True, ): self._windows: List[Any] = [] self._pivot_column: Optional[str] = None @@ -203,7 +207,14 @@ def __init__( self._distinct_subset: List[Any] = [] self._distinct_keep: Optional[str] = None super(AggregatedFeatureSet, self).__init__( - name, entity, description, keys, timestamp, features, + name, + entity, + description, + keys, + timestamp, + features, + deduplicate_rows, + eager_evaluation, ) @property @@ -251,8 +262,8 @@ def _has_aggregated_transform_only(features: List[Feature]) -> bool: @staticmethod def _build_feature_column_name( feature_column: str, - pivot_value: Union[float, str] = None, - window: Window = None, + pivot_value: Optional[Union[float, str]] = None, + window: Optional[Window] = None, ) -> str: base_name = feature_column if pivot_value is not None: @@ -300,7 +311,9 @@ def with_distinct(self, subset: List, keep: str = "last") -> "AggregatedFeatureS return self - def with_windows(self, definitions: List[str]) -> "AggregatedFeatureSet": + def with_windows( + self, definitions: List[str], slide: Optional[str] = None + ) -> "AggregatedFeatureSet": """Create a list with windows defined.""" self._windows = [ Window( @@ -308,6 +321,7 @@ def with_windows(self, definitions: List[str]) -> "AggregatedFeatureSet": order_by=None, mode="rolling_windows", window_definition=definition, + slide=slide, ) for definition in definitions ] @@ -354,7 +368,7 @@ def _dataframe_join( right: DataFrame, on: List[str], how: str, - num_processors: int = None, + num_processors: Optional[int] = None, ) -> DataFrame: # make both tables co-partitioned to improve join performance left = repartition_df(left, partition_by=on, num_processors=num_processors) @@ -366,7 +380,7 @@ def _aggregate( dataframe: DataFrame, features: List[Feature], window: Optional[Window] = None, - num_processors: int = None, + num_processors: Optional[int] = None, ) -> DataFrame: aggregations = [ c.function for f in features for c in f.transformation.aggregations @@ -399,7 +413,9 @@ def _aggregate( # repartition to have all rows for each group at the same partition # by doing that, we won't have to shuffle data on grouping by id dataframe = repartition_df( - dataframe, partition_by=groupby, num_processors=num_processors, + dataframe, + partition_by=groupby, + num_processors=num_processors, ) grouped_data = dataframe.groupby(*groupby) @@ -488,12 +504,45 @@ def get_schema(self) -> List[Dict[str, Any]]: return schema + @staticmethod + def _get_biggest_window_in_days(definitions: List[str]) -> float: + windows_list = [] + for window in definitions: + windows_list.append( + int(window.split()[0]) * ALLOWED_WINDOWS[window.split()[1]] + ) + return max(windows_list) / (60 * 60 * 24) + + def define_start_date(self, start_date: Optional[str] = None) -> Optional[str]: + """Get aggregated feature set start date. + + Args: + start_date: start date regarding source dataframe. + + Returns: + start date. + """ + if self._windows and start_date: + window_definition = [ + definition.frame_boundaries.window_definition + for definition in self._windows + ] + biggest_window = self._get_biggest_window_in_days(window_definition) + + return ( + datetime.strptime(start_date, "%Y-%m-%d") + - timedelta(days=int(biggest_window) + 1) + ).strftime("%Y-%m-%d") + + return start_date + def construct( self, dataframe: DataFrame, client: SparkClient, - end_date: str = None, - num_processors: int = None, + end_date: Optional[str] = None, + num_processors: Optional[int] = None, + start_date: Optional[str] = None, ) -> DataFrame: """Use all the features to build the feature set dataframe. @@ -506,6 +555,7 @@ def construct( client: client responsible for connecting to Spark session. end_date: user defined max date for having aggregated data (exclusive). num_processors: cluster total number of processors for repartitioning. + start_date: user defined min date for having aggregated data. Returns: Spark dataframe with all the feature columns. @@ -519,19 +569,15 @@ def construct( if not isinstance(dataframe, DataFrame): raise ValueError("source_df must be a dataframe") + pre_hook_df = self.run_pre_hooks(dataframe) + output_df = reduce( lambda df, feature: feature.transform(df), self.keys + [self.timestamp], - dataframe, + pre_hook_df, ) if self._windows and end_date is not None: - # prepare our left table, a cartesian product between distinct keys - # and dates in range for this feature set - base_df = self._get_base_dataframe( - client=client, dataframe=output_df, end_date=end_date - ) - # run aggregations for each window agg_list = [ self._aggregate( @@ -543,26 +589,60 @@ def construct( for w in self._windows ] - # left join each aggregation result to our base dataframe - output_df = reduce( - lambda left, right: self._dataframe_join( - left, - right, - on=self.keys_columns + [self.timestamp_column], - how="left", - num_processors=num_processors, - ), - agg_list, - base_df, - ) + # prepare our left table, a cartesian product between distinct keys + # and dates in range for this feature set + + # todo next versions won't use this logic anymore, + # leaving for the client to correct the usage of aggregations + # without events + + # keeping this logic to maintain the same behavior for already implemented + # feature sets + + if self._windows[0].slide == "1 day": + base_df = self._get_base_dataframe( + client=client, dataframe=output_df, end_date=end_date + ) + + # left join each aggregation result to our base dataframe + output_df = reduce( + lambda left, right: self._dataframe_join( + left, + right, + on=self.keys_columns + [self.timestamp_column], + how="left", + num_processors=num_processors, + ), + agg_list, + base_df, + ) + else: + output_df = reduce( + lambda left, right: self._dataframe_join( + left, + right, + on=self.keys_columns + [self.timestamp_column], + how="full_outer", + num_processors=num_processors, + ), + agg_list, + ) else: output_df = self._aggregate(output_df, features=self.features) + output_df = self.incremental_strategy.filter_with_incremental_strategy( + dataframe=output_df, start_date=start_date, end_date=end_date + ) + output_df = output_df.select(*self.columns).replace( # type: ignore float("nan"), None ) if not output_df.isStreaming: - output_df = self._filter_duplicated_rows(output_df) - output_df.cache().count() + if self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) + if self.eager_evaluation: + output_df.cache().count() + + post_hook_df = self.run_post_hooks(output_df) - return output_df + return post_hook_df diff --git a/butterfree/transform/feature_set.py b/butterfree/transform/feature_set.py index c35e90fa1..369eaf290 100644 --- a/butterfree/transform/feature_set.py +++ b/butterfree/transform/feature_set.py @@ -1,7 +1,8 @@ """FeatureSet entity.""" + import itertools from functools import reduce -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pyspark.sql.functions as F from pyspark.sql import Window @@ -9,6 +10,8 @@ from butterfree.clients import SparkClient from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service import IncrementalStrategy +from butterfree.hooks import HookableComponent from butterfree.transform.features import Feature, KeyFeature, TimestampFeature from butterfree.transform.transformations import ( AggregatedTransform, @@ -16,7 +19,7 @@ ) -class FeatureSet: +class FeatureSet(HookableComponent): """Holds metadata about the feature set and constructs the final dataframe. Attributes: @@ -95,6 +98,12 @@ class FeatureSet: values over key columns and timestamp column, we do this in order to reduce our dataframe (regarding the number of rows). A detailed explation of this method can be found at filter_duplicated_rows docstring. + + The `eager_evaluation` param forces Spark to apply the currently + mapped changes to the DataFrame. When this parameter is set to + False, Spark follows its standard behaviour of lazy evaluation. + Lazy evaluation can improve Spark's performance as it allows + Spark to build the best version of the execution plan. """ def __init__( @@ -105,13 +114,19 @@ def __init__( keys: List[KeyFeature], timestamp: TimestampFeature, features: List[Feature], + deduplicate_rows: bool = True, + eager_evaluation: bool = True, ) -> None: + super().__init__() self.name = name self.entity = entity self.description = description self.keys = keys self.timestamp = timestamp self.features = features + self.incremental_strategy = IncrementalStrategy(column=TIMESTAMP_COLUMN) + self.deduplicate_rows = deduplicate_rows + self.eager_evaluation = eager_evaluation @property def name(self) -> str: @@ -243,9 +258,6 @@ def columns(self) -> List[str]: def get_schema(self) -> List[Dict[str, Any]]: """Get feature set schema. - Args: - feature_set: object processed with feature set metadata. - Returns: List of dicts regarding cassandra feature set schema. @@ -378,12 +390,24 @@ def _filter_duplicated_rows(self, df: DataFrame) -> DataFrame: return df.select([column for column in self.columns]) + def define_start_date(self, start_date: Optional[str] = None) -> Optional[str]: + """Get feature set start date. + + Args: + start_date: start date regarding source dataframe. + + Returns: + start date. + """ + return start_date + def construct( self, dataframe: DataFrame, client: SparkClient, - end_date: str = None, - num_processors: int = None, + end_date: Optional[str] = None, + num_processors: Optional[int] = None, + start_date: Optional[str] = None, ) -> DataFrame: """Use all the features to build the feature set dataframe. @@ -393,7 +417,8 @@ def construct( Args: dataframe: input dataframe to be transformed by the features. client: client responsible for connecting to Spark session. - end_date: user defined base date. + start_date: user defined start date. + end_date: user defined end date. num_processors: cluster total number of processors for repartitioning. Returns: @@ -403,14 +428,24 @@ def construct( if not isinstance(dataframe, DataFrame): raise ValueError("source_df must be a dataframe") + pre_hook_df = self.run_pre_hooks(dataframe) + output_df = reduce( lambda df, feature: feature.transform(df), self.keys + [self.timestamp] + self.features, - dataframe, + pre_hook_df, ).select(*self.columns) if not output_df.isStreaming: - output_df = self._filter_duplicated_rows(output_df) - output_df.cache().count() + if self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) + if self.eager_evaluation: + output_df.cache().count() + + output_df = self.incremental_strategy.filter_with_incremental_strategy( + dataframe=output_df, start_date=start_date, end_date=end_date + ) + + post_hook_df = self.run_post_hooks(output_df) - return output_df + return post_hook_df diff --git a/butterfree/transform/features/feature.py b/butterfree/transform/features/feature.py index 612fc4a2f..cfd8a2f61 100644 --- a/butterfree/transform/features/feature.py +++ b/butterfree/transform/features/feature.py @@ -1,6 +1,7 @@ """Feature entity.""" + import warnings -from typing import Any, List +from typing import Any, List, Optional from pyspark.sql import DataFrame from pyspark.sql.functions import col @@ -41,9 +42,9 @@ def __init__( self, name: str, description: str, - dtype: DataType = None, - from_column: str = None, - transformation: TransformComponent = None, + dtype: Optional[DataType] = None, + from_column: Optional[str] = None, + transformation: Optional[TransformComponent] = None, ) -> None: self.name = name self.description = description diff --git a/butterfree/transform/features/key_feature.py b/butterfree/transform/features/key_feature.py index a7ad350cb..74626d6fa 100644 --- a/butterfree/transform/features/key_feature.py +++ b/butterfree/transform/features/key_feature.py @@ -1,5 +1,7 @@ """KeyFeature entity.""" +from typing import Optional + from butterfree.constants.data_type import DataType from butterfree.transform.features.feature import Feature from butterfree.transform.transformations import TransformComponent @@ -31,8 +33,8 @@ def __init__( name: str, description: str, dtype: DataType, - from_column: str = None, - transformation: TransformComponent = None, + from_column: Optional[str] = None, + transformation: Optional[TransformComponent] = None, ) -> None: super(KeyFeature, self).__init__( name=name, diff --git a/butterfree/transform/features/timestamp_feature.py b/butterfree/transform/features/timestamp_feature.py index 2aac8925a..b4aee71e5 100644 --- a/butterfree/transform/features/timestamp_feature.py +++ b/butterfree/transform/features/timestamp_feature.py @@ -1,6 +1,9 @@ """TimestampFeature entity.""" + +from typing import Optional + from pyspark.sql import DataFrame -from pyspark.sql.functions import from_unixtime, to_timestamp +from pyspark.sql.functions import to_timestamp from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN @@ -38,17 +41,18 @@ class TimestampFeature(Feature): def __init__( self, - from_column: str = None, - transformation: TransformComponent = None, + dtype: Optional[DataType] = DataType.TIMESTAMP, + from_column: Optional[str] = None, + transformation: Optional[TransformComponent] = None, from_ms: bool = False, - mask: str = None, + mask: Optional[str] = None, ) -> None: description = "Time tag for the state of all features." super(TimestampFeature, self).__init__( name=TIMESTAMP_COLUMN, description=description, from_column=from_column, - dtype=DataType.TIMESTAMP, + dtype=dtype, transformation=transformation, ) self.from_ms = from_ms @@ -65,13 +69,12 @@ def transform(self, dataframe: DataFrame) -> DataFrame: """ column_name = self.from_column if self.from_column else self.name + ts_column = dataframe[column_name] if self.from_ms: - dataframe = dataframe.withColumn( - column_name, from_unixtime(dataframe[column_name] / 1000.0) - ) - if self.mask: - dataframe = dataframe.withColumn( - column_name, to_timestamp(dataframe[column_name], self.mask) - ) + ts_column = ts_column / 1000 + + dataframe = dataframe.withColumn( + column_name, to_timestamp(ts_column, self.mask) # type: ignore + ) return super().transform(dataframe) diff --git a/butterfree/transform/transformations/aggregated_transform.py b/butterfree/transform/transformations/aggregated_transform.py index 2c7a8ced2..406ca72a9 100644 --- a/butterfree/transform/transformations/aggregated_transform.py +++ b/butterfree/transform/transformations/aggregated_transform.py @@ -1,6 +1,7 @@ """Aggregated Transform entity.""" + from collections import namedtuple -from typing import List, Tuple +from typing import List, Optional, Tuple from pyspark.sql import DataFrame from pyspark.sql.functions import col, expr, when @@ -56,7 +57,9 @@ class AggregatedTransform(TransformComponent): NotImplementedError: ... """ - def __init__(self, functions: List[Function], filter_expression: str = None): + def __init__( + self, functions: List[Function], filter_expression: Optional[str] = None + ): super(AggregatedTransform, self).__init__() self.functions = functions self.filter_expression = filter_expression @@ -76,7 +79,11 @@ def aggregations(self) -> List[Tuple]: Function = namedtuple("Function", ["function", "data_type"]) return [ - Function(f.func(expression), f.data_type.spark,) for f in self.functions + Function( + f.func(expression), + f.data_type.spark, + ) + for f in self.functions ] def _get_output_name(self, function: object) -> str: @@ -88,7 +95,7 @@ def _get_output_name(self, function: object) -> str: """ ) - base_name = "__".join([self._parent.name, function.__name__]) + base_name = "__".join([self._parent.name, str(function.__name__).lower()]) return base_name @property diff --git a/butterfree/transform/transformations/custom_transform.py b/butterfree/transform/transformations/custom_transform.py index 9b5ae23b1..a12310127 100644 --- a/butterfree/transform/transformations/custom_transform.py +++ b/butterfree/transform/transformations/custom_transform.py @@ -69,7 +69,7 @@ def transformer(self) -> Callable[..., Any]: @transformer.setter def transformer(self, method: Callable[..., Any]) -> None: - if not method: + if method is None: raise ValueError("A method must be provided to CustomTransform") self._transformer = method @@ -89,6 +89,8 @@ def transform(self, dataframe: DataFrame) -> DataFrame: """ dataframe = self.transformer( - dataframe, self.parent, **self.transformer__kwargs, + dataframe, + self.parent, + **self.transformer__kwargs, ) return dataframe diff --git a/butterfree/transform/transformations/h3_transform.py b/butterfree/transform/transformations/h3_transform.py index 8ccd3bb38..7a98294ec 100644 --- a/butterfree/transform/transformations/h3_transform.py +++ b/butterfree/transform/transformations/h3_transform.py @@ -84,7 +84,10 @@ class H3HashTransform(TransformComponent): """ def __init__( - self, h3_resolutions: List[int], lat_column: str, lng_column: str, + self, + h3_resolutions: List[int], + lat_column: str, + lng_column: str, ): super().__init__() self.h3_resolutions = h3_resolutions diff --git a/butterfree/transform/transformations/spark_function_transform.py b/butterfree/transform/transformations/spark_function_transform.py index 8fb24dd79..34384518d 100644 --- a/butterfree/transform/transformations/spark_function_transform.py +++ b/butterfree/transform/transformations/spark_function_transform.py @@ -1,5 +1,6 @@ """Spark Function Transform entity.""" -from typing import Any, List + +from typing import Any, List, Optional from pyspark.sql import DataFrame @@ -87,8 +88,8 @@ def with_window( self, partition_by: str, window_definition: List[str], - order_by: str = None, - mode: str = None, + order_by: Optional[str] = None, + mode: Optional[str] = None, ) -> "SparkFunctionTransform": """Create a list with windows defined.""" if mode is not None: @@ -103,7 +104,9 @@ def with_window( ] return self - def _get_output_name(self, function: object, window: Window = None) -> str: + def _get_output_name( + self, function: object, window: Optional[Window] = None + ) -> str: base_name = ( "__".join([self._parent.name, function.__name__]) if hasattr(function, "__name__") diff --git a/butterfree/transform/transformations/sql_expression_transform.py b/butterfree/transform/transformations/sql_expression_transform.py index 0199c23ae..80cd41ea9 100644 --- a/butterfree/transform/transformations/sql_expression_transform.py +++ b/butterfree/transform/transformations/sql_expression_transform.py @@ -54,7 +54,8 @@ class SQLExpressionTransform(TransformComponent): """ def __init__( - self, expression: str, + self, + expression: str, ): super().__init__() self.expression = expression diff --git a/butterfree/transform/transformations/transform_component.py b/butterfree/transform/transformations/transform_component.py index 7ecec332a..94bc19f8c 100644 --- a/butterfree/transform/transformations/transform_component.py +++ b/butterfree/transform/transformations/transform_component.py @@ -1,4 +1,5 @@ """Transform Abstract Class.""" + from abc import ABC, abstractmethod from typing import Any, List diff --git a/butterfree/transform/transformations/user_defined_functions/mode.py b/butterfree/transform/transformations/user_defined_functions/mode.py index 65790b939..5b6c7f17d 100644 --- a/butterfree/transform/transformations/user_defined_functions/mode.py +++ b/butterfree/transform/transformations/user_defined_functions/mode.py @@ -1,4 +1,5 @@ """Method to compute mode aggregation.""" + import pandas as pd from pyspark.sql.functions import pandas_udf from pyspark.sql.types import StringType diff --git a/butterfree/transform/transformations/user_defined_functions/most_frequent_set.py b/butterfree/transform/transformations/user_defined_functions/most_frequent_set.py index 20ccd3ba3..6dd6779f1 100644 --- a/butterfree/transform/transformations/user_defined_functions/most_frequent_set.py +++ b/butterfree/transform/transformations/user_defined_functions/most_frequent_set.py @@ -1,4 +1,5 @@ """Method to compute most frequent set aggregation.""" + from typing import Any import pandas as pd diff --git a/butterfree/transform/utils/__init__.py b/butterfree/transform/utils/__init__.py index abf7ed3fb..66004a374 100644 --- a/butterfree/transform/utils/__init__.py +++ b/butterfree/transform/utils/__init__.py @@ -1,4 +1,5 @@ """This module holds utils to be used by transformations.""" + from butterfree.transform.utils.function import Function from butterfree.transform.utils.window_spec import Window diff --git a/butterfree/transform/utils/date_range.py b/butterfree/transform/utils/date_range.py index 78e0e6e3c..4bdd29772 100644 --- a/butterfree/transform/utils/date_range.py +++ b/butterfree/transform/utils/date_range.py @@ -1,7 +1,7 @@ """Utils for date range generation.""" from datetime import datetime -from typing import Union +from typing import Optional, Union from pyspark.sql import DataFrame, functions @@ -14,7 +14,7 @@ def get_date_range( client: SparkClient, start_date: Union[str, datetime], end_date: Union[str, datetime], - step: int = None, + step: Optional[int] = None, ) -> DataFrame: """Create a date range dataframe. @@ -44,7 +44,7 @@ def get_date_range( for c in ("start_date", "end_date") ] ) - start_date, end_date = date_df.first() + start_date, end_date = date_df.first() # type: ignore return client.conn.range( start_date, end_date + day_in_seconds, step # type: ignore ).select(functions.col("id").cast(DataType.TIMESTAMP.spark).alias(TIMESTAMP_COLUMN)) diff --git a/butterfree/transform/utils/function.py b/butterfree/transform/utils/function.py index fcf6679fb..951a232ca 100644 --- a/butterfree/transform/utils/function.py +++ b/butterfree/transform/utils/function.py @@ -32,9 +32,9 @@ def func(self) -> Callable: @func.setter def func(self, value: Callable) -> None: """Definitions to be used in the transformation.""" - if not value: + if value is None: raise ValueError("Function must not be empty.") - if not callable(value): + if callable(value) is False: raise TypeError("Function must be callable.") self._func = value diff --git a/butterfree/transform/utils/window_spec.py b/butterfree/transform/utils/window_spec.py index f3a392f6a..b95dd73a6 100644 --- a/butterfree/transform/utils/window_spec.py +++ b/butterfree/transform/utils/window_spec.py @@ -1,10 +1,12 @@ """Holds function for defining window in DataFrames.""" + from typing import Any, List, Optional, Union from pyspark import sql from pyspark.sql import Column, WindowSpec, functions from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.constants.window_definitions import ALLOWED_WINDOWS class FrameBoundaries: @@ -16,21 +18,6 @@ class FrameBoundaries: it can be second(s), minute(s), hour(s), day(s), week(s) and year(s), """ - __ALLOWED_WINDOWS = { - "second": 1, - "seconds": 1, - "minute": 60, - "minutes": 60, - "hour": 3600, - "hours": 3600, - "day": 86400, - "days": 86400, - "week": 604800, - "weeks": 604800, - "year": 29030400, - "years": 29030400, - } - def __init__(self, mode: Optional[str], window_definition: str): self.mode = mode self.window_definition = window_definition @@ -46,7 +33,7 @@ def window_size(self) -> int: def window_unit(self) -> str: """Returns window unit.""" unit = self.window_definition.split()[1] - if unit not in self.__ALLOWED_WINDOWS and self.mode != "row_windows": + if unit not in ALLOWED_WINDOWS and self.mode != "row_windows": raise ValueError("Not allowed") return unit @@ -59,7 +46,7 @@ def get(self, window: WindowSpec) -> Any: span = self.window_size - 1 return window.rowsBetween(-span, 0) if self.mode == "fixed_windows": - span = self.__ALLOWED_WINDOWS[self.window_unit] * self.window_size + span = ALLOWED_WINDOWS[self.window_unit] * self.window_size return window.rangeBetween(-span, 0) @@ -76,18 +63,20 @@ class Window: Use the static methods in :class:`Window` to create a :class:`WindowSpec`. """ - SLIDE_DURATION: str = "1 day" + DEFAULT_SLIDE_DURATION: str = "1 day" def __init__( self, window_definition: str, partition_by: Optional[Union[Column, str, List[str]]] = None, order_by: Optional[Union[Column, str]] = None, - mode: str = None, + mode: Optional[str] = None, + slide: Optional[str] = None, ): self.partition_by = partition_by self.order_by = order_by or TIMESTAMP_COLUMN self.frame_boundaries = FrameBoundaries(mode, window_definition) + self.slide = slide or self.DEFAULT_SLIDE_DURATION def get_name(self) -> str: """Return window suffix name based on passed criteria.""" @@ -103,15 +92,10 @@ def get_name(self) -> str: def get(self) -> Any: """Defines a common window to be used both in time and rows windows.""" if self.frame_boundaries.mode == "rolling_windows": - if int(self.frame_boundaries.window_definition.split()[0]) <= 0: - raise KeyError( - f"{self.frame_boundaries.window_definition} " - f"have negative element." - ) return functions.window( TIMESTAMP_COLUMN, self.frame_boundaries.window_definition, - slideDuration=self.SLIDE_DURATION, + slideDuration=self.slide, ) elif self.order_by == TIMESTAMP_COLUMN: w = sql.Window.partitionBy(self.partition_by).orderBy( # type: ignore diff --git a/butterfree/validations/basic_validaton.py b/butterfree/validations/basic_validaton.py index d3a5558c7..01bc9ec21 100644 --- a/butterfree/validations/basic_validaton.py +++ b/butterfree/validations/basic_validaton.py @@ -1,5 +1,7 @@ """Validation implementing basic checks over the dataframe.""" +from typing import Optional + from pyspark.sql.dataframe import DataFrame from butterfree.constants.columns import TIMESTAMP_COLUMN @@ -14,7 +16,7 @@ class BasicValidation(Validation): """ - def __init__(self, dataframe: DataFrame = None): + def __init__(self, dataframe: Optional[DataFrame] = None): super().__init__(dataframe) def check(self) -> None: diff --git a/butterfree/validations/validation.py b/butterfree/validations/validation.py index 9915906cd..551859d82 100644 --- a/butterfree/validations/validation.py +++ b/butterfree/validations/validation.py @@ -1,5 +1,7 @@ """Abstract Validation class.""" + from abc import ABC, abstractmethod +from typing import Optional from pyspark.sql.dataframe import DataFrame @@ -12,7 +14,7 @@ class Validation(ABC): """ - def __init__(self, dataframe: DataFrame = None): + def __init__(self, dataframe: Optional[DataFrame] = None): self.dataframe = dataframe def input(self, dataframe: DataFrame) -> "Validation": diff --git a/docs/requirements.txt b/docs/requirements.txt index 501e17cdf..7eaabf11a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,9 +1,7 @@ recommonmark==0.6.0 -Sphinx==3.1.1 sphinx-rtd-theme==0.4.3 sphinxemoji==0.1.6 typing-extensions==3.7.4.2 cmake==3.18.4 h3==3.7.0 -pyarrow==0.15.1 - +pyarrow==16.1.0 diff --git a/docs/source/butterfree.automated.rst b/docs/source/butterfree.automated.rst new file mode 100644 index 000000000..9c01ac54e --- /dev/null +++ b/docs/source/butterfree.automated.rst @@ -0,0 +1,21 @@ +butterfree.automated package +============================ + +Submodules +---------- + +butterfree.automated.feature\_set\_creation module +-------------------------------------------------- + +.. automodule:: butterfree.automated.feature_set_creation + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: butterfree.automated + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/butterfree.clients.rst b/docs/source/butterfree.clients.rst index 3409d43a4..b1e1029a3 100644 --- a/docs/source/butterfree.clients.rst +++ b/docs/source/butterfree.clients.rst @@ -4,25 +4,30 @@ butterfree.clients package Submodules ---------- +butterfree.clients.abstract\_client module +------------------------------------------ .. automodule:: butterfree.clients.abstract_client :members: :undoc-members: :show-inheritance: +butterfree.clients.cassandra\_client module +------------------------------------------- .. automodule:: butterfree.clients.cassandra_client :members: :undoc-members: :show-inheritance: +butterfree.clients.spark\_client module +--------------------------------------- .. automodule:: butterfree.clients.spark_client :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.configs.db.rst b/docs/source/butterfree.configs.db.rst index a9973c561..6e23dc1c3 100644 --- a/docs/source/butterfree.configs.db.rst +++ b/docs/source/butterfree.configs.db.rst @@ -4,31 +4,38 @@ butterfree.configs.db package Submodules ---------- +butterfree.configs.db.abstract\_config module +--------------------------------------------- .. automodule:: butterfree.configs.db.abstract_config :members: :undoc-members: :show-inheritance: +butterfree.configs.db.cassandra\_config module +---------------------------------------------- .. automodule:: butterfree.configs.db.cassandra_config :members: :undoc-members: :show-inheritance: +butterfree.configs.db.kafka\_config module +------------------------------------------ .. automodule:: butterfree.configs.db.kafka_config :members: :undoc-members: :show-inheritance: +butterfree.configs.db.metastore\_config module +---------------------------------------------- -.. automodule:: butterfree.configs.db.s3_config +.. automodule:: butterfree.configs.db.metastore_config :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.configs.rst b/docs/source/butterfree.configs.rst index dc8a8c774..2a5cc07f1 100644 --- a/docs/source/butterfree.configs.rst +++ b/docs/source/butterfree.configs.rst @@ -12,12 +12,41 @@ Subpackages Submodules ---------- +butterfree.configs.environment module +------------------------------------- .. automodule:: butterfree.configs.environment :members: :undoc-members: :show-inheritance: +butterfree.configs.logger module +-------------------------------- + +.. automodule:: butterfree.configs.logger + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.configs.logger + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.configs.logger + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.configs.logger + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.configs.logger + :members: + :undoc-members: + :show-inheritance: Module contents --------------- diff --git a/docs/source/butterfree.constants.rst b/docs/source/butterfree.constants.rst index 083d20d78..de6f1ceed 100644 --- a/docs/source/butterfree.constants.rst +++ b/docs/source/butterfree.constants.rst @@ -4,24 +4,88 @@ butterfree.constants package Submodules ---------- +butterfree.constants.columns module +----------------------------------- .. automodule:: butterfree.constants.columns :members: :undoc-members: :show-inheritance: +butterfree.constants.data\_type module +-------------------------------------- .. automodule:: butterfree.constants.data_type :members: :undoc-members: :show-inheritance: +butterfree.constants.migrations module +-------------------------------------- + +.. automodule:: butterfree.constants.migrations + :members: + :undoc-members: + :show-inheritance: + +butterfree.constants.spark\_constants module +-------------------------------------------- + +.. automodule:: butterfree.constants.migrations + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: butterfree.constants.migrations + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: butterfree.constants.migrations + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: butterfree.constants.migrations + :members: + :undoc-members: + :show-inheritance: .. automodule:: butterfree.constants.spark_constants :members: :undoc-members: :show-inheritance: +butterfree.constants.window\_definitions module +----------------------------------------------- + +.. automodule:: butterfree.constants.window_definitions + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.constants.window_definitions + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.constants.window_definitions + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.constants.window_definitions + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: butterfree.constants.window_definitions + :members: + :undoc-members: + :show-inheritance: Module contents --------------- diff --git a/docs/source/butterfree.dataframe_service.rst b/docs/source/butterfree.dataframe_service.rst index b3c4cfc86..faf9cf546 100644 --- a/docs/source/butterfree.dataframe_service.rst +++ b/docs/source/butterfree.dataframe_service.rst @@ -4,12 +4,34 @@ butterfree.dataframe\_service package Submodules ---------- +butterfree.dataframe\_service.incremental\_strategy module +---------------------------------------------------------- + +.. automodule:: butterfree.dataframe_service.incremental_strategy + :members: + :undoc-members: + :show-inheritance: + +butterfree.dataframe\_service.partitioning module +------------------------------------------------- + +.. automodule:: butterfree.dataframe_service.partitioning + :members: + :undoc-members: + :show-inheritance: + +butterfree.dataframe\_service.repartition module +------------------------------------------------ .. automodule:: butterfree.dataframe_service.repartition :members: :undoc-members: :show-inheritance: +.. automodule:: butterfree.dataframe_service.repartition + :members: + :undoc-members: + :show-inheritance: Module contents --------------- diff --git a/docs/source/butterfree.extract.pre_processing.rst b/docs/source/butterfree.extract.pre_processing.rst index 9420cd7ee..e8e66e3d0 100644 --- a/docs/source/butterfree.extract.pre_processing.rst +++ b/docs/source/butterfree.extract.pre_processing.rst @@ -4,37 +4,46 @@ butterfree.extract.pre\_processing package Submodules ---------- +butterfree.extract.pre\_processing.explode\_json\_column\_transform module +-------------------------------------------------------------------------- .. automodule:: butterfree.extract.pre_processing.explode_json_column_transform :members: :undoc-members: :show-inheritance: +butterfree.extract.pre\_processing.filter\_transform module +----------------------------------------------------------- .. automodule:: butterfree.extract.pre_processing.filter_transform :members: :undoc-members: :show-inheritance: +butterfree.extract.pre\_processing.forward\_fill\_transform module +------------------------------------------------------------------ .. automodule:: butterfree.extract.pre_processing.forward_fill_transform :members: :undoc-members: :show-inheritance: +butterfree.extract.pre\_processing.pivot\_transform module +---------------------------------------------------------- .. automodule:: butterfree.extract.pre_processing.pivot_transform :members: :undoc-members: :show-inheritance: +butterfree.extract.pre\_processing.replace\_transform module +------------------------------------------------------------ .. automodule:: butterfree.extract.pre_processing.replace_transform :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.extract.readers.rst b/docs/source/butterfree.extract.readers.rst index 6f7ee7b8b..40df200eb 100644 --- a/docs/source/butterfree.extract.readers.rst +++ b/docs/source/butterfree.extract.readers.rst @@ -4,31 +4,38 @@ butterfree.extract.readers package Submodules ---------- +butterfree.extract.readers.file\_reader module +---------------------------------------------- .. automodule:: butterfree.extract.readers.file_reader :members: :undoc-members: :show-inheritance: +butterfree.extract.readers.kafka\_reader module +----------------------------------------------- .. automodule:: butterfree.extract.readers.kafka_reader :members: :undoc-members: :show-inheritance: +butterfree.extract.readers.reader module +---------------------------------------- .. automodule:: butterfree.extract.readers.reader :members: :undoc-members: :show-inheritance: +butterfree.extract.readers.table\_reader module +----------------------------------------------- .. automodule:: butterfree.extract.readers.table_reader :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.extract.rst b/docs/source/butterfree.extract.rst index 4454d6e90..455f02d5f 100644 --- a/docs/source/butterfree.extract.rst +++ b/docs/source/butterfree.extract.rst @@ -13,13 +13,14 @@ Subpackages Submodules ---------- +butterfree.extract.source module +-------------------------------- .. automodule:: butterfree.extract.source :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.hooks.rst b/docs/source/butterfree.hooks.rst new file mode 100644 index 000000000..c633cade3 --- /dev/null +++ b/docs/source/butterfree.hooks.rst @@ -0,0 +1,37 @@ +butterfree.hooks package +======================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + butterfree.hooks.schema_compatibility + +Submodules +---------- + +butterfree.hooks.hook module +---------------------------- + +.. automodule:: butterfree.hooks.hook + :members: + :undoc-members: + :show-inheritance: + +butterfree.hooks.hookable\_component module +------------------------------------------- + +.. automodule:: butterfree.hooks.hookable_component + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: butterfree.hooks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/butterfree.hooks.schema_compatibility.rst b/docs/source/butterfree.hooks.schema_compatibility.rst new file mode 100644 index 000000000..2d3de66ce --- /dev/null +++ b/docs/source/butterfree.hooks.schema_compatibility.rst @@ -0,0 +1,29 @@ +butterfree.hooks.schema\_compatibility package +============================================== + +Submodules +---------- + +butterfree.hooks.schema\_compatibility.cassandra\_table\_schema\_compatibility\_hook module +------------------------------------------------------------------------------------------- + +.. automodule:: butterfree.hooks.schema_compatibility.cassandra_table_schema_compatibility_hook + :members: + :undoc-members: + :show-inheritance: + +butterfree.hooks.schema\_compatibility.spark\_table\_schema\_compatibility\_hook module +--------------------------------------------------------------------------------------- + +.. automodule:: butterfree.hooks.schema_compatibility.spark_table_schema_compatibility_hook + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: butterfree.hooks.schema_compatibility + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/butterfree.load.processing.rst b/docs/source/butterfree.load.processing.rst index 79ae36b9b..d16182cb1 100644 --- a/docs/source/butterfree.load.processing.rst +++ b/docs/source/butterfree.load.processing.rst @@ -4,13 +4,14 @@ butterfree.load.processing package Submodules ---------- +butterfree.load.processing.json\_transform module +------------------------------------------------- .. automodule:: butterfree.load.processing.json_transform :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.load.rst b/docs/source/butterfree.load.rst index 2498b6f29..e4b56fbcb 100644 --- a/docs/source/butterfree.load.rst +++ b/docs/source/butterfree.load.rst @@ -13,13 +13,14 @@ Subpackages Submodules ---------- +butterfree.load.sink module +--------------------------- .. automodule:: butterfree.load.sink :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.load.writers.rst b/docs/source/butterfree.load.writers.rst index 88aa9e64f..b20eb85ea 100644 --- a/docs/source/butterfree.load.writers.rst +++ b/docs/source/butterfree.load.writers.rst @@ -4,25 +4,38 @@ butterfree.load.writers package Submodules ---------- +butterfree.load.writers.delta\_writer module +-------------------------------------------- + +.. automodule:: butterfree.load.writers.delta_writer + :members: + :undoc-members: + :show-inheritance: + +butterfree.load.writers.historical\_feature\_store\_writer module +----------------------------------------------------------------- .. automodule:: butterfree.load.writers.historical_feature_store_writer :members: :undoc-members: :show-inheritance: +butterfree.load.writers.online\_feature\_store\_writer module +------------------------------------------------------------- .. automodule:: butterfree.load.writers.online_feature_store_writer :members: :undoc-members: :show-inheritance: +butterfree.load.writers.writer module +------------------------------------- .. automodule:: butterfree.load.writers.writer :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.migrations.database_migration.rst b/docs/source/butterfree.migrations.database_migration.rst new file mode 100644 index 000000000..32ba4d4d0 --- /dev/null +++ b/docs/source/butterfree.migrations.database_migration.rst @@ -0,0 +1,37 @@ +butterfree.migrations.database\_migration package +================================================= + +Submodules +---------- + +butterfree.migrations.database\_migration.cassandra\_migration module +--------------------------------------------------------------------- + +.. automodule:: butterfree.migrations.database_migration.cassandra_migration + :members: + :undoc-members: + :show-inheritance: + +butterfree.migrations.database\_migration.database\_migration module +-------------------------------------------------------------------- + +.. automodule:: butterfree.migrations.database_migration.database_migration + :members: + :undoc-members: + :show-inheritance: + +butterfree.migrations.database\_migration.metastore\_migration module +--------------------------------------------------------------------- + +.. automodule:: butterfree.migrations.database_migration.metastore_migration + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: butterfree.migrations.database_migration + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/butterfree.migrations.rst b/docs/source/butterfree.migrations.rst new file mode 100644 index 000000000..4770fd8ea --- /dev/null +++ b/docs/source/butterfree.migrations.rst @@ -0,0 +1,18 @@ +butterfree.migrations package +============================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + butterfree.migrations.database_migration + +Module contents +--------------- + +.. automodule:: butterfree.migrations + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/butterfree.pipelines.rst b/docs/source/butterfree.pipelines.rst index d5c65f4d9..e70a4d89e 100644 --- a/docs/source/butterfree.pipelines.rst +++ b/docs/source/butterfree.pipelines.rst @@ -4,13 +4,14 @@ butterfree.pipelines package Submodules ---------- +butterfree.pipelines.feature\_set\_pipeline module +-------------------------------------------------- .. automodule:: butterfree.pipelines.feature_set_pipeline :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.reports.rst b/docs/source/butterfree.reports.rst index d49a701d1..a95a7e7fd 100644 --- a/docs/source/butterfree.reports.rst +++ b/docs/source/butterfree.reports.rst @@ -4,13 +4,14 @@ butterfree.reports package Submodules ---------- +butterfree.reports.metadata module +---------------------------------- .. automodule:: butterfree.reports.metadata :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.rst b/docs/source/butterfree.rst index 76e664b4b..e108be6e7 100644 --- a/docs/source/butterfree.rst +++ b/docs/source/butterfree.rst @@ -7,12 +7,15 @@ Subpackages .. toctree:: :maxdepth: 4 + butterfree.automated butterfree.clients butterfree.configs butterfree.constants butterfree.dataframe_service butterfree.extract + butterfree.hooks butterfree.load + butterfree.migrations butterfree.pipelines butterfree.reports butterfree.testing diff --git a/docs/source/butterfree.transform.features.rst b/docs/source/butterfree.transform.features.rst index e4c9a926b..837e0fcf7 100644 --- a/docs/source/butterfree.transform.features.rst +++ b/docs/source/butterfree.transform.features.rst @@ -4,25 +4,30 @@ butterfree.transform.features package Submodules ---------- +butterfree.transform.features.feature module +-------------------------------------------- .. automodule:: butterfree.transform.features.feature :members: :undoc-members: :show-inheritance: +butterfree.transform.features.key\_feature module +------------------------------------------------- .. automodule:: butterfree.transform.features.key_feature :members: :undoc-members: :show-inheritance: +butterfree.transform.features.timestamp\_feature module +------------------------------------------------------- .. automodule:: butterfree.transform.features.timestamp_feature :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.transform.rst b/docs/source/butterfree.transform.rst index 26d180939..12c346aed 100644 --- a/docs/source/butterfree.transform.rst +++ b/docs/source/butterfree.transform.rst @@ -14,19 +14,22 @@ Subpackages Submodules ---------- +butterfree.transform.aggregated\_feature\_set module +---------------------------------------------------- .. automodule:: butterfree.transform.aggregated_feature_set :members: :undoc-members: :show-inheritance: +butterfree.transform.feature\_set module +---------------------------------------- .. automodule:: butterfree.transform.feature_set :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.transform.transformations.rst b/docs/source/butterfree.transform.transformations.rst index 870c84686..f17818d3e 100644 --- a/docs/source/butterfree.transform.transformations.rst +++ b/docs/source/butterfree.transform.transformations.rst @@ -12,49 +12,62 @@ Subpackages Submodules ---------- +butterfree.transform.transformations.aggregated\_transform module +----------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.aggregated_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.custom\_transform module +------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.custom_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.h3\_transform module +--------------------------------------------------------- .. automodule:: butterfree.transform.transformations.h3_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.spark\_function\_transform module +---------------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.spark_function_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.sql\_expression\_transform module +---------------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.sql_expression_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.stack\_transform module +------------------------------------------------------------ .. automodule:: butterfree.transform.transformations.stack_transform :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.transform\_component module +---------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.transform_component :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.transform.transformations.user_defined_functions.rst b/docs/source/butterfree.transform.transformations.user_defined_functions.rst index becc5d6eb..b79e81386 100644 --- a/docs/source/butterfree.transform.transformations.user_defined_functions.rst +++ b/docs/source/butterfree.transform.transformations.user_defined_functions.rst @@ -4,19 +4,22 @@ butterfree.transform.transformations.user\_defined\_functions package Submodules ---------- +butterfree.transform.transformations.user\_defined\_functions.mode module +------------------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.user_defined_functions.mode :members: :undoc-members: :show-inheritance: +butterfree.transform.transformations.user\_defined\_functions.most\_frequent\_set module +---------------------------------------------------------------------------------------- .. automodule:: butterfree.transform.transformations.user_defined_functions.most_frequent_set :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.transform.utils.rst b/docs/source/butterfree.transform.utils.rst index bd8c15323..d1d7206ce 100644 --- a/docs/source/butterfree.transform.utils.rst +++ b/docs/source/butterfree.transform.utils.rst @@ -4,25 +4,30 @@ butterfree.transform.utils package Submodules ---------- +butterfree.transform.utils.date\_range module +--------------------------------------------- .. automodule:: butterfree.transform.utils.date_range :members: :undoc-members: :show-inheritance: +butterfree.transform.utils.function module +------------------------------------------ .. automodule:: butterfree.transform.utils.function :members: :undoc-members: :show-inheritance: +butterfree.transform.utils.window\_spec module +---------------------------------------------- .. automodule:: butterfree.transform.utils.window_spec :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/butterfree.validations.rst b/docs/source/butterfree.validations.rst index 9fd015574..2aa0053ec 100644 --- a/docs/source/butterfree.validations.rst +++ b/docs/source/butterfree.validations.rst @@ -4,19 +4,22 @@ butterfree.validations package Submodules ---------- +butterfree.validations.basic\_validaton module +---------------------------------------------- .. automodule:: butterfree.validations.basic_validaton :members: :undoc-members: :show-inheritance: +butterfree.validations.validation module +---------------------------------------- .. automodule:: butterfree.validations.validation :members: :undoc-members: :show-inheritance: - Module contents --------------- diff --git a/docs/source/cli.md b/docs/source/cli.md new file mode 100644 index 000000000..ba07428fc --- /dev/null +++ b/docs/source/cli.md @@ -0,0 +1,32 @@ +# Command-line Interface (CLI) + +Butterfree has now a command-line interface, introduced with the new automatic migration ability. + +As soon as you install butterfree, you can check what's available through butterfree's cli with: + +```shell +$~ butterfree --help +``` + +### Automated Database Schema Migration + +When developing your feature sets, you need also to prepare your database for the changes +to come into your Feature Store. Normally, when creating a new feature set, you needed +to manually create a new table in cassandra. Or, when creating a new feature in an existing +feature set, you needed to create new column in cassandra too. + +Now, you can just use `butterfree migrate apply ...`, butterfree will scan your python +files, looking for classes that inherit from `butterfree.pipelines.FeatureSetPipeline`, +then compare its schema with the database schema where the feature set would be written. +Then it will prepare migration queries and run against the databases. + +For more information, please, check `butterfree migrate apply --help` :) + +### Supported databases + +This functionality currently supports only the **Cassandra** database, which is the default +storage for an Online Feature Store built with Butterfree. Nonetheless, it was made with +the intent to be easily extended for other databases. + +Also, each database has its own rules for schema migration commands. Some changes may +still require manual interference. \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 77fdc1256..0a5377392 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,5 @@ """Sphinx Configuration.""" + # -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. diff --git a/docs/source/extract.md b/docs/source/extract.md index 2d9f9fabe..2b4f2e521 100644 --- a/docs/source/extract.md +++ b/docs/source/extract.md @@ -53,4 +53,4 @@ source = Source( ) ``` -It's important to state that we have some pre-processing methods as well, such as filter and pivot. Feel free to check them [here](https://github.com/quintoandar/butterfree/tree/staging/butterfree/core/extract/pre_processing). \ No newline at end of file +It's important to state that we have some pre-processing methods as well, such as filter and pivot. Feel free to check them [here](https://github.com/quintoandar/butterfree/tree/master/butterfree/extract/pre_processing). \ No newline at end of file diff --git a/docs/source/home.md b/docs/source/home.md index eada17390..fc297d2b6 100644 --- a/docs/source/home.md +++ b/docs/source/home.md @@ -10,6 +10,7 @@ The main idea is for this repository to be a set of tools for easing [ETLs](http - [Load](#load) - [Streaming](#streaming) - [Setup Configuration](#setup-configuration) +- [Command-line Interface](#command-line-interface) ## What is going on here @@ -61,3 +62,8 @@ We also support streaming pipelines in Butterfree. More information is available ## Setup Configuration Some configurations are needed to run your ETL pipelines. Detailed information is provided at the [Configuration Section](configuration.md) + +## Command-line Interface + +Butterfree has its own command-line interface, to manage your feature sets. Detailed information +provided by the [Command-line Interface](cli.md) section. diff --git a/docs/source/index.rst b/docs/source/index.rst index 6548f9adc..12bf1609a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,3 +22,4 @@ Navigation stream configuration modules + cli diff --git a/examples/interval_runs/interval_runs.ipynb b/examples/interval_runs/interval_runs.ipynb new file mode 100644 index 000000000..e234da8ac --- /dev/null +++ b/examples/interval_runs/interval_runs.ipynb @@ -0,0 +1,2152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# #5 Discovering Butterfree - Interval Runs\n", + "\n", + "Welcome to Discovering Butterfree tutorial series!\n", + "\n", + "This is the fifth tutorial of this series: its goal is to cover interval runs.\n", + "\n", + "Before diving into the tutorial make sure you have a basic understanding of these main data concepts: features, feature sets and the \"Feature Store Architecture\", you can read more about this [here]." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example:\n", + "\n", + "Simulating the following scenario (the same from previous tutorials):\n", + "\n", + "- We want to create a feature set with features about houses for rent (listings).\n", + "\n", + "\n", + "We have an input dataset:\n", + "\n", + "- Table: `listing_events`. Table with data about events of house listings.\n", + "\n", + "\n", + "Our desire is to have three resulting datasets with the following schema:\n", + "\n", + "* id: **int**;\n", + "* timestamp: **timestamp**;\n", + "* rent__avg_over_1_day_rolling_windows: **double**;\n", + "* rent__stddev_pop_over_1_day_rolling_windows: **double**.\n", + " \n", + "The first dataset will be computed with just an end date time limit. The second one, on the other hand, uses both start and end date in order to filter data. Finally, the third one will be the result of a daily run. You can understand more about these definitions in our documentation.\n", + "\n", + "The following code blocks will show how to generate this feature set using Butterfree library:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# setup spark\n", + "from pyspark import SparkContext, SparkConf\n", + "from pyspark.sql import session\n", + "\n", + "conf = SparkConf().setAll([('spark.driver.host','127.0.0.1'), ('spark.sql.session.timeZone', 'UTC')])\n", + "sc = SparkContext(conf=conf)\n", + "spark = session.SparkSession(sc)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# fix working dir\n", + "import pathlib\n", + "import os\n", + "path = os.path.join(pathlib.Path().absolute(), '../..')\n", + "os.chdir(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Showing test data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "listing_events_df = spark.read.json(f\"{path}/examples/data/listing_events.json\")\n", + "listing_events_df.createOrReplaceTempView(\"listing_events\") # creating listing_events view\n", + "\n", + "region = spark.read.json(f\"{path}/examples/data/region.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Listing events table:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
areabathroomsbedroomsidregion_idrenttimestamp
050111113001588302000000
150111120001588647600000
2100122215001588734000000
3100122225001589252400000
4150223330001589943600000
5175224432001589943600000
6250335532001590030000000
7225326632001590116400000
\n", + "
" + ], + "text/plain": [ + " area bathrooms bedrooms id region_id rent timestamp\n", + "0 50 1 1 1 1 1300 1588302000000\n", + "1 50 1 1 1 1 2000 1588647600000\n", + "2 100 1 2 2 2 1500 1588734000000\n", + "3 100 1 2 2 2 2500 1589252400000\n", + "4 150 2 2 3 3 3000 1589943600000\n", + "5 175 2 2 4 4 3200 1589943600000\n", + "6 250 3 3 5 5 3200 1590030000000\n", + "7 225 3 2 6 6 3200 1590116400000" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "listing_events_df.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Region table:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cityidlatlngregion
0Cerulean173.4448931.75030Kanto
1Veridian2-9.43510-167.11772Kanto
2Cinnabar329.73043117.66164Kanto
3Pallet4-52.95717-81.15251Kanto
4Violet5-47.35798-178.77255Johto
5Olivine651.7282046.21958Johto
\n", + "
" + ], + "text/plain": [ + " city id lat lng region\n", + "0 Cerulean 1 73.44489 31.75030 Kanto\n", + "1 Veridian 2 -9.43510 -167.11772 Kanto\n", + "2 Cinnabar 3 29.73043 117.66164 Kanto\n", + "3 Pallet 4 -52.95717 -81.15251 Kanto\n", + "4 Violet 5 -47.35798 -178.77255 Johto\n", + "5 Olivine 6 51.72820 46.21958 Johto" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "region.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Extract\n", + "\n", + "- For the extract part, we need the `Source` entity and the `FileReader` for the data we have;\n", + "- We need to declare a query in order to bring the results from our lonely reader (it's as simples as a select all statement)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.clients import SparkClient\n", + "from butterfree.extract import Source\n", + "from butterfree.extract.readers import FileReader, TableReader\n", + "from butterfree.extract.pre_processing import filter\n", + "\n", + "readers = [\n", + " TableReader(id=\"listing_events\", table=\"listing_events\",),\n", + " FileReader(id=\"region\", path=f\"{path}/examples/data/region.json\", format=\"json\",)\n", + "]\n", + "\n", + "query = \"\"\"\n", + "select\n", + " listing_events.*,\n", + " region.city,\n", + " region.region,\n", + " region.lat,\n", + " region.lng,\n", + " region.region as region_name\n", + "from\n", + " listing_events\n", + " join region\n", + " on listing_events.region_id = region.id\n", + "\"\"\"\n", + "\n", + "source = Source(readers=readers, query=query)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "spark_client = SparkClient()\n", + "source_df = source.construct(spark_client)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, finally, it's possible to see the results from building our souce dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
areabathroomsbedroomsidregion_idrenttimestampcityregionlatlngregion_name
050111113001588302000000CeruleanKanto73.4448931.75030Kanto
150111120001588647600000CeruleanKanto73.4448931.75030Kanto
2100122215001588734000000VeridianKanto-9.43510-167.11772Kanto
3100122225001589252400000VeridianKanto-9.43510-167.11772Kanto
4150223330001589943600000CinnabarKanto29.73043117.66164Kanto
5175224432001589943600000PalletKanto-52.95717-81.15251Kanto
6250335532001590030000000VioletJohto-47.35798-178.77255Johto
7225326632001590116400000OlivineJohto51.7282046.21958Johto
\n", + "
" + ], + "text/plain": [ + " area bathrooms bedrooms id region_id rent timestamp city \\\n", + "0 50 1 1 1 1 1300 1588302000000 Cerulean \n", + "1 50 1 1 1 1 2000 1588647600000 Cerulean \n", + "2 100 1 2 2 2 1500 1588734000000 Veridian \n", + "3 100 1 2 2 2 2500 1589252400000 Veridian \n", + "4 150 2 2 3 3 3000 1589943600000 Cinnabar \n", + "5 175 2 2 4 4 3200 1589943600000 Pallet \n", + "6 250 3 3 5 5 3200 1590030000000 Violet \n", + "7 225 3 2 6 6 3200 1590116400000 Olivine \n", + "\n", + " region lat lng region_name \n", + "0 Kanto 73.44489 31.75030 Kanto \n", + "1 Kanto 73.44489 31.75030 Kanto \n", + "2 Kanto -9.43510 -167.11772 Kanto \n", + "3 Kanto -9.43510 -167.11772 Kanto \n", + "4 Kanto 29.73043 117.66164 Kanto \n", + "5 Kanto -52.95717 -81.15251 Kanto \n", + "6 Johto -47.35798 -178.77255 Johto \n", + "7 Johto 51.72820 46.21958 Johto " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "source_df.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transform\n", + "- At the transform part, a set of `Feature` objects is declared;\n", + "- An Instance of `AggregatedFeatureSet` is used to hold the features;\n", + "- An `AggregatedFeatureSet` can only be created when it is possible to define a unique tuple formed by key columns and a time reference. This is an **architectural requirement** for the data. So least one `KeyFeature` and one `TimestampFeature` is needed;\n", + "- Every `Feature` needs a unique name, a description, and a data-type definition. Besides, in the case of the `AggregatedFeatureSet`, it's also mandatory to have an `AggregatedTransform` operator;\n", + "- An `AggregatedTransform` operator is used, as the name suggests, to define aggregation functions." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import functions as F\n", + "\n", + "from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet\n", + "from butterfree.transform.features import Feature, KeyFeature, TimestampFeature\n", + "from butterfree.transform.transformations import AggregatedTransform\n", + "from butterfree.constants import DataType\n", + "from butterfree.transform.utils import Function\n", + "\n", + "keys = [\n", + " KeyFeature(\n", + " name=\"id\",\n", + " description=\"Unique identificator code for houses.\",\n", + " dtype=DataType.BIGINT,\n", + " )\n", + "]\n", + "\n", + "# from_ms = True because the data originally is not in a Timestamp format.\n", + "ts_feature = TimestampFeature(from_ms=True)\n", + "\n", + "features = [\n", + " Feature(\n", + " name=\"rent\",\n", + " description=\"Rent value by month described in the listing.\",\n", + " transformation=AggregatedTransform(\n", + " functions=[\n", + " Function(F.avg, DataType.DOUBLE),\n", + " Function(F.stddev_pop, DataType.DOUBLE),\n", + " ],\n", + " filter_expression=\"region_name = 'Kanto'\",\n", + " ),\n", + " )\n", + "]\n", + "\n", + "aggregated_feature_set = AggregatedFeatureSet(\n", + " name=\"house_listings\",\n", + " entity=\"house\", # entity: to which \"business context\" this feature set belongs\n", + " description=\"Features describring a house listing.\",\n", + " keys=keys,\n", + " timestamp=ts_feature,\n", + " features=features,\n", + ").with_windows(definitions=[\"1 day\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we'll define out first aggregated feature set, with just an `end date` parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "aggregated_feature_set_windows_df = aggregated_feature_set.construct(\n", + " source_df, \n", + " spark_client, \n", + " end_date=\"2020-05-30\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting dataset is:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-01NaNNaN
112020-05-021300.00.0
212020-05-03NaNNaN
312020-05-062000.00.0
412020-05-07NaNNaN
522020-05-01NaNNaN
622020-05-071500.00.0
722020-05-08NaNNaN
822020-05-132500.00.0
922020-05-14NaNNaN
1032020-05-01NaNNaN
1132020-05-213000.00.0
1232020-05-22NaNNaN
1342020-05-01NaNNaN
1442020-05-213200.00.0
1542020-05-22NaNNaN
1652020-05-01NaNNaN
1762020-05-01NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-01 NaN \n", + "1 1 2020-05-02 1300.0 \n", + "2 1 2020-05-03 NaN \n", + "3 1 2020-05-06 2000.0 \n", + "4 1 2020-05-07 NaN \n", + "5 2 2020-05-01 NaN \n", + "6 2 2020-05-07 1500.0 \n", + "7 2 2020-05-08 NaN \n", + "8 2 2020-05-13 2500.0 \n", + "9 2 2020-05-14 NaN \n", + "10 3 2020-05-01 NaN \n", + "11 3 2020-05-21 3000.0 \n", + "12 3 2020-05-22 NaN \n", + "13 4 2020-05-01 NaN \n", + "14 4 2020-05-21 3200.0 \n", + "15 4 2020-05-22 NaN \n", + "16 5 2020-05-01 NaN \n", + "17 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 0.0 \n", + "2 NaN \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN \n", + "6 0.0 \n", + "7 NaN \n", + "8 0.0 \n", + "9 NaN \n", + "10 NaN \n", + "11 0.0 \n", + "12 NaN \n", + "13 NaN \n", + "14 0.0 \n", + "15 NaN \n", + "16 NaN \n", + "17 NaN " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated_feature_set_windows_df.orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's possible to see that if we use both a `start date` and `end_date` values. Then we'll achieve a time slice of the last dataframe, as it's possible to see:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-062000.00.0
112020-05-07NaNNaN
222020-05-06NaNNaN
322020-05-071500.00.0
422020-05-08NaNNaN
522020-05-132500.00.0
622020-05-14NaNNaN
732020-05-06NaNNaN
832020-05-213000.00.0
942020-05-06NaNNaN
1042020-05-213200.00.0
1152020-05-06NaNNaN
1262020-05-06NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-06 2000.0 \n", + "1 1 2020-05-07 NaN \n", + "2 2 2020-05-06 NaN \n", + "3 2 2020-05-07 1500.0 \n", + "4 2 2020-05-08 NaN \n", + "5 2 2020-05-13 2500.0 \n", + "6 2 2020-05-14 NaN \n", + "7 3 2020-05-06 NaN \n", + "8 3 2020-05-21 3000.0 \n", + "9 4 2020-05-06 NaN \n", + "10 4 2020-05-21 3200.0 \n", + "11 5 2020-05-06 NaN \n", + "12 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 0.0 \n", + "1 NaN \n", + "2 NaN \n", + "3 0.0 \n", + "4 NaN \n", + "5 0.0 \n", + "6 NaN \n", + "7 NaN \n", + "8 0.0 \n", + "9 NaN \n", + "10 0.0 \n", + "11 NaN \n", + "12 NaN " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated_feature_set.construct(\n", + " source_df, \n", + " spark_client, \n", + " end_date=\"2020-05-21\",\n", + " start_date=\"2020-05-06\",\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load\n", + "\n", + "- For the load part we need `Writer` instances and a `Sink`;\n", + "- `writers` define where to load the data;\n", + "- The `Sink` gets the transformed data (feature set) and trigger the load to all the defined `writers`;\n", + "- `debug_mode` will create a temporary view instead of trying to write in a real data store." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.load.writers import (\n", + " HistoricalFeatureStoreWriter,\n", + " OnlineFeatureStoreWriter,\n", + ")\n", + "from butterfree.load import Sink\n", + "\n", + "writers = [HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True), \n", + " OnlineFeatureStoreWriter(debug_mode=True, interval_mode=True)]\n", + "sink = Sink(writers=writers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline\n", + "\n", + "- The `Pipeline` entity wraps all the other defined elements.\n", + "- `run` command will trigger the execution of the pipeline, end-to-end." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.pipelines import FeatureSetPipeline\n", + "\n", + "pipeline = FeatureSetPipeline(source=source, feature_set=aggregated_feature_set, sink=sink)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first run will use just an `end_date` as parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run(end_date=\"2020-05-30\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-01NaNNaN202051
112020-05-021300.00.0202052
212020-05-03NaNNaN202053
312020-05-062000.00.0202056
412020-05-07NaNNaN202057
522020-05-01NaNNaN202051
622020-05-071500.00.0202057
722020-05-08NaNNaN202058
822020-05-132500.00.02020513
922020-05-14NaNNaN2020514
1032020-05-01NaNNaN202051
1132020-05-213000.00.02020521
1232020-05-22NaNNaN2020522
1342020-05-01NaNNaN202051
1442020-05-213200.00.02020521
1542020-05-22NaNNaN2020522
1652020-05-01NaNNaN202051
1762020-05-01NaNNaN202051
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-01 NaN \n", + "1 1 2020-05-02 1300.0 \n", + "2 1 2020-05-03 NaN \n", + "3 1 2020-05-06 2000.0 \n", + "4 1 2020-05-07 NaN \n", + "5 2 2020-05-01 NaN \n", + "6 2 2020-05-07 1500.0 \n", + "7 2 2020-05-08 NaN \n", + "8 2 2020-05-13 2500.0 \n", + "9 2 2020-05-14 NaN \n", + "10 3 2020-05-01 NaN \n", + "11 3 2020-05-21 3000.0 \n", + "12 3 2020-05-22 NaN \n", + "13 4 2020-05-01 NaN \n", + "14 4 2020-05-21 3200.0 \n", + "15 4 2020-05-22 NaN \n", + "16 5 2020-05-01 NaN \n", + "17 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 NaN 2020 5 1 \n", + "1 0.0 2020 5 2 \n", + "2 NaN 2020 5 3 \n", + "3 0.0 2020 5 6 \n", + "4 NaN 2020 5 7 \n", + "5 NaN 2020 5 1 \n", + "6 0.0 2020 5 7 \n", + "7 NaN 2020 5 8 \n", + "8 0.0 2020 5 13 \n", + "9 NaN 2020 5 14 \n", + "10 NaN 2020 5 1 \n", + "11 0.0 2020 5 21 \n", + "12 NaN 2020 5 22 \n", + "13 NaN 2020 5 1 \n", + "14 0.0 2020 5 21 \n", + "15 NaN 2020 5 22 \n", + "16 NaN 2020 5 1 \n", + "17 NaN 2020 5 1 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-07NaNNaN
122020-05-14NaNNaN
232020-05-22NaNNaN
342020-05-22NaNNaN
452020-05-01NaNNaN
562020-05-01NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-07 NaN \n", + "1 2 2020-05-14 NaN \n", + "2 3 2020-05-22 NaN \n", + "3 4 2020-05-22 NaN \n", + "4 5 2020-05-01 NaN \n", + "5 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- We can see that we were able to create all the desired features in an easy way\n", + "- The **historical feature set** holds all the data, and we can see that it is partitioned by year, month and day (columns added in the `HistoricalFeatureStoreWriter`)\n", + "- In the **online feature set** there is only the latest data for each id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The second run, on the other hand, will use both a `start_date` and `end_date` as parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run(end_date=\"2020-05-21\", start_date=\"2020-05-06\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-062000.00.0202056
112020-05-07NaNNaN202057
222020-05-06NaNNaN202056
322020-05-071500.00.0202057
422020-05-08NaNNaN202058
522020-05-132500.00.02020513
622020-05-14NaNNaN2020514
732020-05-06NaNNaN202056
832020-05-213000.00.02020521
942020-05-06NaNNaN202056
1042020-05-213200.00.02020521
1152020-05-06NaNNaN202056
1262020-05-06NaNNaN202056
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-06 2000.0 \n", + "1 1 2020-05-07 NaN \n", + "2 2 2020-05-06 NaN \n", + "3 2 2020-05-07 1500.0 \n", + "4 2 2020-05-08 NaN \n", + "5 2 2020-05-13 2500.0 \n", + "6 2 2020-05-14 NaN \n", + "7 3 2020-05-06 NaN \n", + "8 3 2020-05-21 3000.0 \n", + "9 4 2020-05-06 NaN \n", + "10 4 2020-05-21 3200.0 \n", + "11 5 2020-05-06 NaN \n", + "12 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 0.0 2020 5 6 \n", + "1 NaN 2020 5 7 \n", + "2 NaN 2020 5 6 \n", + "3 0.0 2020 5 7 \n", + "4 NaN 2020 5 8 \n", + "5 0.0 2020 5 13 \n", + "6 NaN 2020 5 14 \n", + "7 NaN 2020 5 6 \n", + "8 0.0 2020 5 21 \n", + "9 NaN 2020 5 6 \n", + "10 0.0 2020 5 21 \n", + "11 NaN 2020 5 6 \n", + "12 NaN 2020 5 6 " + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-07NaNNaN
122020-05-14NaNNaN
232020-05-213000.00.0
342020-05-213200.00.0
452020-05-06NaNNaN
562020-05-06NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-07 NaN \n", + "1 2 2020-05-14 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-06 NaN \n", + "5 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 0.0 \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, the third run, will use only an `execution_date` as a parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run_for_date(execution_date=\"2020-05-21\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-21NaNNaN2020521
122020-05-21NaNNaN2020521
232020-05-213000.00.02020521
342020-05-213200.00.02020521
452020-05-21NaNNaN2020521
562020-05-21NaNNaN2020521
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-21 NaN \n", + "1 2 2020-05-21 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-21 NaN \n", + "5 6 2020-05-21 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 NaN 2020 5 21 \n", + "1 NaN 2020 5 21 \n", + "2 0.0 2020 5 21 \n", + "3 0.0 2020 5 21 \n", + "4 NaN 2020 5 21 \n", + "5 NaN 2020 5 21 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-21NaNNaN
122020-05-21NaNNaN
232020-05-213000.00.0
342020-05-213200.00.0
452020-05-21NaNNaN
562020-05-21NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-21 NaN \n", + "1 2 2020-05-21 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-21 NaN \n", + "5 6 2020-05-21 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 0.0 \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/simple_feature_set/simple_feature_set.ipynb b/examples/simple_feature_set/simple_feature_set.ipynb index b217fcdf7..c5ed9ae55 100644 --- a/examples/simple_feature_set/simple_feature_set.ipynb +++ b/examples/simple_feature_set/simple_feature_set.ipynb @@ -89,7 +89,7 @@ "| - | - | - | - | - | - | - | - | - | - | - | - | - | - |\n", "| int | timestamp | float | float | int | int | float | float | float | double | double | string | string | string |\n", "\n", - "For more information about H3 geohash click [here]()\n", + "For more information about H3 geohash click [here](https://h3geo.org/docs/)\n", "\n", "The following code blocks will show how to generate this feature set using Butterfree library:\n", "\n" diff --git a/examples/spark_function_and_window/spark_function_and_window.ipynb b/examples/spark_function_and_window/spark_function_and_window.ipynb index a4472e245..dcf715524 100644 --- a/examples/spark_function_and_window/spark_function_and_window.ipynb +++ b/examples/spark_function_and_window/spark_function_and_window.ipynb @@ -50,7 +50,7 @@ "\n", "Note that we're going to compute two aggregated features, rent average and standard deviation, considering the two last occurrences (or events). It'd also be possible to define time windows, instead of windows based on events.\n", "\n", - "For more information about H3 geohash click [here]().\n", + "For more information about H3 geohash click [here](https://h3geo.org/docs/).\n", "\n", "The following code blocks will show how to generate this feature set using Butterfree library:\n", "\n" diff --git a/examples/test_examples.py b/examples/test_examples.py index b40b6e1a4..7180e080d 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -36,9 +36,9 @@ _, error = p.communicate() if p.returncode != 0: errors.append({"notebook": path, "error": error}) - print(f" >>> Error in execution!\n") + print(" >>> Error in execution!\n") else: - print(f" >>> Successful execution\n") + print(" >>> Successful execution\n") if errors: print(">>> Errors in the following notebooks:") diff --git a/logging.json b/logging.json new file mode 100644 index 000000000..e69de29bb diff --git a/mypy.ini b/mypy.ini index c67bd3a89..fc2931493 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.7 +python_version = 3.9 ignore_missing_imports = True disallow_untyped_calls = False disallow_untyped_defs = True @@ -9,3 +9,42 @@ show_error_codes = True show_error_context = True disable_error_code = attr-defined, list-item, operator pretty = True + +[mypy-butterfree.pipelines.*] +ignore_errors = True + +[mypy-butterfree.load.*] +ignore_errors = True + +[mypy-butterfree.transform.*] +ignore_errors = True + +[mypy-butterfree.extract.*] +ignore_errors = True + +[mypy-butterfree.config.*] +ignore_errors = True + +[mypy-butterfree.clients.*] +ignore_errors = True + +[mypy-butterfree.configs.*] +ignore_errors = True + +[mypy-butterfree.dataframe_service.*] +ignore_errors = True + +[mypy-butterfree.validations.*] +ignore_errors = True + +[mypy-butterfree.migrations.*] +ignore_errors = True + +[mypy-butterfree.testing.*] +ignore_errors = True + +[mypy-butterfree.hooks.*] +ignore_errors = True + +[mypy-butterfree._cli.*] +ignore_errors = True diff --git a/requirements.dev.txt b/requirements.dev.txt index 8ebfa5108..bf4b4b2b8 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,7 +1,11 @@ -cmake==3.18.4 -h3==3.7.0 -pyarrow==0.15.1 +h3==3.7.7 jupyter==1.0.0 twine==3.1.1 -mypy==0.790 -pyspark-stubs==3.0.0 +mypy==1.10.0 +sphinx==6.2.1 +sphinxemoji==0.3.1 +sphinx-rtd-theme==1.3.0 +recommonmark==0.7.1 +pyarrow==16.1.0 +setuptools==70.0.0 +wheel==0.43.0 diff --git a/requirements.lint.txt b/requirements.lint.txt index 161f7911f..1ad6499d1 100644 --- a/requirements.lint.txt +++ b/requirements.lint.txt @@ -1,7 +1,7 @@ -black==19.10b0 -flake8==3.7.9 -flake8-isort==2.8.0 -isort<5 # temporary fix +black==24.3.0 +flake8==4.0.1 +flake8-isort==4.1.1 flake8-docstrings==1.5.0 flake8-bugbear==20.1.0 flake8-bandit==2.1.2 +bandit==1.7.2 diff --git a/requirements.test.txt b/requirements.test.txt index b0c4032a8..651700b80 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -2,4 +2,4 @@ pytest==5.3.2 pytest-cov==2.8.1 pytest-xdist==1.31.0 pytest-mock==2.0.0 -pytest-spark==0.5.2 +pytest-spark==0.6.0 diff --git a/requirements.txt b/requirements.txt index e55289f4d..9c9eea640 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ -cassandra-driver>=3.22.0,<4.0 +cassandra-driver==3.24.0 mdutils>=1.2.2,<2.0 -pandas>=0.24,<1.1 +pandas>=0.24,<2.0 parameters-validation>=1.1.5,<2.0 -pyspark==3.* +pyspark==3.5.1 +typer==0.4.2 +typing-extensions>3.7.4,<5 +boto3==1.17.* +numpy==1.26.4 +delta-spark==3.2.0 diff --git a/setup.cfg b/setup.cfg index 7b1c62bd2..8206c6ae3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ docstring-convention = google max-line-length = 88 max-complexity = 12 -ignore = W503, E203, D203, D401, D107, S101 +ignore = W503, E203, D203, D401, D107, S101, D105, D100, W605, D202, D212, D104, E261 exclude = dist/*,build/*,.pytest_cache/*,.git/*,pip/* per-file-ignores = # We will not check for docstrings or the use of asserts in tests @@ -10,13 +10,13 @@ per-file-ignores = setup.py:D,S101 [isort] +profile = black line_length = 88 known_first_party = butterfree default_section = THIRDPARTY multi_line_output = 3 indent = ' ' skip_glob = pip -use_parantheses = True include_trailing_comma = True [tool:pytest] @@ -24,6 +24,7 @@ spark_options = spark.sql.session.timeZone: UTC spark.driver.bindAddress: 127.0.0.1 spark.sql.legacy.timeParserPolicy: LEGACY + spark.sql.legacy.createHiveTableByDefault: false [mypy] # suppress errors about unsatisfied imports @@ -40,3 +41,9 @@ disallow_any_generics = True disallow_untyped_defs = True check_untyped_defs = True disallow_untyped_calls = True + +[build_sphinx] +all-files = 1 +source-dir = docs/source +build-dir = docs/build +warning-is-error = 0 diff --git a/setup.py b/setup.py index 47ba0b989..e6b9f7618 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import find_packages, setup __package_name__ = "butterfree" -__version__ = "1.1.1" +__version__ = "1.4.0" __repository_url__ = "https://github.com/quintoandar/butterfree" with open("requirements.txt") as f: @@ -34,6 +34,8 @@ license="Copyright", author="QuintoAndar", install_requires=requirements, - extras_require={"h3": ["cmake==3.16.3", "h3==3.4.2"]}, - python_requires=">=3.7, <4", + extras_require={"h3": ["h3>=3.7.4,<4"]}, + python_requires=">=3.9, <4", + entry_points={"console_scripts": ["butterfree=butterfree._cli.main:app"]}, + include_package_data=True, ) diff --git a/tests/integration/butterfree/extract/test_source.py b/tests/integration/butterfree/extract/test_source.py index c465ebd05..3ab991ab2 100644 --- a/tests/integration/butterfree/extract/test_source.py +++ b/tests/integration/butterfree/extract/test_source.py @@ -1,11 +1,11 @@ from typing import List from pyspark.sql import DataFrame -from tests.integration import INPUT_PATH from butterfree.clients import SparkClient from butterfree.extract import Source from butterfree.extract.readers import FileReader, TableReader +from tests.integration import INPUT_PATH def create_temp_view(dataframe: DataFrame, name): @@ -13,10 +13,11 @@ def create_temp_view(dataframe: DataFrame, name): def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table): - spark.sql(f"create database if not exists {table_reader_db}") + spark.sql(f"drop schema if exists {table_reader_db} cascade") + spark.sql(f"create database {table_reader_db}") spark.sql(f"use {table_reader_db}") spark.sql( - f"create table if not exists {table_reader_db}.{table_reader_table} " # noqa + f"create table {table_reader_db}.{table_reader_table} " # noqa f"as select * from {table_reader_id}" # noqa ) @@ -33,7 +34,10 @@ def compare_dataframes( class TestSource: def test_source( - self, target_df_source, target_df_table_reader, spark_session, + self, + target_df_source, + target_df_table_reader, + spark_session, ): # given spark_client = SparkClient() @@ -66,6 +70,7 @@ def test_source( query=f"select a.*, b.feature2 " # noqa f"from {table_reader_id} a " # noqa f"inner join {file_reader_id} b on a.id = b.id ", # noqa + eager_evaluation=False, ) result_df = source.construct(client=spark_client) diff --git a/tests/integration/butterfree/load/conftest.py b/tests/integration/butterfree/load/conftest.py index 418b6d2ac..60101f1ac 100644 --- a/tests/integration/butterfree/load/conftest.py +++ b/tests/integration/butterfree/load/conftest.py @@ -51,7 +51,7 @@ def feature_set(): ] ts_feature = TimestampFeature(from_column="timestamp") features = [ - Feature(name="feature", description="Description", dtype=DataType.FLOAT), + Feature(name="feature", description="Description", dtype=DataType.INTEGER), ] return FeatureSet( "test_sink_feature_set", diff --git a/tests/integration/butterfree/load/test_sink.py b/tests/integration/butterfree/load/test_sink.py index d00f48062..f73f5f7ce 100644 --- a/tests/integration/butterfree/load/test_sink.py +++ b/tests/integration/butterfree/load/test_sink.py @@ -12,6 +12,7 @@ def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() + client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys] @@ -20,14 +21,21 @@ def test_sink(input_dataframe, feature_set): # setup historical writer s3config = Mock() + s3config.mode = "overwrite" + s3config.format_ = "parquet" s3config.get_options = Mock( return_value={ - "mode": "overwrite", - "format_": "parquet", "path": "test_folder/historical/entity/feature_set", + "mode": "overwrite", } ) - historical_writer = HistoricalFeatureStoreWriter(db_config=s3config) + s3config.get_path_with_partitions = Mock( + return_value="spark-warehouse/test.db/test_folder/historical/entity/feature_set" + ) + + historical_writer = HistoricalFeatureStoreWriter( + db_config=s3config, interval_mode=True + ) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready @@ -47,13 +55,14 @@ def test_sink(input_dataframe, feature_set): sink.flush(feature_set, feature_set_df, client) # get historical results - historical_result_df = client.read_table( - feature_set.name, historical_writer.database + historical_result_df = client.read( + s3config.format_, + path=s3config.get_path_with_partitions(feature_set.name, feature_set_df), ) # get online results online_result_df = client.read( - online_config.format_, options=online_config.get_options(feature_set.name) + online_config.format_, **online_config.get_options(feature_set.name) ) # assert diff --git a/tests/integration/butterfree/pipelines/conftest.py b/tests/integration/butterfree/pipelines/conftest.py index 798941761..1466a8d98 100644 --- a/tests/integration/butterfree/pipelines/conftest.py +++ b/tests/integration/butterfree/pipelines/conftest.py @@ -1,7 +1,19 @@ import pytest +from pyspark.sql import DataFrame +from pyspark.sql import functions as F from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import HistoricalFeatureStoreWriter +from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import SparkFunctionTransform +from butterfree.transform.utils import Function @pytest.fixture() @@ -74,3 +86,197 @@ def fixed_windows_output_feature_set_dataframe(spark_context, spark_session): df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark)) return df + + +@pytest.fixture() +def mocked_date_df(spark_context, spark_session): + data = [ + {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300}, + {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400}, + {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + + return df + + +@pytest.fixture() +def fixed_windows_output_feature_set_date_dataframe(spark_context, spark_session): + data = [ + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature__avg_over_1_day_fixed_windows": 300, + "feature__stddev_pop_over_1_day_fixed_windows": 0, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature__avg_over_1_day_fixed_windows": 400, + "feature__stddev_pop_over_1_day_fixed_windows": 0, + "year": 2016, + "month": 4, + "day": 13, + }, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark)) + + return df + + +@pytest.fixture() +def feature_set_pipeline( + spark_context, + spark_session, +): + + feature_set_pipeline = FeatureSetPipeline( + source=Source( + readers=[ + TableReader( + id="b_source", + table="b_table", + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy(column="timestamp") + ), + ], + query=f"select * from b_source ", # noqa + ), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=SparkFunctionTransform( + functions=[ + Function(F.avg, DataType.FLOAT), + Function(F.stddev_pop, DataType.FLOAT), + ], + ).with_window( + partition_by="id", + order_by=TIMESTAMP_COLUMN, + mode="fixed_windows", + window_definition=["1 day"], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ), + sink=Sink(writers=[HistoricalFeatureStoreWriter(debug_mode=True)]), + ) + + return feature_set_pipeline + + +@pytest.fixture() +def pipeline_interval_run_target_dfs( + spark_session, spark_context +) -> (DataFrame, DataFrame, DataFrame): + first_data = [ + { + "id": 1, + "timestamp": "2016-04-11 11:31:11", + "feature": 200, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 11, + }, + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature": 300, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature": 400, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 13, + }, + ] + + second_data = first_data + [ + { + "id": 1, + "timestamp": "2016-04-14 12:03:21", + "feature": 500, + "run_id": 2, + "year": 2016, + "month": 4, + "day": 14, + }, + ] + + third_data = [ + { + "id": 1, + "timestamp": "2016-04-11 11:31:11", + "feature": 200, + "run_id": 3, + "year": 2016, + "month": 4, + "day": 11, + }, + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature": 300, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature": 400, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 13, + }, + { + "id": 1, + "timestamp": "2016-04-14 12:03:21", + "feature": 500, + "run_id": 2, + "year": 2016, + "month": 4, + "day": 14, + }, + ] + + first_run_df = spark_session.read.json( + spark_context.parallelize(first_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + second_run_df = spark_session.read.json( + spark_context.parallelize(second_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + third_run_df = spark_session.read.json( + spark_context.parallelize(third_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + + return first_run_df, second_run_df, third_run_df diff --git a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py index 23d200c16..16eb08e2f 100644 --- a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py @@ -5,29 +5,56 @@ from pyspark.sql import functions as F from butterfree.configs import environment +from butterfree.configs.db import MetastoreConfig from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy from butterfree.extract import Source from butterfree.extract.readers import TableReader +from butterfree.hooks import Hook from butterfree.load import Sink from butterfree.load.writers import HistoricalFeatureStoreWriter from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline from butterfree.testing.dataframe import assert_dataframe_equality from butterfree.transform import FeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature -from butterfree.transform.transformations import CustomTransform, SparkFunctionTransform +from butterfree.transform.transformations import ( + CustomTransform, + SparkFunctionTransform, + SQLExpressionTransform, +) from butterfree.transform.utils import Function +class AddHook(Hook): + def __init__(self, value): + self.value = value + + def run(self, dataframe): + return dataframe.withColumn("feature", F.expr(f"feature + {self.value}")) + + +class RunHook(Hook): + def __init__(self, id): + self.id = id + + def run(self, dataframe): + return dataframe.withColumn( + "run_id", + F.when(F.lit(self.id).isNotNull(), F.lit(self.id)).otherwise(F.lit(None)), + ) + + def create_temp_view(dataframe: DataFrame, name): dataframe.createOrReplaceTempView(name) def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table): - spark.sql(f"create database if not exists {table_reader_db}") + spark.sql(f"drop schema if exists {table_reader_db} cascade") + spark.sql(f"create database {table_reader_db}") spark.sql(f"use {table_reader_db}") spark.sql( - f"create table if not exists {table_reader_db}.{table_reader_table} " # noqa + f"create table {table_reader_db}.{table_reader_table} " # noqa f"as select * from {table_reader_id}" # noqa ) @@ -38,14 +65,27 @@ def divide(df, fs, column1, column2): return df +def create_ymd(dataframe): + return ( + dataframe.withColumn("year", F.year(F.col("timestamp"))) + .withColumn("month", F.month(F.col("timestamp"))) + .withColumn("day", F.dayofmonth(F.col("timestamp"))) + ) + + class TestFeatureSetPipeline: def test_feature_set_pipeline( - self, mocked_df, spark_session, fixed_windows_output_feature_set_dataframe + self, + mocked_df, + spark_session, + fixed_windows_output_feature_set_dataframe, ): # arrange + table_reader_id = "a_source" table_reader_table = "table" table_reader_db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE") + create_temp_view(dataframe=mocked_df, name=table_reader_id) create_db_and_table( spark=spark_session, @@ -53,13 +93,16 @@ def test_feature_set_pipeline( table_reader_db=table_reader_db, table_reader_table=table_reader_table, ) - dbconfig = Mock() + + path = "spark-warehouse/test.db/test_folder/historical/entity/feature_set" + + dbconfig = MetastoreConfig() dbconfig.get_options = Mock( - return_value={ - "mode": "overwrite", - "format_": "parquet", - "path": "test_folder/historical/entity/feature_set", - } + return_value={"mode": "overwrite", "format_": "parquet", "path": path} + ) + + historical_writer = HistoricalFeatureStoreWriter( + db_config=dbconfig, debug_mode=True ) # act @@ -99,7 +142,9 @@ def test_feature_set_pipeline( description="unit test", dtype=DataType.FLOAT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ), ], @@ -112,13 +157,17 @@ def test_feature_set_pipeline( ], timestamp=TimestampFeature(), ), - sink=Sink(writers=[HistoricalFeatureStoreWriter(db_config=dbconfig)],), + sink=Sink(writers=[historical_writer]), ) test_pipeline.run() + # act and assert + dbconfig.get_path_with_partitions = Mock( + return_value=["historical/entity/feature_set"] + ) + # assert - path = dbconfig.get_options("historical/entity/feature_set").get("path") - df = spark_session.read.parquet(path).orderBy(TIMESTAMP_COLUMN) + df = spark_session.sql("select * from historical_feature_store__feature_set") target_df = fixed_windows_output_feature_set_dataframe.orderBy( test_pipeline.feature_set.timestamp_column @@ -127,5 +176,262 @@ def test_feature_set_pipeline( # assert assert_dataframe_equality(df, target_df) + def test_feature_set_pipeline_with_dates( + self, + mocked_date_df, + spark_session, + fixed_windows_output_feature_set_date_dataframe, + feature_set_pipeline, + ): + # arrange + table_reader_table = "b_table" + create_temp_view(dataframe=mocked_date_df, name=table_reader_table) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + feature_set_pipeline.sink.writers = [historical_writer] + + # act + feature_set_pipeline.run(start_date="2016-04-12", end_date="2016-04-13") + + df = spark_session.sql("select * from historical_feature_store__feature_set") + + # assert + assert_dataframe_equality(df, fixed_windows_output_feature_set_date_dataframe) + + def test_feature_set_pipeline_with_execution_date( + self, + mocked_date_df, + spark_session, + fixed_windows_output_feature_set_date_dataframe, + feature_set_pipeline, + ): + # arrange + table_reader_table = "b_table" + create_temp_view(dataframe=mocked_date_df, name=table_reader_table) + + target_df = fixed_windows_output_feature_set_date_dataframe.filter( + "timestamp < '2016-04-13'" + ) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + feature_set_pipeline.sink.writers = [historical_writer] + + # act + feature_set_pipeline.run_for_date(execution_date="2016-04-12") + + df = spark_session.sql("select * from historical_feature_store__feature_set") + + # assert + assert_dataframe_equality(df, target_df) + + def test_pipeline_with_hooks(self, spark_session): + # arrange + hook1 = AddHook(value=1) + + spark_session.sql( + "select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature" + ).createOrReplaceTempView("test") + + target_df = spark_session.sql( + "select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 " + "as year, 1 as month, 1 as day" + ) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + test_pipeline = FeatureSetPipeline( + source=Source( + readers=[ + TableReader( + id="reader", + table="test", + ).add_post_hook(hook1) + ], + query="select * from reader", + ).add_post_hook(hook1), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=SQLExpressionTransform(expression="feature + 1"), + dtype=DataType.INTEGER, + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ) + .add_pre_hook(hook1) + .add_post_hook(hook1), + sink=Sink( + writers=[historical_writer], + ).add_pre_hook(hook1), + ) + + # act + test_pipeline.run() + output_df = spark_session.table("historical_feature_store__feature_set") + + # assert + output_df.show() + assert_dataframe_equality(output_df, target_df) + + def test_pipeline_interval_run( + self, mocked_date_df, pipeline_interval_run_target_dfs, spark_session + ): + """Testing pipeline's idempotent interval run feature. + Source data: + +-------+---+-------------------+-------------------+ + |feature| id| ts| timestamp| + +-------+---+-------------------+-------------------+ + | 200| 1|2016-04-11 11:31:11|2016-04-11 11:31:11| + | 300| 1|2016-04-12 11:44:12|2016-04-12 11:44:12| + | 400| 1|2016-04-13 11:46:24|2016-04-13 11:46:24| + | 500| 1|2016-04-14 12:03:21|2016-04-14 12:03:21| + +-------+---+-------------------+-------------------+ + The test executes 3 runs for different time intervals. The input data has 4 data + points: 2016-04-11, 2016-04-12, 2016-04-13 and 2016-04-14. The following run + specifications are: + 1) Interval: from 2016-04-11 to 2016-04-13 + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + +---+-------+---+-----+------+-------------------+----+ + 2) Interval: only 2016-04-14. + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016| + +---+-------+---+-----+------+-------------------+----+ + 3) Interval: only 2016-04-11. + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 3|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016| + +---+-------+---+-----+------+-------------------+----+ + """ + # arrange + create_temp_view(dataframe=mocked_date_df, name="input_data") + + db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE") + path = "test_folder/historical/entity/feature_set" + read_path = "spark-warehouse/test.db/" + path + + spark_session.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") + spark_session.sql(f"drop schema {db} cascade") + spark_session.sql(f"create database {db}") + spark_session.sql( + f"create table {db}.feature_set_interval " + f"(id int, timestamp timestamp, feature int, " + f"run_id int, year int, month int, day int);" + ) + + dbconfig = MetastoreConfig() + dbconfig.get_options = Mock( + return_value={"mode": "overwrite", "format_": "parquet", "path": path} + ) + + historical_writer = HistoricalFeatureStoreWriter( + db_config=dbconfig, interval_mode=True, row_count_validation=False + ) + + first_run_hook = RunHook(id=1) + second_run_hook = RunHook(id=2) + third_run_hook = RunHook(id=3) + + ( + first_run_target_df, + second_run_target_df, + third_run_target_df, + ) = pipeline_interval_run_target_dfs + + test_pipeline = FeatureSetPipeline( + source=Source( + readers=[ + TableReader( + id="id", + table="input_data", + ).with_incremental_strategy(IncrementalStrategy("ts")), + ], + query="select * from id ", + ), + feature_set=FeatureSet( + name="feature_set_interval", + entity="entity", + description="", + keys=[ + KeyFeature( + name="id", + description="", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature(name="feature", description="", dtype=DataType.INTEGER), + Feature(name="run_id", description="", dtype=DataType.INTEGER), + ], + ), + sink=Sink( + [historical_writer], + ), + ) + + # act and assert + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=11", # noqa + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=12", # noqa + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=13", # noqa + ] + ) + test_pipeline.feature_set.add_pre_hook(first_run_hook) + test_pipeline.run(end_date="2016-04-13", start_date="2016-04-11") + first_run_output_df = spark_session.read.parquet(read_path) + assert_dataframe_equality(first_run_output_df, first_run_target_df) + + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=14", # noqa + ] + ) + test_pipeline.feature_set.add_pre_hook(second_run_hook) + test_pipeline.run_for_date("2016-04-14") + second_run_output_df = spark_session.read.parquet(read_path) + assert_dataframe_equality(second_run_output_df, second_run_target_df) + + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=11", # noqa + ] + ) + test_pipeline.feature_set.add_pre_hook(third_run_hook) + test_pipeline.run_for_date("2016-04-11") + third_run_output_df = spark_session.read.parquet(read_path) + assert_dataframe_equality(third_run_output_df, third_run_target_df) + # tear down - shutil.rmtree("test_folder") + shutil.rmtree("spark-warehouse/test.db/test_folder") diff --git a/tests/integration/butterfree/transform/conftest.py b/tests/integration/butterfree/transform/conftest.py index 6621c9a35..fe0cc5727 100644 --- a/tests/integration/butterfree/transform/conftest.py +++ b/tests/integration/butterfree/transform/conftest.py @@ -395,3 +395,58 @@ def rolling_windows_output_feature_set_dataframe_base_date( df = df.withColumn(TIMESTAMP_COLUMN, df.origin_ts.cast(DataType.TIMESTAMP.spark)) return df + + +@fixture +def feature_set_dates_dataframe(spark_context, spark_session): + data = [ + {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300}, + {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400}, + {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + df = df.withColumn("ts", df.ts.cast(DataType.TIMESTAMP.spark)) + + return df + + +@fixture +def feature_set_dates_output_dataframe(spark_context, spark_session): + data = [ + {"id": 1, "timestamp": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "timestamp": "2016-04-12 11:44:12", "feature": 300}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + + return df + + +@fixture +def rolling_windows_output_date_boundaries(spark_context, spark_session): + data = [ + { + "id": 1, + "ts": "2016-04-11 00:00:00", + "feature__avg_over_1_day_rolling_windows": None, + "feature__avg_over_1_week_rolling_windows": None, + "feature__stddev_pop_over_1_day_rolling_windows": None, + "feature__stddev_pop_over_1_week_rolling_windows": None, + }, + { + "id": 1, + "ts": "2016-04-12 00:00:00", + "feature__avg_over_1_day_rolling_windows": 200.0, + "feature__avg_over_1_week_rolling_windows": 200.0, + "feature__stddev_pop_over_1_day_rolling_windows": 0.0, + "feature__stddev_pop_over_1_week_rolling_windows": 0.0, + }, + ] + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + + return df diff --git a/tests/integration/butterfree/transform/test_aggregated_feature_set.py b/tests/integration/butterfree/transform/test_aggregated_feature_set.py index 559dbcb89..413077619 100644 --- a/tests/integration/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/integration/butterfree/transform/test_aggregated_feature_set.py @@ -19,7 +19,9 @@ def divide(df, fs, column1, column2): class TestAggregatedFeatureSet: def test_construct_without_window( - self, feature_set_dataframe, target_df_without_window, + self, + feature_set_dataframe, + target_df_without_window, ): # given @@ -157,7 +159,9 @@ def test_construct_rolling_windows_without_end_date( ) ], timestamp=TimestampFeature(), - ).with_windows(definitions=["1 day", "1 week"],) + ).with_windows( + definitions=["1 day", "1 week"], + ) # act & assert with pytest.raises(ValueError): @@ -201,7 +205,9 @@ def test_h3_feature_set(self, h3_input_df, h3_target_df): assert_dataframe_equality(output_df, h3_target_df) def test_construct_with_pivot( - self, feature_set_df_pivot, target_df_pivot_agg, + self, + feature_set_df_pivot, + target_df_pivot_agg, ): # given @@ -241,3 +247,55 @@ def test_construct_with_pivot( # assert assert_dataframe_equality(output_df, target_df_pivot_agg) + + def test_construct_rolling_windows_with_date_boundaries( + self, + feature_set_dates_dataframe, + rolling_windows_output_date_boundaries, + ): + # given + + spark_client = SparkClient() + + # arrange + + feature_set = AggregatedFeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[ + Function(F.avg, DataType.DOUBLE), + Function(F.stddev_pop, DataType.DOUBLE), + ], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ).with_windows(definitions=["1 day", "1 week"]) + + # act + output_df = feature_set.construct( + feature_set_dates_dataframe, + client=spark_client, + start_date="2016-04-11", + end_date="2016-04-12", + ).orderBy("timestamp") + + target_df = rolling_windows_output_date_boundaries.orderBy( + feature_set.timestamp_column + ).select(feature_set.columns) + + # assert + assert_dataframe_equality(output_df, target_df) diff --git a/tests/integration/butterfree/transform/test_feature_set.py b/tests/integration/butterfree/transform/test_feature_set.py index 4872ded24..6c5f7f1d8 100644 --- a/tests/integration/butterfree/transform/test_feature_set.py +++ b/tests/integration/butterfree/transform/test_feature_set.py @@ -51,7 +51,9 @@ def test_construct( description="unit test", dtype=DataType.FLOAT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ), ], @@ -77,3 +79,51 @@ def test_construct( # assert assert_dataframe_equality(output_df, target_df) + + def test_construct_with_date_boundaries( + self, feature_set_dates_dataframe, feature_set_dates_output_dataframe + ): + # given + + spark_client = SparkClient() + + # arrange + + feature_set = FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + dtype=DataType.FLOAT, + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ) + + output_df = ( + feature_set.construct( + feature_set_dates_dataframe, + client=spark_client, + start_date="2016-04-11", + end_date="2016-04-12", + ) + .orderBy(feature_set.timestamp_column) + .select(feature_set.columns) + ) + + target_df = feature_set_dates_output_dataframe.orderBy( + feature_set.timestamp_column + ).select(feature_set.columns) + + # assert + assert_dataframe_equality(output_df, target_df) diff --git a/tests/mocks/__init__.py b/tests/mocks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mocks/entities/__init__.py b/tests/mocks/entities/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mocks/entities/first/__init__.py b/tests/mocks/entities/first/__init__.py new file mode 100644 index 000000000..e69592de4 --- /dev/null +++ b/tests/mocks/entities/first/__init__.py @@ -0,0 +1,3 @@ +from .first_pipeline import FirstPipeline + +__all__ = ["FirstPipeline"] diff --git a/tests/mocks/entities/first/first_pipeline.py b/tests/mocks/entities/first/first_pipeline.py new file mode 100644 index 000000000..938c880c7 --- /dev/null +++ b/tests/mocks/entities/first/first_pipeline.py @@ -0,0 +1,55 @@ +from butterfree.constants.data_type import DataType +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import ( + HistoricalFeatureStoreWriter, + OnlineFeatureStoreWriter, +) +from butterfree.pipelines import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature + + +class FirstPipeline(FeatureSetPipeline): + def __init__(self): + super(FirstPipeline, self).__init__( + source=Source( + readers=[ + TableReader( + id="t", + database="db", + table="table", + ) + ], + query=f"select * from t", # noqa + ), + feature_set=FeatureSet( + name="first", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + dtype=DataType.FLOAT, + ), + Feature( + name="feature2", + description="another test", + dtype=DataType.STRING, + ), + ], + keys=[ + KeyFeature( + name="id", + description="identifier", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ), + sink=Sink( + writers=[HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter()] + ), + ) diff --git a/tests/mocks/entities/second/__init__.py b/tests/mocks/entities/second/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mocks/entities/second/deeper/__init__.py b/tests/mocks/entities/second/deeper/__init__.py new file mode 100644 index 000000000..9f70be75d --- /dev/null +++ b/tests/mocks/entities/second/deeper/__init__.py @@ -0,0 +1,3 @@ +from .second_pipeline import SecondPipeline + +__all__ = ["SecondPipeline"] diff --git a/tests/mocks/entities/second/deeper/second_pipeline.py b/tests/mocks/entities/second/deeper/second_pipeline.py new file mode 100644 index 000000000..a59ba2e5d --- /dev/null +++ b/tests/mocks/entities/second/deeper/second_pipeline.py @@ -0,0 +1,55 @@ +from butterfree.constants.data_type import DataType +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import ( + HistoricalFeatureStoreWriter, + OnlineFeatureStoreWriter, +) +from butterfree.pipelines import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature + + +class SecondPipeline(FeatureSetPipeline): + def __init__(self): + super(SecondPipeline, self).__init__( + source=Source( + readers=[ + TableReader( + id="t", + database="db", + table="table", + ) + ], + query=f"select * from t", # noqa + ), + feature_set=FeatureSet( + name="second", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + dtype=DataType.STRING, + ), + Feature( + name="feature2", + description="another test", + dtype=DataType.FLOAT, + ), + ], + keys=[ + KeyFeature( + name="id", + description="identifier", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ), + sink=Sink( + writers=[HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter()] + ), + ) diff --git a/tests/unit/butterfree/_cli/__init__.py b/tests/unit/butterfree/_cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/_cli/test_migrate.py b/tests/unit/butterfree/_cli/test_migrate.py new file mode 100644 index 000000000..c0751c888 --- /dev/null +++ b/tests/unit/butterfree/_cli/test_migrate.py @@ -0,0 +1,43 @@ +from unittest.mock import call + +from typer.testing import CliRunner + +from butterfree._cli import migrate +from butterfree._cli.main import app +from butterfree.migrations.database_migration import CassandraMigration +from butterfree.pipelines import FeatureSetPipeline + +runner = CliRunner() + + +class TestMigrate: + def test_migrate_success(self, mocker): + mocker.patch.object(migrate.Migrate, "run") + all_fs = migrate.migrate("tests/mocks/entities/") + assert all(isinstance(fs, FeatureSetPipeline) for fs in all_fs) + assert sorted([fs.feature_set.name for fs in all_fs]) == ["first", "second"] + + def test_migrate_run_methods(self, mocker): + mocker.patch.object(CassandraMigration, "apply_migration") + mocker.patch.object(migrate.Migrate, "_send_logs_to_s3") + + all_fs = migrate.migrate("tests/mocks/entities/", False, False) + + assert CassandraMigration.apply_migration.call_count == 2 + + cassandra_pairs = [ + call(pipe.feature_set, pipe.sink.writers[1], False) for pipe in all_fs + ] + CassandraMigration.apply_migration.assert_has_calls( + cassandra_pairs, any_order=True + ) + migrate.Migrate._send_logs_to_s3.assert_called_once() + + def test_app_cli(self): + result = runner.invoke(app, "migrate") + assert result.exit_code == 0 + + def test_app_migrate(self, mocker): + mocker.patch.object(migrate.Migrate, "run") + result = runner.invoke(app, ["migrate", "apply", "tests/mocks/entities/"]) + assert result.exit_code == 0 diff --git a/tests/unit/butterfree/automated/__init__.py b/tests/unit/butterfree/automated/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/automated/test_feature_set_creation.py b/tests/unit/butterfree/automated/test_feature_set_creation.py new file mode 100644 index 000000000..cfb5101e6 --- /dev/null +++ b/tests/unit/butterfree/automated/test_feature_set_creation.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import MagicMock + +from butterfree.automated.feature_set_creation import FeatureSetCreation + + +class TestFeatureSetCreation(unittest.TestCase): + def setUp(self): + self.feature_set_creation = FeatureSetCreation() + + def test_get_features_with_regex(self): + sql_query = "SELECT column1, column2 FROM table1" + expected_features = ["column1", "column2"] + + features = self.feature_set_creation._get_features_with_regex(sql_query) + + self.assertEqual(features, expected_features) + + def test_get_data_type(self): + field_name = "column1" + df_mock = MagicMock() + df_mock.schema.jsonValue.return_value = { + "fields": [{"name": "column1", "type": "string"}] + } + + data_type = self.feature_set_creation._get_data_type(field_name, df_mock) + + self.assertEqual(data_type, ".STRING") diff --git a/tests/unit/butterfree/clients/conftest.py b/tests/unit/butterfree/clients/conftest.py index fda11f8ef..ffb2db881 100644 --- a/tests/unit/butterfree/clients/conftest.py +++ b/tests/unit/butterfree/clients/conftest.py @@ -46,11 +46,16 @@ def mocked_stream_df() -> Mock: return mock +@pytest.fixture() +def mock_spark_sql() -> Mock: + mock = Mock() + mock.sql = mock + return mock + + @pytest.fixture def cassandra_client() -> CassandraClient: - return CassandraClient( - cassandra_host=["mock"], cassandra_key_space="dummy_keyspace" - ) + return CassandraClient(host=["mock"], keyspace="dummy_keyspace") @pytest.fixture diff --git a/tests/unit/butterfree/clients/test_cassandra_client.py b/tests/unit/butterfree/clients/test_cassandra_client.py index 8785485be..0356e43f9 100644 --- a/tests/unit/butterfree/clients/test_cassandra_client.py +++ b/tests/unit/butterfree/clients/test_cassandra_client.py @@ -1,8 +1,6 @@ from typing import Any, Dict, List from unittest.mock import MagicMock -import pytest - from butterfree.clients import CassandraClient from butterfree.clients.cassandra_client import CassandraColumn @@ -15,9 +13,7 @@ def sanitize_string(query: str) -> str: class TestCassandraClient: def test_conn(self, cassandra_client: CassandraClient) -> None: # arrange - cassandra_client = CassandraClient( - cassandra_host=["mock"], cassandra_key_space="dummy_keyspace" - ) + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") # act start_conn = cassandra_client._session @@ -90,31 +86,3 @@ def test_cassandra_create_table( query = cassandra_client.sql.call_args[0][0] assert sanitize_string(query) == sanitize_string(expected_query) - - def test_cassandra_without_session(self, cassandra_client: CassandraClient) -> None: - cassandra_client = cassandra_client - - with pytest.raises( - RuntimeError, match="There's no session available for this query." - ): - cassandra_client.sql( - query="select feature1, feature2 from cassandra_feature_set" - ) - with pytest.raises( - RuntimeError, match="There's no session available for this query." - ): - cassandra_client.create_table( - [ - {"column_name": "id", "type": "int", "primary_key": True}, - { - "column_name": "rent_per_month", - "type": "float", - "primary_key": False, - }, - ], - "test", - ) - with pytest.raises( - RuntimeError, match="There's no session available for this query." - ): - cassandra_client.get_schema("test") diff --git a/tests/unit/butterfree/clients/test_spark_client.py b/tests/unit/butterfree/clients/test_spark_client.py index 58d53a401..b2418a7c6 100644 --- a/tests/unit/butterfree/clients/test_spark_client.py +++ b/tests/unit/butterfree/clients/test_spark_client.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Union +from datetime import datetime +from typing import Any, Optional, Union from unittest.mock import Mock import pytest @@ -14,6 +15,15 @@ def create_temp_view(dataframe: DataFrame, name: str) -> None: dataframe.createOrReplaceTempView(name) +def create_db_and_table(spark, database, table, view): + spark.sql(f"create database if not exists {database}") + spark.sql(f"use {database}") + spark.sql( + f"create table if not exists {database}.{table} " # noqa + f"as select * from {view}" # noqa + ) + + class TestSparkClient: def test_conn(self) -> None: # arrange @@ -26,19 +36,20 @@ def test_conn(self) -> None: assert start_conn is None @pytest.mark.parametrize( - "format, options, stream, schema", + "format, path, stream, schema, options", [ - ("parquet", {"path": "path/to/file"}, False, None), - ("csv", {"path": "path/to/file", "header": True}, False, None), - ("json", {"path": "path/to/file"}, True, None), + ("parquet", ["path/to/file"], False, None, {}), + ("csv", "path/to/file", False, None, {"header": True}), + ("json", "path/to/file", True, None, {}), ], ) def test_read( self, format: str, - options: Dict[str, Any], stream: bool, schema: Optional[StructType], + path: Any, + options: Any, target_df: DataFrame, mocked_spark_read: Mock, ) -> None: @@ -48,26 +59,26 @@ def test_read( spark_client._session = mocked_spark_read # act - result_df = spark_client.read(format, options, schema, stream) + result_df = spark_client.read( + format=format, schema=schema, stream=stream, path=path, **options + ) # assert mocked_spark_read.format.assert_called_once_with(format) - mocked_spark_read.options.assert_called_once_with(**options) + mocked_spark_read.load.assert_called_once_with(path=path, **options) assert target_df.collect() == result_df.collect() @pytest.mark.parametrize( - "format, options", - [(None, {"path": "path/to/file"}), ("csv", "not a valid options")], + "format, path", + [(None, "path/to/file"), ("csv", 123)], ) - def test_read_invalid_params( - self, format: Optional[str], options: Union[Dict[str, Any], str] - ) -> None: + def test_read_invalid_params(self, format: Optional[str], path: Any) -> None: # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): - spark_client.read(format, options) # type: ignore + spark_client.read(format=format, path=path) # type: ignore def test_sql(self, target_df: DataFrame) -> None: # arrange @@ -105,7 +116,8 @@ def test_read_table( assert target_df == result_df @pytest.mark.parametrize( - "database, table", [("database", None), ("database", 123)], + "database, table", + [("database", None), ("database", 123)], ) def test_read_table_invalid_params( self, database: str, table: Optional[int] @@ -118,7 +130,8 @@ def test_read_table_invalid_params( spark_client.read_table(table, database) # type: ignore @pytest.mark.parametrize( - "format, mode", [("parquet", "append"), ("csv", "overwrite")], + "format, mode", + [("parquet", "append"), ("csv", "overwrite")], ) def test_write_dataframe( self, format: str, mode: str, mocked_spark_write: Mock @@ -127,7 +140,8 @@ def test_write_dataframe( mocked_spark_write.save.assert_called_with(format=format, mode=mode) @pytest.mark.parametrize( - "format, mode", [(None, "append"), ("parquet", 1)], + "format, mode", + [(None, "append"), ("parquet", 1)], ) def test_write_dataframe_invalid_params( self, target_df: DataFrame, format: Optional[str], mode: Union[str, int] @@ -252,3 +266,67 @@ def test_create_temporary_view( # assert assert_dataframe_equality(target_df, result_df) + + def test_add_table_partitions(self, mock_spark_sql: Mock): + # arrange + target_command = ( + f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS " # noqa + f"PARTITION ( year = 2020, month = 8, day = 14 ) " + f"PARTITION ( year = 2020, month = 8, day = 15 ) " + f"PARTITION ( year = 2020, month = 8, day = 16 )" + ) + + spark_client = SparkClient() + spark_client._session = mock_spark_sql + partitions = [ + {"year": 2020, "month": 8, "day": 14}, + {"year": 2020, "month": 8, "day": 15}, + {"year": 2020, "month": 8, "day": 16}, + ] + + # act + spark_client.add_table_partitions(partitions, "table", "db") + + # assert + mock_spark_sql.assert_called_once_with(target_command) + + @pytest.mark.parametrize( + "partition", + [ + [{"float_partition": 2.72}], + [{123: 2020}], + [{"date": datetime(year=2020, month=8, day=18)}], + ], + ) + def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition): + # arrange + spark_client = SparkClient() + spark_client._session = mock_spark_sql + + # act and assert + with pytest.raises(ValueError): + spark_client.add_table_partitions(partition, "table", "db") + + def test_get_schema( + self, target_df: DataFrame, spark_session: SparkSession + ) -> None: + # arrange + spark_client = SparkClient() + create_temp_view(dataframe=target_df, name="temp_view") + create_db_and_table( + spark=spark_session, + database="test_db", + table="test_table", + view="temp_view", + ) + + expected_schema = [ + {"col_name": "col1", "data_type": "string"}, + {"col_name": "col2", "data_type": "bigint"}, + ] + + # act + schema = spark_client.get_schema(table="test_table", database="test_db") + + # assert + assert schema, expected_schema diff --git a/tests/unit/butterfree/configs/db/test_cassandra_config.py b/tests/unit/butterfree/configs/db/test_cassandra_config.py index f51ffe8cc..fa907a07a 100644 --- a/tests/unit/butterfree/configs/db/test_cassandra_config.py +++ b/tests/unit/butterfree/configs/db/test_cassandra_config.py @@ -159,11 +159,77 @@ def test_stream_checkpoint_path_custom(self, cassandra_config): # then assert cassandra_config.stream_checkpoint_path == value + def test_read_consistency_level(self, cassandra_config): + # expecting + default = "LOCAL_ONE" + assert cassandra_config.read_consistency_level == default + + def test_read_consistency_level_custom(self, cassandra_config): + # given + value = "Custom Config" + cassandra_config.read_consistency_level = value + + # then + assert cassandra_config.read_consistency_level == value + + def test_read_consistency_level_custom_env_var(self, mocker, cassandra_config): + # given + value = "Custom Config" + mocker.patch("butterfree.configs.environment.get_variable", return_value=value) + cassandra_config.read_consistency_level = value + + # then + assert cassandra_config.read_consistency_level == value + + def test_write_consistency_level(self, cassandra_config): + # expecting + default = "LOCAL_QUORUM" + assert cassandra_config.write_consistency_level == default + + def test_write_consistency_level_custom(self, cassandra_config): + # given + value = "Custom Config" + cassandra_config.write_consistency_level = value + + # then + assert cassandra_config.write_consistency_level == value + + def test_write_consistency_level_custom_env_var(self, mocker, cassandra_config): + # given + value = "Custom Config" + mocker.patch("butterfree.configs.environment.get_variable", return_value=value) + cassandra_config.write_consistency_level = value + + # then + assert cassandra_config.write_consistency_level == value + + def test_local_dc(self, cassandra_config): + # expecting + default = None + assert cassandra_config.local_dc == default + + def test_local_dc_custom(self, cassandra_config): + # given + value = "VPC_1" + cassandra_config.local_dc = value + + # then + assert cassandra_config.local_dc == value + + def test_local_dc_custom_env_var(self, mocker, cassandra_config): + # given + value = "VPC_1" + mocker.patch("butterfree.configs.environment.get_variable", return_value=value) + cassandra_config.local_dc = value + + # then + assert cassandra_config.local_dc == value + def test_set_credentials_on_instantiation(self): cassandra_config = CassandraConfig( # noqa: S106 username="username", password="password", host="host", keyspace="keyspace" ) assert cassandra_config.username == "username" - assert cassandra_config.password == "password" + assert cassandra_config.password == "password" # noqa: S105 assert cassandra_config.host == "host" assert cassandra_config.keyspace == "keyspace" diff --git a/tests/unit/butterfree/dataframe_service/conftest.py b/tests/unit/butterfree/dataframe_service/conftest.py index 867bc80a3..09470c9a4 100644 --- a/tests/unit/butterfree/dataframe_service/conftest.py +++ b/tests/unit/butterfree/dataframe_service/conftest.py @@ -25,3 +25,17 @@ def input_df(spark_context, spark_session): return spark_session.read.json( spark_context.parallelize(data, 1), schema="timestamp timestamp" ) + + +@pytest.fixture() +def test_partitioning_input_df(spark_context, spark_session): + data = [ + {"feature": 1, "year": 2009, "month": 8, "day": 20}, + {"feature": 2, "year": 2009, "month": 8, "day": 20}, + {"feature": 3, "year": 2020, "month": 8, "day": 20}, + {"feature": 4, "year": 2020, "month": 9, "day": 20}, + {"feature": 5, "year": 2020, "month": 9, "day": 20}, + {"feature": 6, "year": 2020, "month": 8, "day": 20}, + {"feature": 7, "year": 2020, "month": 8, "day": 21}, + ] + return spark_session.read.json(spark_context.parallelize(data, 1)) diff --git a/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py new file mode 100644 index 000000000..a140ceb30 --- /dev/null +++ b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py @@ -0,0 +1,70 @@ +from butterfree.dataframe_service import IncrementalStrategy + + +class TestIncrementalStrategy: + def test_from_milliseconds(self): + # arrange + incremental_strategy = IncrementalStrategy().from_milliseconds("ts") + target_expression = "date(from_unixtime(ts/ 1000.0)) >= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_from_string(self): + # arrange + incremental_strategy = IncrementalStrategy().from_string( + "dt", mask="dd/MM/yyyy" + ) + target_expression = "date(to_date(dt, 'dd/MM/yyyy')) >= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_from_year_month_day_partitions(self): + # arrange + incremental_strategy = IncrementalStrategy().from_year_month_day_partitions( + year_column="y", month_column="m", day_column="d" + ) + target_expression = ( + "date(concat(string(y), " + "'-', string(m), " + "'-', string(d))) >= date('2020-01-01')" + ) + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_get_expression_with_just_end_date(self): + # arrange + incremental_strategy = IncrementalStrategy(column="dt") + target_expression = "date(dt) <= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(end_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_get_expression_with_start_and_end_date(self): + # arrange + incremental_strategy = IncrementalStrategy(column="dt") + target_expression = ( + "date(dt) >= date('2019-12-30') and date(dt) <= date('2020-01-01')" + ) + + # act + result_expression = incremental_strategy.get_expression( + start_date="2019-12-30", end_date="2020-01-01" + ) + + # assert + assert target_expression.split() == result_expression.split() diff --git a/tests/unit/butterfree/dataframe_service/test_partitioning.py b/tests/unit/butterfree/dataframe_service/test_partitioning.py new file mode 100644 index 000000000..3a6b5b406 --- /dev/null +++ b/tests/unit/butterfree/dataframe_service/test_partitioning.py @@ -0,0 +1,20 @@ +from butterfree.dataframe_service import extract_partition_values + + +class TestPartitioning: + def test_extract_partition_values(self, test_partitioning_input_df): + # arrange + target_values = [ + {"year": 2009, "month": 8, "day": 20}, + {"year": 2020, "month": 8, "day": 20}, + {"year": 2020, "month": 9, "day": 20}, + {"year": 2020, "month": 8, "day": 21}, + ] + + # act + result_values = extract_partition_values( + test_partitioning_input_df, partition_columns=["year", "month", "day"] + ) + + # assert + assert result_values == target_values diff --git a/tests/unit/butterfree/extract/conftest.py b/tests/unit/butterfree/extract/conftest.py index ab6f525c7..3d0e763d3 100644 --- a/tests/unit/butterfree/extract/conftest.py +++ b/tests/unit/butterfree/extract/conftest.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from pyspark.sql.functions import col, to_date from butterfree.constants.columns import TIMESTAMP_COLUMN @@ -17,6 +18,60 @@ def target_df(spark_context, spark_session): return spark_session.read.json(spark_context.parallelize(data, 1)) +@pytest.fixture() +def incremental_source_df(spark_context, spark_session): + data = [ + { + "id": 1, + "feature": 100, + "date_str": "28/07/2020", + "milliseconds": 1595894400000, + "year": 2020, + "month": 7, + "day": 28, + }, + { + "id": 1, + "feature": 110, + "date_str": "29/07/2020", + "milliseconds": 1595980800000, + "year": 2020, + "month": 7, + "day": 29, + }, + { + "id": 1, + "feature": 120, + "date_str": "30/07/2020", + "milliseconds": 1596067200000, + "year": 2020, + "month": 7, + "day": 30, + }, + { + "id": 2, + "feature": 150, + "date_str": "31/07/2020", + "milliseconds": 1596153600000, + "year": 2020, + "month": 7, + "day": 31, + }, + { + "id": 2, + "feature": 200, + "date_str": "01/08/2020", + "milliseconds": 1596240000000, + "year": 2020, + "month": 8, + "day": 1, + }, + ] + return spark_session.read.json(spark_context.parallelize(data, 1)).withColumn( + "date", to_date(col("date_str"), "dd/MM/yyyy") + ) + + @pytest.fixture() def spark_client(): return Mock() diff --git a/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py b/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py index 669fd0336..fed20f2d4 100644 --- a/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py +++ b/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py @@ -28,7 +28,8 @@ def test_filter(self, feature_set_dataframe, spark_context, spark_session): assert result_df.collect() == target_df.collect() @pytest.mark.parametrize( - "condition", [None, 100], + "condition", + [None, 100], ) def test_filter_with_invalidations( self, feature_set_dataframe, condition, spark_context, spark_session diff --git a/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py b/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py index e716f9d65..cfe730d3a 100644 --- a/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py +++ b/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py @@ -9,7 +9,9 @@ class TestPivotTransform: def test_pivot_transformation( - self, input_df, pivot_df, + self, + input_df, + pivot_df, ): result_df = pivot( dataframe=input_df, @@ -20,10 +22,15 @@ def test_pivot_transformation( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_df, + ) def test_pivot_transformation_with_forward_fill( - self, input_df, pivot_ffill_df, + self, + input_df, + pivot_ffill_df, ): result_df = pivot( dataframe=input_df, @@ -35,10 +42,15 @@ def test_pivot_transformation_with_forward_fill( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_ffill_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_ffill_df, + ) def test_pivot_transformation_with_forward_fill_and_mock( - self, input_df, pivot_ffill_mock_df, + self, + input_df, + pivot_ffill_mock_df, ): result_df = pivot( dataframe=input_df, @@ -52,10 +64,15 @@ def test_pivot_transformation_with_forward_fill_and_mock( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_ffill_mock_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_ffill_mock_df, + ) def test_pivot_transformation_mock_without_type( - self, input_df, pivot_ffill_mock_df, + self, + input_df, + pivot_ffill_mock_df, ): with pytest.raises(AttributeError): _ = pivot( @@ -83,4 +100,7 @@ def test_apply_pivot_transformation(self, input_df, pivot_df): result_df = file_reader._apply_transformations(input_df) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_df, + ) diff --git a/tests/unit/butterfree/extract/readers/test_file_reader.py b/tests/unit/butterfree/extract/readers/test_file_reader.py index d337d4fef..136c8fd62 100644 --- a/tests/unit/butterfree/extract/readers/test_file_reader.py +++ b/tests/unit/butterfree/extract/readers/test_file_reader.py @@ -7,7 +7,15 @@ class TestFileReader: @pytest.mark.parametrize( - "path, format", [(None, "parquet"), ("path/to/file.json", 123), (123, None,)], + "path, format", + [ + (None, "parquet"), + ("path/to/file.json", 123), + ( + 123, + None, + ), + ], ) def test_init_invalid_params(self, path, format): # act and assert @@ -36,11 +44,11 @@ def test_consume( # act output_df = file_reader.consume(spark_client) - options = dict({"path": path}, **format_options if format_options else {}) + options = dict(format_options if format_options else {}) # assert spark_client.read.assert_called_once_with( - format=format, options=options, schema=schema, stream=False + format=format, schema=schema, stream=False, path=path, **options ) assert target_df.collect() == output_df.collect() @@ -51,7 +59,7 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df): schema = None format_options = None stream = True - options = dict({"path": path}) + options = dict({}) spark_client.read.return_value = target_df file_reader = FileReader( @@ -64,11 +72,11 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df): # assert # assert call for schema infer - spark_client.read.assert_any_call(format=format, options=options) + spark_client.read.assert_any_call(format=format, path=path, **options) # assert call for stream read # stream spark_client.read.assert_called_with( - format=format, options=options, schema=output_df.schema, stream=stream + format=format, schema=output_df.schema, stream=stream, path=path, **options ) assert target_df.collect() == output_df.collect() diff --git a/tests/unit/butterfree/extract/readers/test_kafka_reader.py b/tests/unit/butterfree/extract/readers/test_kafka_reader.py index 5a07cbdd9..f1ea82ae3 100644 --- a/tests/unit/butterfree/extract/readers/test_kafka_reader.py +++ b/tests/unit/butterfree/extract/readers/test_kafka_reader.py @@ -99,7 +99,7 @@ def test_consume( # assert spark_client.read.assert_called_once_with( - format="kafka", options=options, stream=kafka_reader.stream + format="kafka", stream=kafka_reader.stream, **options ) assert_dataframe_equality(target_df, output_df) diff --git a/tests/unit/butterfree/extract/readers/test_reader.py b/tests/unit/butterfree/extract/readers/test_reader.py index c210a756d..bcceacbd1 100644 --- a/tests/unit/butterfree/extract/readers/test_reader.py +++ b/tests/unit/butterfree/extract/readers/test_reader.py @@ -1,7 +1,9 @@ import pytest from pyspark.sql.functions import expr +from butterfree.dataframe_service import IncrementalStrategy from butterfree.extract.readers import FileReader +from butterfree.testing.dataframe import assert_dataframe_equality def add_value_transformer(df, column, value): @@ -146,9 +148,66 @@ def test_build_with_columns( # act file_reader.build( - client=spark_client, columns=[("col1", "new_col1"), ("col2", "new_col2")], + client=spark_client, + columns=[("col1", "new_col1"), ("col2", "new_col2")], ) result_df = spark_session.sql("select * from test") # assert assert column_target_df.collect() == result_df.collect() + + def test_build_with_incremental_strategy( + self, incremental_source_df, spark_client, spark_session + ): + # arrange + readers = [ + # directly from column + FileReader( + id="test_1", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy(column="date") + ), + # from milliseconds + FileReader( + id="test_2", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy().from_milliseconds( + column_name="milliseconds" + ) + ), + # from str + FileReader( + id="test_3", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy().from_string( + column_name="date_str", mask="dd/MM/yyyy" + ) + ), + # from year, month, day partitions + FileReader( + id="test_4", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=( + IncrementalStrategy().from_year_month_day_partitions() + ) + ), + ] + + spark_client.read.return_value = incremental_source_df + target_df = incremental_source_df.where( + "date >= date('2020-07-29') and date <= date('2020-07-31')" + ) + + # act + for reader in readers: + reader.build( + client=spark_client, start_date="2020-07-29", end_date="2020-07-31" + ) + + output_dfs = [ + spark_session.table(f"test_{i + 1}") for i, _ in enumerate(readers) + ] + + # assert + for output_df in output_dfs: + assert_dataframe_equality(output_df=output_df, target_df=target_df) diff --git a/tests/unit/butterfree/extract/readers/test_table_reader.py b/tests/unit/butterfree/extract/readers/test_table_reader.py index 65f3be236..1a2f56f23 100644 --- a/tests/unit/butterfree/extract/readers/test_table_reader.py +++ b/tests/unit/butterfree/extract/readers/test_table_reader.py @@ -5,7 +5,14 @@ class TestTableReader: @pytest.mark.parametrize( - "database, table", [("database", 123), (123, None,)], + "database, table", + [ + ("database", 123), + ( + 123, + None, + ), + ], ) def test_init_invalid_params(self, database, table): # act and assert diff --git a/tests/unit/butterfree/extract/test_source.py b/tests/unit/butterfree/extract/test_source.py index 53af8b658..842d2210f 100644 --- a/tests/unit/butterfree/extract/test_source.py +++ b/tests/unit/butterfree/extract/test_source.py @@ -14,7 +14,8 @@ def test_construct(self, mocker, target_df): # when source_selector = Source( - readers=[reader], query=f"select * from {reader_id}", # noqa + readers=[reader], + query=f"select * from {reader_id}", # noqa ) result_df = source_selector.construct(spark_client) @@ -32,7 +33,8 @@ def test_is_cached(self, mocker, target_df): # when source_selector = Source( - readers=[reader], query=f"select * from {reader_id}", # noqa + readers=[reader], + query=f"select * from {reader_id}", # noqa ) result_df = source_selector.construct(spark_client) diff --git a/tests/unit/butterfree/hooks/__init__.py b/tests/unit/butterfree/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/hooks/schema_compatibility/__init__.py b/tests/unit/butterfree/hooks/schema_compatibility/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py new file mode 100644 index 000000000..eccb8d8cc --- /dev/null +++ b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py @@ -0,0 +1,49 @@ +from unittest.mock import MagicMock + +import pytest + +from butterfree.clients import CassandraClient +from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook + + +class TestCassandraTableSchemaCompatibilityHook: + def test_run_compatible_schema(self, spark_session): + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") + + cassandra_client.sql = MagicMock( # type: ignore + return_value=[ + {"column_name": "feature1", "type": "text"}, + {"column_name": "feature2", "type": "int"}, + ] + ) + + table = "table" + + input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2") + + hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table) + + # act and assert + assert hook.run(input_dataframe) == input_dataframe + + def test_run_incompatible_schema(self, spark_session): + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") + + cassandra_client.sql = MagicMock( # type: ignore + return_value=[ + {"column_name": "feature1", "type": "text"}, + {"column_name": "feature2", "type": "bigint"}, + ] + ) + + table = "table" + + input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2") + + hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table) + + # act and assert + with pytest.raises( + ValueError, match="There's a schema incompatibility between" + ): + hook.run(input_dataframe) diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py new file mode 100644 index 000000000..3a31b600c --- /dev/null +++ b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py @@ -0,0 +1,53 @@ +import pytest + +from butterfree.clients import SparkClient +from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook + + +class TestSparkTableSchemaCompatibilityHook: + @pytest.mark.parametrize( + "table, database, target_table_expression", + [("table", "database", "`database`.`table`"), ("table", None, "`table`")], + ) + def test_build_table_expression(self, table, database, target_table_expression): + # arrange + spark_client = SparkClient() + + # act + result_table_expression = SparkTableSchemaCompatibilityHook( + spark_client, table, database + ).table_expression + + # assert + assert target_table_expression == result_table_expression + + def test_run_compatible_schema(self, spark_session): + # arrange + spark_client = SparkClient() + target_table = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as other_feature" + ) + input_dataframe = spark_session.sql("select 1 as feature_a, 'abc' as feature_b") + target_table.registerTempTable("test") + + hook = SparkTableSchemaCompatibilityHook(spark_client, "test") + + # act and assert + assert hook.run(input_dataframe) == input_dataframe + + def test_run_incompatible_schema(self, spark_session): + # arrange + spark_client = SparkClient() + target_table = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as other_feature" + ) + input_dataframe = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as unregisted_column" + ) + target_table.registerTempTable("test") + + hook = SparkTableSchemaCompatibilityHook(spark_client, "test") + + # act and assert + with pytest.raises(ValueError, match="The dataframe has a schema incompatible"): + hook.run(input_dataframe) diff --git a/tests/unit/butterfree/hooks/test_hookable_component.py b/tests/unit/butterfree/hooks/test_hookable_component.py new file mode 100644 index 000000000..37e34e691 --- /dev/null +++ b/tests/unit/butterfree/hooks/test_hookable_component.py @@ -0,0 +1,107 @@ +import pytest +from pyspark.sql.functions import expr + +from butterfree.hooks import Hook, HookableComponent +from butterfree.testing.dataframe import assert_dataframe_equality + + +class TestComponent(HookableComponent): + def construct(self, dataframe): + pre_hook_df = self.run_pre_hooks(dataframe) + construct_df = pre_hook_df.withColumn("feature", expr("feature * feature")) + return self.run_post_hooks(construct_df) + + +class AddHook(Hook): + def __init__(self, value): + self.value = value + + def run(self, dataframe): + return dataframe.withColumn("feature", expr(f"feature + {self.value}")) + + +class TestHookableComponent: + def test_add_hooks(self): + # arrange + hook1 = AddHook(value=1) + hook2 = AddHook(value=2) + hook3 = AddHook(value=3) + hook4 = AddHook(value=4) + hookable_component = HookableComponent() + + # act + hookable_component.add_pre_hook(hook1, hook2) + hookable_component.add_post_hook(hook3, hook4) + + # assert + assert hookable_component.pre_hooks == [hook1, hook2] + assert hookable_component.post_hooks == [hook3, hook4] + + @pytest.mark.parametrize( + "enable_pre_hooks, enable_post_hooks", + [("not boolean", False), (False, "not boolean")], + ) + def test_invalid_enable_hook(self, enable_pre_hooks, enable_post_hooks): + # arrange + hookable_component = HookableComponent() + + # act and assert + with pytest.raises(ValueError): + hookable_component.enable_pre_hooks = enable_pre_hooks + hookable_component.enable_post_hooks = enable_post_hooks + + @pytest.mark.parametrize( + "pre_hooks, post_hooks", + [ + ([AddHook(1)], "not a list of hooks"), + ([AddHook(1)], [AddHook(1), 2, 3]), + ("not a list of hooks", [AddHook(1)]), + ([AddHook(1), 2, 3], [AddHook(1)]), + ], + ) + def test_invalid_hooks(self, pre_hooks, post_hooks): + # arrange + hookable_component = HookableComponent() + + # act and assert + with pytest.raises(ValueError): + hookable_component.pre_hooks = pre_hooks + hookable_component.post_hooks = post_hooks + + @pytest.mark.parametrize( + "pre_hook, enable_pre_hooks, post_hook, enable_post_hooks", + [ + (AddHook(value=1), False, AddHook(value=1), True), + (AddHook(value=1), True, AddHook(value=1), False), + ("not a pre-hook", True, AddHook(value=1), True), + (AddHook(value=1), True, "not a pre-hook", True), + ], + ) + def test_add_invalid_hooks( + self, pre_hook, enable_pre_hooks, post_hook, enable_post_hooks + ): + # arrange + hookable_component = HookableComponent() + hookable_component.enable_pre_hooks = enable_pre_hooks + hookable_component.enable_post_hooks = enable_post_hooks + + # act and assert + with pytest.raises(ValueError): + hookable_component.add_pre_hook(pre_hook) + hookable_component.add_post_hook(post_hook) + + def test_run_hooks(self, spark_session): + # arrange + input_dataframe = spark_session.sql("select 2 as feature") + test_component = ( + TestComponent() + .add_pre_hook(AddHook(value=1)) + .add_post_hook(AddHook(value=1)) + ) + target_table = spark_session.sql("select 10 as feature") + + # act + output_df = test_component.construct(input_dataframe) + + # assert + assert_dataframe_equality(output_df, target_table) diff --git a/tests/unit/butterfree/load/conftest.py b/tests/unit/butterfree/load/conftest.py index 7c2549c58..d0bb2c3be 100644 --- a/tests/unit/butterfree/load/conftest.py +++ b/tests/unit/butterfree/load/conftest.py @@ -20,7 +20,11 @@ def feature_set(): ] ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) features = [ - Feature(name="feature", description="Description", dtype=DataType.BIGINT,) + Feature( + name="feature", + description="Description", + dtype=DataType.BIGINT, + ) ] return FeatureSet( "feature_set", @@ -32,6 +36,31 @@ def feature_set(): ) +@fixture +def feature_set_incremental(): + key_features = [ + KeyFeature(name="id", description="Description", dtype=DataType.INTEGER) + ] + ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) + features = [ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.sum, DataType.INTEGER)] + ), + ), + ] + return AggregatedFeatureSet( + "feature_set", + "entity", + "description", + keys=key_features, + timestamp=ts_feature, + features=features, + ) + + @fixture def feature_set_dataframe(spark_context, spark_session): data = [ diff --git a/tests/unit/butterfree/load/processing/test_json_transform.py b/tests/unit/butterfree/load/processing/test_json_transform.py index 73949eea7..78320d108 100644 --- a/tests/unit/butterfree/load/processing/test_json_transform.py +++ b/tests/unit/butterfree/load/processing/test_json_transform.py @@ -3,7 +3,9 @@ class TestJsonTransform: def test_json_transformation( - self, input_df, json_df, + self, + input_df, + json_df, ): result_df = json_transform(dataframe=input_df) diff --git a/tests/unit/butterfree/load/test_sink.py b/tests/unit/butterfree/load/test_sink.py index 93b5e2797..517f651e0 100644 --- a/tests/unit/butterfree/load/test_sink.py +++ b/tests/unit/butterfree/load/test_sink.py @@ -136,6 +136,7 @@ def test_flush_streaming_df(self, feature_set): mocked_stream_df.start.return_value = Mock(spec=StreamingQuery) online_feature_store_writer = OnlineFeatureStoreWriter() + online_feature_store_writer_on_entity = OnlineFeatureStoreWriter( write_to_entity=True ) @@ -173,6 +174,7 @@ def test_flush_with_multiple_online_writers( feature_set.name = "my_feature_set" online_feature_store_writer = OnlineFeatureStoreWriter() + online_feature_store_writer_on_entity = OnlineFeatureStoreWriter( write_to_entity=True ) diff --git a/tests/unit/butterfree/load/writers/test_delta_writer.py b/tests/unit/butterfree/load/writers/test_delta_writer.py new file mode 100644 index 000000000..550f6d057 --- /dev/null +++ b/tests/unit/butterfree/load/writers/test_delta_writer.py @@ -0,0 +1,83 @@ +import os +from unittest import mock + +import pytest + +from butterfree.clients import SparkClient +from butterfree.load.writers import DeltaWriter + +DELTA_LOCATION = "spark-warehouse" + + +class TestDeltaWriter: + + def __checkFileExists(self, file_name: str = "test_delta_table") -> bool: + return os.path.exists(os.path.join(DELTA_LOCATION, file_name)) + + @pytest.fixture + def merge_builder_mock(self): + builder = mock.MagicMock() + builder.whenMatchedDelete.return_value = builder + builder.whenMatchedUpdateAll.return_value = builder + builder.whenNotMatchedInsertAll.return_value = builder + return builder + + def test_merge(self, feature_set_dataframe, merge_builder_mock): + + client = SparkClient() + delta_writer = DeltaWriter() + delta_writer.merge = mock.MagicMock() + + DeltaWriter().merge( + client=client, + database=None, + table="test_delta_table", + merge_on=["id"], + source_df=feature_set_dataframe, + ) + + assert merge_builder_mock.execute.assert_called_once + + # Step 2 + source = client.conn.createDataFrame( + [(1, "test3"), (2, "test4"), (3, "test5")], ["id", "feature"] + ) + + DeltaWriter().merge( + client=client, + database=None, + table="test_delta_table", + merge_on=["id"], + source_df=source, + when_not_matched_insert_condition=None, + when_matched_update_condition="id > 2", + ) + + assert merge_builder_mock.execute.assert_called_once + + def test_optimize(self, mocker): + + client = SparkClient() + conn_mock = mocker.patch( + "butterfree.clients.SparkClient.conn", return_value=mock.Mock() + ) + dw = DeltaWriter() + + dw.optimize = mock.MagicMock(client) + dw.optimize(client, "a_table") + + conn_mock.assert_called_once + + def test_vacuum(self, mocker): + + client = SparkClient() + conn_mock = mocker.patch( + "butterfree.clients.SparkClient.conn", return_value=mock.Mock() + ) + dw = DeltaWriter() + retention_hours = 24 + dw.vacuum = mock.MagicMock(client) + + dw.vacuum("a_table", retention_hours, client) + + conn_mock.assert_called_once diff --git a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py index 14c067f92..d9d9181a4 100644 --- a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py +++ b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py @@ -1,5 +1,6 @@ import datetime import random +from unittest import mock import pytest from pyspark.sql.functions import spark_partition_id @@ -19,7 +20,7 @@ def test_write( feature_set, ): # given - spark_client = mocker.stub("spark_client") + spark_client = SparkClient() spark_client.write_table = mocker.stub("write_table") writer = HistoricalFeatureStoreWriter() @@ -41,7 +42,60 @@ def test_write( assert ( writer.PARTITION_BY == spark_client.write_table.call_args[1]["partition_by"] ) + + def test_write_interval_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + mocker, + feature_set, + ): + # given + spark_client = SparkClient() + spark_client.write_table = mocker.stub("write_table") + spark_client.conn.conf.set( + "spark.sql.sources.partitionOverwriteMode", "dynamic" + ) + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + # when + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + result_df = spark_client.write_table.call_args[1]["dataframe"] + + # then + assert_dataframe_equality(historical_feature_set_dataframe, result_df) + + assert writer.database == spark_client.write_table.call_args[1]["database"] assert feature_set.name == spark_client.write_table.call_args[1]["table_name"] + assert ( + writer.PARTITION_BY == spark_client.write_table.call_args[1]["partition_by"] + ) + + def test_write_interval_mode_invalid_partition_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + mocker, + feature_set, + ): + # given + spark_client = SparkClient() + spark_client.write_dataframe = mocker.stub("write_dataframe") + spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "static") + + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + # when + with pytest.raises(RuntimeError): + _ = writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) def test_write_in_debug_mode( self, @@ -65,33 +119,104 @@ def test_write_in_debug_mode( # then assert_dataframe_equality(historical_feature_set_dataframe, result_df) - def test_validate(self, feature_set_dataframe, mocker, feature_set): + def test_write_in_debug_mode_with_interval_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + feature_set, + spark_session, + mocker, + ): + # given + spark_client = SparkClient() + spark_client.write_dataframe = mocker.stub("write_dataframe") + spark_client.conn.conf.set( + "spark.sql.sources.partitionOverwriteMode", "dynamic" + ) + writer = HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True) + + # when + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + result_df = spark_session.table(f"historical_feature_store__{feature_set.name}") + + # then + assert_dataframe_equality(historical_feature_set_dataframe, result_df) + + def test_merge_from_historical_writer( + self, + feature_set, + feature_set_dataframe, + mocker, + ): + # given + spark_client = SparkClient() + + spark_client.write_table = mocker.stub("write_table") + writer = HistoricalFeatureStoreWriter(merge_on=["id", "timestamp"]) + + static_mock = mocker.patch( + "butterfree.load.writers.DeltaWriter.merge", return_value=mock.Mock() + ) + + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + + assert static_mock.call_count == 1 + + def test_validate(self, historical_feature_set_dataframe, mocker, feature_set): # given spark_client = mocker.stub("spark_client") spark_client.read_table = mocker.stub("read_table") - spark_client.read_table.return_value = feature_set_dataframe + spark_client.read_table.return_value = historical_feature_set_dataframe writer = HistoricalFeatureStoreWriter() # when - writer.validate(feature_set, feature_set_dataframe, spark_client) + writer.validate(feature_set, historical_feature_set_dataframe, spark_client) # then spark_client.read_table.assert_called_once() - def test_validate_false(self, feature_set_dataframe, mocker, feature_set): + def test_validate_interval_mode( + self, historical_feature_set_dataframe, mocker, feature_set + ): # given spark_client = mocker.stub("spark_client") - spark_client.read_table = mocker.stub("read_table") + spark_client.read = mocker.stub("read") + spark_client.read.return_value = historical_feature_set_dataframe + + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + # when + writer.validate(feature_set, historical_feature_set_dataframe, spark_client) + + # then + spark_client.read.assert_called_once() + + def test_validate_false( + self, historical_feature_set_dataframe, mocker, feature_set + ): + # given + spark_client = mocker.stub("spark_client") + spark_client.read = mocker.stub("read") # limiting df to 1 row, now the counts should'n be the same - spark_client.read_table.return_value = feature_set_dataframe.limit(1) + spark_client.read.return_value = historical_feature_set_dataframe.limit(1) - writer = HistoricalFeatureStoreWriter() + writer = HistoricalFeatureStoreWriter(interval_mode=True) # when with pytest.raises(AssertionError): - _ = writer.validate(feature_set, feature_set_dataframe, spark_client) + _ = writer.validate( + feature_set, historical_feature_set_dataframe, spark_client + ) def test__create_partitions(self, spark_session, spark_context): # arrange @@ -201,6 +326,7 @@ def test_write_with_transform( # given spark_client = mocker.stub("spark_client") spark_client.write_table = mocker.stub("write_table") + writer = HistoricalFeatureStoreWriter().with_(json_transform) # when diff --git a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py index 87823c552..78f6862ee 100644 --- a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py +++ b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py @@ -94,6 +94,7 @@ def test_write_in_debug_mode( latest_feature_set_dataframe, feature_set, spark_session, + mocker, ): # given spark_client = SparkClient() @@ -110,9 +111,7 @@ def test_write_in_debug_mode( # then assert_dataframe_equality(latest_feature_set_dataframe, result_df) - def test_write_in_debug_and_stream_mode( - self, feature_set, spark_session, - ): + def test_write_in_debug_and_stream_mode(self, feature_set, spark_session): # arrange spark_client = SparkClient() diff --git a/tests/unit/butterfree/migrations/__init__.py b/tests/unit/butterfree/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/migrations/database_migration/__init__.py b/tests/unit/butterfree/migrations/database_migration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/butterfree/migrations/database_migration/conftest.py b/tests/unit/butterfree/migrations/database_migration/conftest.py new file mode 100644 index 000000000..3d3662d82 --- /dev/null +++ b/tests/unit/butterfree/migrations/database_migration/conftest.py @@ -0,0 +1,76 @@ +from pyspark.sql.types import ( + ArrayType, + DoubleType, + FloatType, + LongType, + StringType, + TimestampType, +) +from pytest import fixture + +from butterfree.constants import DataType +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature + + +@fixture +def db_schema(): + return [ + {"column_name": "id", "type": LongType(), "primary_key": True}, + {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, + { + "column_name": "feature1__avg_over_1_week_rolling_windows", + "type": DoubleType(), + "primary_key": False, + }, + { + "column_name": "feature1__avg_over_2_days_rolling_windows", + "type": DoubleType(), + "primary_key": False, + }, + ] + + +@fixture +def fs_schema(): + return [ + {"column_name": "id", "type": LongType(), "primary_key": True}, + {"column_name": "timestamp", "type": TimestampType(), "primary_key": True}, + {"column_name": "new_feature", "type": FloatType(), "primary_key": False}, + { + "column_name": "array_feature", + "type": ArrayType(StringType(), True), + "primary_key": False, + }, + { + "column_name": "feature1__avg_over_1_week_rolling_windows", + "type": FloatType(), + "primary_key": False, + }, + ] + + +@fixture +def feature_set(): + feature_set = FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature_float", + description="test", + dtype=DataType.FLOAT, + ), + ], + keys=[ + KeyFeature( + name="id", + description="The device ID", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ) + + return feature_set diff --git a/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py b/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py new file mode 100644 index 000000000..5e89b65bf --- /dev/null +++ b/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py @@ -0,0 +1,43 @@ +from butterfree.migrations.database_migration import CassandraMigration + + +class TestCassandraMigration: + def test_queries(self, fs_schema, db_schema): + cassandra_migration = CassandraMigration() + expected_query = [ + "ALTER TABLE table_name ADD (new_feature FloatType);", + "ALTER TABLE table_name DROP (feature1__avg_over_2_days_rolling_windows);", + "ALTER TABLE table_name ALTER " + "feature1__avg_over_1_week_rolling_windows TYPE FloatType;", + ] + query = cassandra_migration.create_query(fs_schema, "table_name", db_schema) + + assert query, expected_query + + def test_queries_on_entity(self, fs_schema, db_schema): + cassandra_migration = CassandraMigration() + expected_query = [ + "ALTER TABLE table_name ADD (new_feature FloatType);", + "ALTER TABLE table_name ALTER " + "feature1__avg_over_1_week_rolling_windows TYPE FloatType;", + ] + query = cassandra_migration.create_query( + fs_schema, "table_name", db_schema, True + ) + + assert query, expected_query + + def test_create_table_query(self, fs_schema): + + cassandra_migration = CassandraMigration() + expected_query = [ + "CREATE TABLE test.table_name " + "(id LongType, timestamp TimestampType, new_feature FloatType, " + "array_feature ArrayType(StringType(), True), " + "feature1__avg_over_1_week_rolling_windows FloatType, " + "PRIMARY KEY (id, timestamp));" + ] + + query = cassandra_migration.create_query(fs_schema, "table_name") + + assert query, expected_query diff --git a/tests/unit/butterfree/migrations/database_migration/test_database_migration.py b/tests/unit/butterfree/migrations/database_migration/test_database_migration.py new file mode 100644 index 000000000..ea7ce8158 --- /dev/null +++ b/tests/unit/butterfree/migrations/database_migration/test_database_migration.py @@ -0,0 +1,68 @@ +from pyspark.sql.types import DoubleType, FloatType, LongType, TimestampType + +from butterfree.load.writers import HistoricalFeatureStoreWriter +from butterfree.migrations.database_migration import CassandraMigration, Diff + + +class TestDatabaseMigration: + def test__get_diff_empty(self, mocker, db_schema): + fs_schema = [ + {"column_name": "id", "type": LongType(), "primary_key": True}, + {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, + { + "column_name": "feature1__avg_over_1_week_rolling_windows", + "type": DoubleType(), + "primary_key": False, + }, + { + "column_name": "feature1__avg_over_2_days_rolling_windows", + "type": DoubleType(), + "primary_key": False, + }, + ] + m = CassandraMigration() + m._client = mocker.stub("client") + diff = m._get_diff(fs_schema, db_schema) + assert not diff + + def test__get_diff(self, mocker, db_schema): + fs_schema = [ + {"column_name": "id", "type": LongType(), "primary_key": True}, + {"column_name": "timestamp", "type": TimestampType(), "primary_key": True}, + {"column_name": "new_feature", "type": FloatType(), "primary_key": False}, + { + "column_name": "feature1__avg_over_1_week_rolling_windows", + "type": FloatType(), + "primary_key": False, + }, + ] + expected_diff = { + Diff("timestamp", kind=Diff.Kind.ALTER_KEY, value=None), + Diff("new_feature", kind=Diff.Kind.ADD, value=FloatType()), + Diff( + "feature1__avg_over_2_days_rolling_windows", + kind=Diff.Kind.DROP, + value=None, + ), + Diff( + "feature1__avg_over_1_week_rolling_windows", + kind=Diff.Kind.ALTER_TYPE, + value=FloatType(), + ), + } + + m = CassandraMigration() + m._client = mocker.stub("client") + diff = m._get_diff(fs_schema, db_schema) + assert diff == expected_diff + + def test_apply_migration(self, feature_set, mocker): + # given + m = CassandraMigration() + m.apply_migration = mocker.stub("apply_migration") + + # when + m.apply_migration(feature_set, HistoricalFeatureStoreWriter()) + + # then + m.apply_migration.assert_called_once() diff --git a/tests/unit/butterfree/migrations/database_migration/test_metastore_migration.py b/tests/unit/butterfree/migrations/database_migration/test_metastore_migration.py new file mode 100644 index 000000000..d9c2de3c6 --- /dev/null +++ b/tests/unit/butterfree/migrations/database_migration/test_metastore_migration.py @@ -0,0 +1,49 @@ +from butterfree.migrations.database_migration import MetastoreMigration + + +class TestMetastoreMigration: + def test_queries(self, fs_schema, db_schema): + metastore_migration = MetastoreMigration() + + expected_query = [ + "ALTER TABLE test.table_name ADD IF NOT EXISTS " + "columns (new_feature FloatType);", + "ALTER TABLE table_name DROP IF EXISTS " + "(feature1__avg_over_2_days_rolling_windows None);", + "ALTER TABLE table_name ALTER COLUMN " + "feature1__avg_over_1_week_rolling_windows FloatType;", + ] + + query = metastore_migration.create_query(fs_schema, "table_name", db_schema) + + assert query, expected_query + + def test_queries_on_entity(self, fs_schema, db_schema): + metastore_migration = MetastoreMigration() + + expected_query = [ + "ALTER TABLE test.table_name ADD IF NOT EXISTS " + "columns (new_feature FloatType);", + "ALTER TABLE table_name ALTER COLUMN " + "feature1__avg_over_1_week_rolling_windows FloatType;", + ] + + query = metastore_migration.create_query( + fs_schema, "table_name", db_schema, True + ) + + assert query, expected_query + + def test_create_table_query(self, fs_schema): + + metastore_migration = MetastoreMigration() + + expected_query = [ + "CREATE TABLE IF NOT EXISTS test.table_name " + "(id LongType, timestamp TimestampType, new_feature FloatType) " + "PARTITIONED BY (year INT, month INT, day INT);" + ] + + query = metastore_migration.create_query(fs_schema, "table_name") + + assert query, expected_query diff --git a/tests/unit/butterfree/pipelines/conftest.py b/tests/unit/butterfree/pipelines/conftest.py new file mode 100644 index 000000000..f17e5f41e --- /dev/null +++ b/tests/unit/butterfree/pipelines/conftest.py @@ -0,0 +1,72 @@ +from unittest.mock import Mock + +from pyspark.sql import functions +from pytest import fixture + +from butterfree.clients import SparkClient +from butterfree.constants import DataType +from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import HistoricalFeatureStoreWriter +from butterfree.pipelines import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import SparkFunctionTransform +from butterfree.transform.utils import Function + + +@fixture() +def feature_set_pipeline(): + test_pipeline = FeatureSetPipeline( + spark_client=SparkClient(), + source=Mock( + spec=Source, + readers=[ + TableReader( + id="source_a", + database="db", + table="table", + ) + ], + query="select * from source_a", + ), + feature_set=Mock( + spec=FeatureSet, + name="feature_set", + entity="entity", + description="description", + keys=[ + KeyFeature( + name="user_id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature( + name="listing_page_viewed__rent_per_month", + description="Average of something.", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.FLOAT), + ], + ).with_window( + partition_by="user_id", + order_by=TIMESTAMP_COLUMN, + window_definition=["7 days", "2 weeks"], + mode="fixed_windows", + ), + ), + ], + ), + sink=Mock( + spec=Sink, + writers=[HistoricalFeatureStoreWriter(db_config=None)], + ), + ) + + return test_pipeline diff --git a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py index 1bc3c7071..5a67e77d4 100644 --- a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py @@ -17,15 +17,25 @@ from butterfree.load.writers.writer import Writer from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline from butterfree.transform import FeatureSet -from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature -from butterfree.transform.transformations import ( - AggregatedTransform, - SparkFunctionTransform, -) +from butterfree.transform.transformations import SparkFunctionTransform from butterfree.transform.utils import Function +def get_reader(): + table_reader = TableReader( + id="source_a", + database="db", + table="table", + ) + + return table_reader + + +def get_historical_writer(): + return HistoricalFeatureStoreWriter(db_config=None) + + class TestFeatureSetPipeline: def test_feature_set_args(self): # arrange and act @@ -42,8 +52,12 @@ def test_feature_set_args(self): pipeline = FeatureSetPipeline( source=Source( readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), + get_reader(), + FileReader( + id="source_b", + path="path", + format="parquet", + ), ], query="select a.*, b.specific_feature " "from source_a left join source_b on a.id=b.id", @@ -104,115 +118,29 @@ def test_feature_set_args(self): assert len(pipeline.sink.writers) == 2 assert all(isinstance(writer, Writer) for writer in pipeline.sink.writers) - def test_run(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=FeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ).with_window( - partition_by="user_id", - order_by=TIMESTAMP_COLUMN, - window_definition=["7 days", "2 weeks"], - mode="fixed_windows", - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) - + def test_run(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run() + feature_set_pipeline.run() - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() - - def test_run_with_repartition(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=FeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ).with_window( - partition_by="user_id", - order_by=TIMESTAMP_COLUMN, - window_definition=["7 days", "2 weeks"], - mode="fixed_windows", - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() + def test_run_with_repartition(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run(partition_by=["id"]) + feature_set_pipeline.run(partition_by=["id"]) - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() def test_source_raise(self): with pytest.raises(ValueError, match="source must be a Source instance"): @@ -221,7 +149,7 @@ def test_source_raise(self): source=Mock( spark_client=SparkClient(), readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -257,7 +185,8 @@ def test_source_raise(self): ], ), sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], + spec=Sink, + writers=[get_historical_writer()], ), ) @@ -270,7 +199,7 @@ def test_feature_set_raise(self): source=Mock( spec=Source, readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -305,7 +234,8 @@ def test_feature_set_raise(self): ], ), sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], + spec=Sink, + writers=[get_historical_writer()], ), ) @@ -316,7 +246,7 @@ def test_sink_raise(self): source=Mock( spec=Source, readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -340,55 +270,31 @@ def test_sink_raise(self): key_columns=["user_id"], timestamp_column="ts", ), - sink=Mock(writers=[HistoricalFeatureStoreWriter(db_config=None)],), + sink=Mock( + writers=[get_historical_writer()], + ), ) - def test_run_agg_with_end_date(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=AggregatedFeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=AggregatedTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) + def test_run_agg_with_end_date(self, spark_session, feature_set_pipeline): + # feature_set need to return a real df for streaming validation + sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) + feature_set_pipeline.feature_set.construct.return_value = sample_df + + feature_set_pipeline.run(end_date="2016-04-18") + + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() + def test_run_agg_with_start_date(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run(end_date="2016-04-18") + feature_set_pipeline.run(start_date="2020-08-04") - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() diff --git a/tests/unit/butterfree/reports/test_metadata.py b/tests/unit/butterfree/reports/test_metadata.py index 6f26cc558..093721df1 100644 --- a/tests/unit/butterfree/reports/test_metadata.py +++ b/tests/unit/butterfree/reports/test_metadata.py @@ -16,49 +16,63 @@ from butterfree.transform.utils import Function +def get_pipeline(): + + return FeatureSetPipeline( + source=Source( + readers=[ + TableReader( + id="source_a", + database="db", + table="table", + ), + FileReader( + id="source_b", + path="path", + format="parquet", + ), + ], + query="select a.*, b.specific_feature " + "from source_a left join source_b on a.id=b.id", + ), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + keys=[ + KeyFeature( + name="user_id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature( + name="page_viewed__rent_per_month", + description="Average of something.", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.DOUBLE), + ], + ), + ), + ], + ), + sink=Sink( + writers=[ + HistoricalFeatureStoreWriter(db_config=None), + OnlineFeatureStoreWriter(db_config=None), + ], + ), + ) + + class TestMetadata: def test_json(self): - pipeline = FeatureSetPipeline( - source=Source( - readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), - ], - query="select a.*, b.specific_feature " - "from source_a left join source_b on a.id=b.id", - ), - feature_set=FeatureSet( - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.DOUBLE), - ], - ), - ), - ], - ), - sink=Sink( - writers=[ - HistoricalFeatureStoreWriter(db_config=None), - OnlineFeatureStoreWriter(db_config=None), - ], - ), - ) + + pipeline = get_pipeline() target_json = [ { @@ -102,47 +116,8 @@ def test_json(self): assert json == target_json def test_markdown(self): - pipeline = FeatureSetPipeline( - source=Source( - readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), - ], - query="select a.*, b.specific_feature " - "from source_a left join source_b on a.id=b.id", - ), - feature_set=FeatureSet( - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.DOUBLE), - ], - ), - ), - ], - ), - sink=Sink( - writers=[ - HistoricalFeatureStoreWriter(db_config=None), - OnlineFeatureStoreWriter(db_config=None), - ], - ), - ) + + pipeline = get_pipeline() target_md = ( "\n# Feature_set\n\n## Description\n\n\ndescription \n\n" diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py index 2d7d3e50c..c0ebb47af 100644 --- a/tests/unit/butterfree/transform/conftest.py +++ b/tests/unit/butterfree/transform/conftest.py @@ -1,11 +1,99 @@ import json from unittest.mock import Mock +import pyspark.pandas as ps +from pyspark.sql import functions +from pyspark.sql.functions import col +from pyspark.sql.types import TimestampType from pytest import fixture from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.transform import FeatureSet +from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import ( + AggregatedTransform, + SparkFunctionTransform, +) +from butterfree.transform.utils import Function + + +def create_dataframe(data, timestamp_col="ts"): + pdf = ps.DataFrame.from_dict(data) + df = pdf.to_spark() + df = df.withColumn( + TIMESTAMP_COLUMN, df[timestamp_col].cast(DataType.TIMESTAMP.spark) + ) + return df + + +def create_dataframe_from_data( + spark_context, spark_session, data, timestamp_col="timestamp", use_json=False +): + if use_json: + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + else: + df = create_dataframe(data, timestamp_col=timestamp_col) + + df = df.withColumn(timestamp_col, col(timestamp_col).cast(TimestampType())) + return df + + +def create_rolling_windows_agg_dataframe( + spark_context, spark_session, data, timestamp_col="timestamp", use_json=False +): + if use_json: + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + df = df.withColumn( + timestamp_col, col(timestamp_col).cast(DataType.TIMESTAMP.spark) + ) + else: + df = create_dataframe(data, timestamp_col=timestamp_col) + + return df + + +def build_data(rows, base_features, dynamic_features=None): + """ + Constrói uma lista de dicionários para DataFrame com recursos dinâmicos. + + :param rows: Lista de tuplas com (id, timestamp, base_values, dynamic_values). + :param base_features: Lista de nomes de recursos base (strings). + :param dynamic_features: Lista de nomes de recursos dinâmicos, + mapeando para o índice de dynamic_values (opcional). + :return: Lista de dicionários para criação do DataFrame. + """ + data = [] + for row in rows: + id_value, timestamp_value, base_values, dynamic_values = row + + entry = { + "id": id_value, + "timestamp": timestamp_value, + } + + # Adiciona valores das features base + entry.update( + {feature: value for feature, value in zip(base_features, base_values)} + ) + + # Adiciona valores das features dinâmicas, se houver + if dynamic_features: + entry.update( + { + feature: dynamic_values[idx] + for idx, feature in enumerate(dynamic_features) + } + ) + + data.append(entry) + + return data def make_dataframe(spark_context, spark_session): @@ -46,10 +134,7 @@ def make_dataframe(spark_context, spark_session): "nonfeature": 0, }, ] - df = spark_session.read.json(spark_context.parallelize(data, 1)) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_filtering_dataframe(spark_context, spark_session): @@ -62,12 +147,7 @@ def make_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, {"id": 1, "ts": 7, "feature1": None, "feature2": None, "feature3": None}, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_output_filtering_dataframe(spark_context, spark_session): @@ -78,53 +158,94 @@ def make_output_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 4, "feature1": 0, "feature2": 1, "feature3": 1}, {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_rolling_windows_agg_dataframe(spark_context, spark_session): - data = [ - { - "id": 1, - "timestamp": "2016-04-11 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, - { - "id": 1, - "timestamp": "2016-04-12 00:00:00", - "feature1__avg_over_1_week_rolling_windows": 300.0, - "feature2__avg_over_1_week_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-19 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, - { - "id": 1, - "timestamp": "2016-04-23 00:00:00", - "feature1__avg_over_1_week_rolling_windows": 1000.0, - "feature2__avg_over_1_week_rolling_windows": 1100.0, - }, - { - "id": 1, - "timestamp": "2016-04-30 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, + rows = [ + (1, "2016-04-11 00:00:00", [None, None], None), + (1, "2016-04-12 00:00:00", [300.0, 350.0], None), + (1, "2016-04-19 00:00:00", [None, None], None), + (1, "2016-04-23 00:00:00", [1000.0, 1100.0], None), + (1, "2016-04-30 00:00:00", [None, None], None), ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - return df + base_features = [ + "feature1__avg_over_1_week_rolling_windows", + "feature2__avg_over_1_week_rolling_windows", + ] + + data = build_data(rows, base_features) + return create_dataframe_from_data(spark_context, spark_session, data) + + +def make_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): + rows = [ + (1, "2016-04-11 12:00:00", [266.6666666666667, 300.0], None), + (1, "2016-04-12 00:00:00", [300.0, 350.0], None), + (1, "2016-04-12 12:00:00", [400.0, 500.0], None), + ] + + base_features = [ + "feature1__avg_over_1_day_rolling_windows", + "feature2__avg_over_1_day_rolling_windows", + ] + + data = build_data(rows, base_features) + return create_dataframe_from_data(spark_context, spark_session, data) + + +def make_multiple_rolling_windows_hour_slide_agg_dataframe( + spark_context, spark_session +): + rows = [ + ( + 1, + "2016-04-11 12:00:00", + [], + [266.6666666666667, 266.6666666666667, 300.0, 300.0], + ), + (1, "2016-04-12 00:00:00", [], [300.0, 300.0, 350.0, 350.0]), + (1, "2016-04-13 12:00:00", [], [400.0, 300.0, 500.0, 350.0]), + (1, "2016-04-14 00:00:00", [], [None, 300.0, None, 350.0]), + (1, "2016-04-14 12:00:00", [], [None, 400.0, None, 500.0]), + ] + + dynamic_features = [ + "feature1__avg_over_2_days_rolling_windows", + "feature1__avg_over_3_days_rolling_windows", + "feature2__avg_over_2_days_rolling_windows", + "feature2__avg_over_3_days_rolling_windows", + ] + + data = build_data(rows, [], dynamic_features=dynamic_features) + return create_dataframe_from_data(spark_context, spark_session, data, use_json=True) + + +def create_rolling_window_dataframe( + spark_context, spark_session, rows, base_features, dynamic_features=None +): + """ + Cria um DataFrame com recursos de rolagem de janelas agregadas. + + :param spark_context: Contexto do Spark. + :param spark_session: Sessão do Spark. + :param rows: Lista de tuplas com (id, timestamp, base_values, dynamic_values). + :param base_features: Lista de nomes de recursos base (strings). + :param dynamic_features: Lista de nomes de recursos dinâmicos, + mapeando para o índice de dynamic_values (opcional). + :return: DataFrame do Spark. + """ + data = build_data(rows, base_features, dynamic_features) + + # Converte a lista de dicionários em um RDD do Spark + rdd = spark_context.parallelize(data).map(lambda x: json.dumps(x)) + + # Cria o DataFrame do Spark a partir do RDD + df = spark_session.read.json(rdd) + + # Converte a coluna "timestamp" para o tipo TIMESTAMP + df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) def make_fs(spark_context, spark_session): @@ -171,8 +292,7 @@ def make_fs_dataframe_with_distinct(spark_context, spark_session): "h3": "86a8100efffffff", }, ] - df = spark_session.read.json(spark_context.parallelize(data, 1)) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df @@ -200,10 +320,7 @@ def make_target_df_distinct(spark_context, spark_session): "feature__sum_over_3_days_rolling_windows": None, }, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df @@ -233,6 +350,18 @@ def rolling_windows_agg_dataframe(spark_context, spark_session): return make_rolling_windows_agg_dataframe(spark_context, spark_session) +@fixture +def rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): + return make_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session) + + +@fixture +def multiple_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): + return make_multiple_rolling_windows_hour_slide_agg_dataframe( + spark_context, spark_session + ) + + @fixture def feature_set_with_distinct_dataframe(spark_context, spark_session): return make_fs_dataframe_with_distinct(spark_context, spark_session) @@ -297,3 +426,72 @@ def key_id(): @fixture def timestamp_c(): return TimestampFeature() + + +@fixture +def feature_set(): + feature_set = FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.DOUBLE), + ] + ).with_window( + partition_by="id", + order_by=TIMESTAMP_COLUMN, + mode="fixed_windows", + window_definition=["2 minutes", "15 minutes"], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ) + + return feature_set + + +@fixture +def agg_feature_set(): + return AggregatedFeatureSet( + name="name", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.avg, DataType.DOUBLE)], + ), + ), + Feature( + name="feature2", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.avg, DataType.DOUBLE)] + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="description", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ) diff --git a/tests/unit/butterfree/transform/features/conftest.py b/tests/unit/butterfree/transform/features/conftest.py index e79c5075f..ae6444703 100644 --- a/tests/unit/butterfree/transform/features/conftest.py +++ b/tests/unit/butterfree/transform/features/conftest.py @@ -18,8 +18,8 @@ def feature_set_dataframe(spark_context, spark_session): @fixture def feature_set_dataframe_ms_from_column(spark_context, spark_session): data = [ - {"id": 1, "ts": 1581542311000, "feature": 100}, - {"id": 2, "ts": 1581542322000, "feature": 200}, + {"id": 1, "ts": 1581542311112, "feature": 100}, + {"id": 2, "ts": 1581542322223, "feature": 200}, ] return spark_session.read.json(spark_context.parallelize(data, 1)) @@ -27,8 +27,17 @@ def feature_set_dataframe_ms_from_column(spark_context, spark_session): @fixture def feature_set_dataframe_ms(spark_context, spark_session): data = [ - {"id": 1, TIMESTAMP_COLUMN: 1581542311000, "feature": 100}, - {"id": 2, TIMESTAMP_COLUMN: 1581542322000, "feature": 200}, + {"id": 1, TIMESTAMP_COLUMN: 1581542311112, "feature": 100}, + {"id": 2, TIMESTAMP_COLUMN: 1581542322223, "feature": 200}, + ] + return spark_session.read.json(spark_context.parallelize(data, 1)) + + +@fixture +def feature_set_dataframe_small_time_diff(spark_context, spark_session): + data = [ + {"id": 1, TIMESTAMP_COLUMN: 1581542311001, "feature": 100}, + {"id": 2, TIMESTAMP_COLUMN: 1581542311002, "feature": 200}, ] return spark_session.read.json(spark_context.parallelize(data, 1)) diff --git a/tests/unit/butterfree/transform/features/test_feature.py b/tests/unit/butterfree/transform/features/test_feature.py index 14a89f2cf..01bb41e5a 100644 --- a/tests/unit/butterfree/transform/features/test_feature.py +++ b/tests/unit/butterfree/transform/features/test_feature.py @@ -98,7 +98,9 @@ def test_feature_transform_with_from_column_and_column_name_exists( def test_feature_transform_with_dtype(self, feature_set_dataframe): test_feature = Feature( - name="feature", description="unit test", dtype=DataType.TIMESTAMP, + name="feature", + description="unit test", + dtype=DataType.TIMESTAMP, ) df = test_feature.transform(feature_set_dataframe) diff --git a/tests/unit/butterfree/transform/features/test_timestamp_feature.py b/tests/unit/butterfree/transform/features/test_timestamp_feature.py index c77450362..42ab40a2c 100644 --- a/tests/unit/butterfree/transform/features/test_timestamp_feature.py +++ b/tests/unit/butterfree/transform/features/test_timestamp_feature.py @@ -1,18 +1,26 @@ -from pyspark.sql.types import StringType +from datetime import datetime +import pytz +from pyspark.sql.types import StringType, StructField, StructType + +from butterfree.clients import SparkClient from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN from butterfree.transform.features import TimestampFeature +# from pyspark.sql.types import * + class TestTimestampFeature: def test_args_without_transformation(self): test_key = TimestampFeature(from_column="ts") + test_key_ntz = TimestampFeature(dtype=DataType.TIMESTAMP_NTZ, from_column="ts") assert test_key.name == TIMESTAMP_COLUMN assert test_key.from_column == "ts" assert test_key.dtype == DataType.TIMESTAMP + assert test_key_ntz.dtype == DataType.TIMESTAMP_NTZ def test_transform(self, feature_set_dataframe): @@ -32,8 +40,8 @@ def test_transform_ms_from_column(self, feature_set_dataframe_ms_from_column): df = df.withColumn("timestamp", df["timestamp"].cast(StringType())).collect() - assert df[0]["timestamp"] == "2020-02-12 21:18:31" - assert df[1]["timestamp"] == "2020-02-12 21:18:42" + assert df[0]["timestamp"] == "2020-02-12 21:18:31.112" + assert df[1]["timestamp"] == "2020-02-12 21:18:42.223" def test_transform_ms(self, feature_set_dataframe_ms): @@ -43,8 +51,22 @@ def test_transform_ms(self, feature_set_dataframe_ms): df = df.withColumn("timestamp", df["timestamp"].cast(StringType())).collect() - assert df[0]["timestamp"] == "2020-02-12 21:18:31" - assert df[1]["timestamp"] == "2020-02-12 21:18:42" + assert df[0]["timestamp"] == "2020-02-12 21:18:31.112" + assert df[1]["timestamp"] == "2020-02-12 21:18:42.223" + + def test_transform_ms_from_column_small_time_diff( + self, feature_set_dataframe_small_time_diff + ): + + test_key = TimestampFeature(from_ms=True) + + df = test_key.transform(feature_set_dataframe_small_time_diff).orderBy( + "timestamp" + ) + + df = df.withColumn("timestamp", df["timestamp"].cast(StringType())).collect() + + assert df[0]["timestamp"] != df[1]["timestamp"] def test_transform_mask(self, feature_set_dataframe_date): @@ -56,3 +78,73 @@ def test_transform_mask(self, feature_set_dataframe_date): assert df[0]["timestamp"] == "2020-02-07 00:00:00" assert df[1]["timestamp"] == "2020-02-08 00:00:00" + + def test_timezone_configs(self): + + spark = SparkClient() + now = datetime.now() + + # Testing a new timezone + spark.conn.conf.set("spark.sql.session.timeZone", "GMT-5") + + time_list = [(now, now)] + rdd = spark.conn.sparkContext.parallelize(time_list) + + schema = StructType( + [ + StructField("ts", DataType.TIMESTAMP.spark, True), + StructField("ts_ntz", DataType.TIMESTAMP_NTZ.spark, True), + ] + ) + df = spark.conn.createDataFrame(rdd, schema) + df.createOrReplaceTempView("temp_tz_table") + + df1 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""") + df2 = df1.withColumns( + {"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())} + ) + df2_vals = df2.collect()[0] + + assert df2_vals.ts != df2_vals.ts_ntz + + # New TZ. Column with TZ must have a != value; Column NTZ must keep its value + spark.conn.conf.set("spark.sql.session.timeZone", "GMT-7") + + df3 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""") + df4 = df3.withColumns( + {"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())} + ) + df4_vals = df4.collect()[0] + + assert df4_vals.ts != df2_vals.ts + assert df4_vals.ts_ntz == df2_vals.ts_ntz + + def test_timezone(self): + + spark = SparkClient() + + my_date = datetime.now(pytz.timezone("US/Pacific")) + + datetime_mask = "%Y-%m-%d %H:%M" + + data = [ + {"id": 1, TIMESTAMP_COLUMN: str(my_date), "feature": 100}, + {"id": 2, TIMESTAMP_COLUMN: str(my_date), "feature": 200}, + ] + + df = spark.conn.read.json(spark.conn._sc.parallelize(data, 1)) + df.createOrReplaceTempView("time_table") + + df2 = spark.sql("SELECT TIMESTAMP AS ts FROM time_table") + + time_value = datetime.fromisoformat(df2.collect()[0].ts).strftime(datetime_mask) + + df_different_timezone = df2.withColumn( + "ts", df2.ts.cast(DataType.TIMESTAMP.spark) + ) + df_no_timezone = df2.withColumn("ts", df2.ts.cast(DataType.TIMESTAMP_NTZ.spark)) + + assert ( + df_different_timezone.collect()[0].ts.strftime(datetime_mask) != time_value + ) + assert df_no_timezone.collect()[0].ts.strftime(datetime_mask) == time_value diff --git a/tests/unit/butterfree/transform/test_aggregated_feature_set.py b/tests/unit/butterfree/transform/test_aggregated_feature_set.py index 2c404feab..38ec249aa 100644 --- a/tests/unit/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/unit/butterfree/transform/test_aggregated_feature_set.py @@ -1,13 +1,6 @@ import pytest from pyspark.sql import functions -from pyspark.sql.types import ( - ArrayType, - DoubleType, - FloatType, - LongType, - StringType, - TimestampType, -) +from pyspark.sql.types import DoubleType, LongType, TimestampType from butterfree.clients import SparkClient from butterfree.constants import DataType @@ -51,33 +44,14 @@ def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe): ).construct(dataframe, spark_client) def test_agg_feature_set_with_window( - self, key_id, timestamp_c, dataframe, rolling_windows_agg_dataframe + self, + dataframe, + rolling_windows_agg_dataframe, + agg_feature_set, ): spark_client = SparkClient() - fs = AggregatedFeatureSet( - name="name", - entity="entity", - description="description", - features=[ - Feature( - name="feature1", - description="unit test", - transformation=AggregatedTransform( - functions=[Function(functions.avg, DataType.FLOAT)] - ), - ), - Feature( - name="feature2", - description="unit test", - transformation=AggregatedTransform( - functions=[Function(functions.avg, DataType.FLOAT)] - ), - ), - ], - keys=[key_id], - timestamp=timestamp_c, - ).with_windows(definitions=["1 week"]) + fs = agg_feature_set.with_windows(definitions=["1 week"]) # raises without end date with pytest.raises(ValueError): @@ -89,7 +63,47 @@ def test_agg_feature_set_with_window( output_df = fs.construct(dataframe, spark_client, end_date="2016-05-01") assert_dataframe_equality(output_df, rolling_windows_agg_dataframe) - def test_get_schema(self): + def test_agg_feature_set_with_smaller_slide( + self, + dataframe, + rolling_windows_hour_slide_agg_dataframe, + agg_feature_set, + ): + spark_client = SparkClient() + + fs = agg_feature_set.with_windows(definitions=["1 day"], slide="12 hours") + + # raises without end date + with pytest.raises(ValueError): + _ = fs.construct(dataframe, spark_client) + + # filters with date smaller then mocked max + output_df = fs.construct(dataframe, spark_client, end_date="2016-04-17") + assert_dataframe_equality(output_df, rolling_windows_hour_slide_agg_dataframe) + + def test_agg_feature_set_with_smaller_slide_and_multiple_windows( + self, + dataframe, + multiple_rolling_windows_hour_slide_agg_dataframe, + agg_feature_set, + ): + spark_client = SparkClient() + + fs = agg_feature_set.with_windows( + definitions=["2 days", "3 days"], slide="12 hours" + ) + + # raises without end date + with pytest.raises(ValueError): + _ = fs.construct(dataframe, spark_client) + + # filters with date smaller then mocked max + output_df = fs.construct(dataframe, spark_client, end_date="2016-04-17") + assert_dataframe_equality( + output_df, multiple_rolling_windows_hour_slide_agg_dataframe + ) + + def test_get_schema(self, agg_feature_set): expected_schema = [ {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, @@ -104,61 +118,20 @@ def test_get_schema(self): "primary_key": False, }, { - "column_name": "feature1__stddev_pop_over_1_week_rolling_windows", - "type": FloatType(), - "primary_key": False, - }, - { - "column_name": "feature1__stddev_pop_over_2_days_rolling_windows", - "type": FloatType(), - "primary_key": False, - }, - { - "column_name": "feature2__count_over_1_week_rolling_windows", - "type": ArrayType(StringType(), True), + "column_name": "feature2__avg_over_1_week_rolling_windows", + "type": DoubleType(), "primary_key": False, }, { - "column_name": "feature2__count_over_2_days_rolling_windows", - "type": ArrayType(StringType(), True), + "column_name": "feature2__avg_over_2_days_rolling_windows", + "type": DoubleType(), "primary_key": False, }, ] - feature_set = AggregatedFeatureSet( - name="feature_set", - entity="entity", - description="description", - features=[ - Feature( - name="feature1", - description="test", - transformation=AggregatedTransform( - functions=[ - Function(functions.avg, DataType.DOUBLE), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ), - ), - Feature( - name="feature2", - description="test", - transformation=AggregatedTransform( - functions=[Function(functions.count, DataType.ARRAY_STRING)] - ), - ), - ], - keys=[ - KeyFeature( - name="id", - description="The user's Main ID or device ID", - dtype=DataType.BIGINT, - ) - ], - timestamp=TimestampFeature(), - ).with_windows(definitions=["1 week", "2 days"]) - - schema = feature_set.get_schema() + schema = agg_feature_set.with_windows( + definitions=["1 week", "2 days"] + ).get_schema() assert schema == expected_schema @@ -389,3 +362,38 @@ def test_feature_transform_with_data_type_array(self, spark_context, spark_sessi # assert assert_dataframe_equality(target_df, output_df) + + def test_define_start_date(self, agg_feature_set): + start_date = agg_feature_set.with_windows( + definitions=["1 week", "2 days"] + ).define_start_date("2020-08-04") + + assert isinstance(start_date, str) + assert start_date == "2020-07-27" + + def test_feature_set_start_date( + self, + timestamp_c, + feature_set_with_distinct_dataframe, + ): + fs = AggregatedFeatureSet( + name="name", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.sum, DataType.INTEGER)] + ), + ), + ], + keys=[KeyFeature(name="h3", description="test", dtype=DataType.STRING)], + timestamp=timestamp_c, + ).with_windows(["10 days", "3 weeks", "90 days"]) + + # assert + start_date = fs.define_start_date("2016-04-14") + + assert start_date == "2016-01-14" diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py index bdb1ff7d4..e907dc0a8 100644 --- a/tests/unit/butterfree/transform/test_feature_set.py +++ b/tests/unit/butterfree/transform/test_feature_set.py @@ -3,25 +3,23 @@ import pytest from pyspark.sql import functions as F from pyspark.sql.types import DoubleType, FloatType, LongType, TimestampType -from tests.unit.butterfree.transform.conftest import ( - feature_add, - feature_divide, - key_id, - timestamp_c, -) from butterfree.clients import SparkClient from butterfree.constants import DataType -from butterfree.constants.columns import TIMESTAMP_COLUMN from butterfree.testing.dataframe import assert_dataframe_equality from butterfree.transform import FeatureSet -from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.features import Feature from butterfree.transform.transformations import ( AggregatedTransform, - SparkFunctionTransform, SQLExpressionTransform, ) from butterfree.transform.utils import Function +from tests.unit.butterfree.transform.conftest import ( + feature_add, + feature_divide, + key_id, + timestamp_c, +) class TestFeatureSet: @@ -72,7 +70,14 @@ class TestFeatureSet: None, [feature_add, feature_divide], ), - ("name", "entity", "description", [key_id], timestamp_c, [None],), + ( + "name", + "entity", + "description", + [key_id], + timestamp_c, + [None], + ), ], ) def test_cannot_instantiate( @@ -341,7 +346,7 @@ def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe): timestamp=timestamp_c, ).construct(dataframe, spark_client) - def test_get_schema(self): + def test_get_schema(self, feature_set): expected_schema = [ {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, @@ -367,37 +372,6 @@ def test_get_schema(self): }, ] - feature_set = FeatureSet( - name="feature_set", - entity="entity", - description="description", - features=[ - Feature( - name="feature1", - description="test", - transformation=SparkFunctionTransform( - functions=[ - Function(F.avg, DataType.FLOAT), - Function(F.stddev_pop, DataType.DOUBLE), - ] - ).with_window( - partition_by="id", - order_by=TIMESTAMP_COLUMN, - mode="fixed_windows", - window_definition=["2 minutes", "15 minutes"], - ), - ), - ], - keys=[ - KeyFeature( - name="id", - description="The user's Main ID or device ID", - dtype=DataType.BIGINT, - ) - ], - timestamp=TimestampFeature(), - ) - schema = feature_set.get_schema() assert schema == expected_schema @@ -421,3 +395,9 @@ def test_feature_without_datatype(self, key_id, timestamp_c, dataframe): keys=[key_id], timestamp=timestamp_c, ).construct(dataframe, spark_client) + + def test_define_start_date(self, feature_set): + start_date = feature_set.define_start_date("2020-08-04") + + assert isinstance(start_date, str) + assert start_date == "2020-08-04" diff --git a/tests/unit/butterfree/transform/transformations/conftest.py b/tests/unit/butterfree/transform/transformations/conftest.py index 8f3c13bff..41bc63d5c 100644 --- a/tests/unit/butterfree/transform/transformations/conftest.py +++ b/tests/unit/butterfree/transform/transformations/conftest.py @@ -62,7 +62,7 @@ def target_df_spark(spark_context, spark_session): "timestamp": "2016-04-11 11:31:11", "feature1": 200, "feature2": 200, - "feature__cos": 0.4871876750070059, + "feature__cos": 0.48718767500700594, }, { "id": 1, diff --git a/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py b/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py index 6cdebf74d..96ff682a8 100644 --- a/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py @@ -44,7 +44,10 @@ def test_output_columns(self): assert all( [ a == b - for a, b in zip(df_columns, ["feature1__avg", "feature1__stddev_pop"],) + for a, b in zip( + df_columns, + ["feature1__avg", "feature1__stddev_pop"], + ) ] ) @@ -64,7 +67,7 @@ def test_blank_aggregation(self, feature_set_dataframe): name="feature1", description="unit test", transformation=AggregatedTransform( - functions=[Function(func="", data_type="")] + functions=[Function(func=None, data_type="")] ), ) diff --git a/tests/unit/butterfree/transform/transformations/test_custom_transform.py b/tests/unit/butterfree/transform/transformations/test_custom_transform.py index 4198d9bda..d87cc7cb1 100644 --- a/tests/unit/butterfree/transform/transformations/test_custom_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_custom_transform.py @@ -21,7 +21,9 @@ def test_feature_transform(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) @@ -44,7 +46,9 @@ def test_output_columns(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) @@ -59,7 +63,9 @@ def test_custom_transform_output(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) diff --git a/tests/unit/butterfree/transform/transformations/test_h3_transform.py b/tests/unit/butterfree/transform/transformations/test_h3_transform.py index 4b3308ebe..d4ad6493e 100644 --- a/tests/unit/butterfree/transform/transformations/test_h3_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_h3_transform.py @@ -64,9 +64,9 @@ def test_import_error(self): for m in modules: del sys.modules[m] with pytest.raises(ModuleNotFoundError, match="you must install"): - from butterfree.transform.transformations.h3_transform import ( # noqa - H3HashTransform, # noqa - ) # noqa + from butterfree.transform.transformations.h3_transform import ( # noqa; noqa + H3HashTransform, + ) def test_with_stack(self, h3_input_df, h3_with_stack_target_df): # arrange diff --git a/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py b/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py index fe8bca85c..cf88657a0 100644 --- a/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py @@ -126,7 +126,9 @@ def test_feature_transform_output_row_windows( transformation=SparkFunctionTransform( functions=[Function(functions.avg, DataType.DOUBLE)], ).with_window( - partition_by="id", mode="row_windows", window_definition=["2 events"], + partition_by="id", + mode="row_windows", + window_definition=["2 events"], ), ) diff --git a/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py b/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py index 9cc2e687e..814f83012 100644 --- a/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py @@ -43,7 +43,15 @@ def test_output_columns(self): df_columns = test_feature.get_output_columns() - assert all([a == b for a, b in zip(df_columns, ["feature1_over_feature2"],)]) + assert all( + [ + a == b + for a, b in zip( + df_columns, + ["feature1_over_feature2"], + ) + ] + ) def test_feature_transform_output(self, feature_set_dataframe): test_feature = Feature(