From 90ab2c82009c051c7f1871b40575b021185eca57 Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 18 Jun 2021 14:13:43 +0000 Subject: [PATCH] Changes related to the MOlPCQM --- .authors.yml | 52 +- .github/PULL_REQUEST_TEMPLATE.md | 10 +- .github/workflows/lint.yml | 62 +- .github/workflows/test.yml | 106 +- .gitignore | 298 +-- .mailmap | 32 +- AUTHORS.rst | 16 +- CHANGELOG.rst | 66 +- CODE_OF_CONDUCT.md | 256 +- CONTRIBUTING.md | 6 +- LICENSE | 402 ++-- MANIFEST.in | 16 +- README.md | 136 +- SECURITY.md | 6 +- docs/_assets/css/custom.css | 66 +- docs/api/goli.data.md | 2 +- docs/api/goli.features.md | 2 +- docs/api/goli.nn.md | 2 +- docs/api/goli.trainer.md | 2 +- docs/api/goli.utils.md | 2 +- docs/cli_references.md | 16 +- docs/contribute.md | 86 +- docs/datasets.md | 86 +- docs/design.md | 228 +- docs/images/logo-title.svg | 376 +-- docs/images/logo.svg | 354 +-- docs/index.md | 82 +- docs/license.md | 6 +- docs/pretrained_models.md | 30 +- .../basics/implementing_gnn_layers.ipynb | 640 ++--- .../basics/making_gnn_networks.ipynb | 454 ++-- docs/tutorials/basics/using_gnn_layers.ipynb | 748 +++--- .../simple-molecular-model.ipynb | 1908 +++++++-------- env.yml | 132 +- expts/config_ZINC_bench_gnn.yaml | 316 +-- expts/config_bindingDB_pretrained.yaml | 102 +- expts/config_htsfp_pcba.yaml | 404 ++-- expts/config_micro_ZINC.yaml | 336 +-- expts/config_molHIV.yaml | 328 +-- expts/config_molHIV_pretrained.yaml | 90 +- expts/config_molPCBA.yaml | 392 ++-- expts/config_molPCQM4M.yaml | 363 +-- expts/config_molbace_pretrained.yaml | 90 +- expts/config_mollipo_pretrained.yaml | 90 +- expts/config_moltox21_pretrained.yaml | 90 +- ...config_single_atom_dataset_pretrained.yaml | 102 +- expts/data/micro_zinc_splits.csv | 1202 +++++----- expts/data/tiny_zinc_splits.csv | 122 +- expts/example_zinc.yaml | 44 +- expts/main_run.py | 153 +- expts/main_run_predict.py | 156 +- expts/main_run_test.py | 100 +- goli/__init__.py | 18 +- goli/_version.py | 2 +- goli/cli/__init__.py | 4 +- goli/cli/data.py | 114 +- goli/cli/main.py | 22 +- goli/config/__init__.py | 14 +- goli/config/_load.py | 32 +- goli/config/_loader.py | 302 +-- goli/config/config_convert.py | 60 +- goli/config/zinc_default_fulldgl.yaml | 144 +- goli/data/__init__.py | 16 +- goli/data/collate.py | 74 +- goli/data/datamodule.py | 1914 +++++++-------- .../single_atom_dataset.csv | 14 +- goli/data/utils.py | 144 +- goli/features/__init__.py | 12 +- goli/features/featurizer.py | 1467 ++++++------ goli/features/nmp.py | 174 +- goli/features/positional_encoding.py | 202 +- goli/features/properties.py | 212 +- goli/features/spectral.py | 278 +-- goli/nn/__init__.py | 6 +- goli/nn/architectures.py | 2076 ++++++++--------- goli/nn/base_layers.py | 634 ++--- goli/nn/dgl_layers/__init__.py | 14 +- goli/nn/dgl_layers/base_dgl_layer.py | 350 +-- goli/nn/dgl_layers/dgn_layer.py | 338 +-- goli/nn/dgl_layers/gat_layer.py | 320 +-- goli/nn/dgl_layers/gated_gcn_layer.py | 380 +-- goli/nn/dgl_layers/gcn_layer.py | 308 +-- goli/nn/dgl_layers/gin_layer.py | 388 +-- goli/nn/dgl_layers/pna_layer.py | 1168 +++++----- goli/nn/dgl_layers/pooling.py | 672 +++--- goli/nn/dgn_operations.py | 594 ++--- goli/nn/pna_operations.py | 176 +- goli/nn/residual_connections.py | 1046 ++++----- goli/trainer/__init__.py | 10 +- goli/trainer/metrics.py | 628 ++--- goli/trainer/model_summary.py | 148 +- goli/trainer/predictor.py | 1144 ++++----- goli/utils/__init__.py | 4 +- goli/utils/arg_checker.py | 418 ++-- goli/utils/decorators.py | 32 +- goli/utils/fs.py | 418 ++-- goli/utils/read_file.py | 318 +-- goli/utils/spaces.py | 100 +- goli/utils/tensor.py | 536 ++--- goli/visualization/vis_utils.py | 72 +- mkdocs.yml | 166 +- news/TEMPLATE.rst | 46 +- news/cache.rst | 46 +- news/datasets.rst | 48 +- news/ogb.rst | 46 +- .../dev-datamodule-invalidate-cache.ipynb | 898 +++---- notebooks/dev-datamodule-ogb.ipynb | 878 +++---- notebooks/dev-datamodule.ipynb | 858 +++---- notebooks/dev-pretrained.ipynb | 954 ++++---- notebooks/dev-training-loop.ipynb | 1826 +++++++-------- notebooks/dev.ipynb | 226 +- notebooks/running-model-from-config.ipynb | 440 ++-- pyproject.toml | 46 +- rever.xsh | 56 +- setup.py | 34 +- tests/conftest.py | 22 +- tests/data/config_micro_ZINC.yaml | 336 +-- tests/test_architectures.py | 1266 +++++----- tests/test_data_utils.py | 40 +- tests/test_datamodule.py | 418 ++-- tests/test_featurizer.py | 434 ++-- tests/test_gnn_layers.py | 730 +++--- tests/test_metrics.py | 298 +-- tests/test_positional_encoding.py | 200 +- tests/test_predictor_module.py | 100 +- tests/test_residual_connections.py | 472 ++-- tests/test_utils.py | 164 +- 127 files changed, 19875 insertions(+), 19856 deletions(-) diff --git a/.authors.yml b/.authors.yml index 3c13495e3..b87bc6b65 100644 --- a/.authors.yml +++ b/.authors.yml @@ -1,26 +1,26 @@ -- name: Dom - email: dominique@invivoai.com - num_commits: 103 - first_commit: 2021-01-20 01:31:50 - github: invivoai -- name: DomInvivo - email: 47570400+DomInvivo@users.noreply.github.com - num_commits: 2 - first_commit: 2021-01-19 15:30:48 - github: invivoai -- name: Hadrien Mary - email: hadrien.mary@gmail.com - alternate_emails: - - hadim@users.noreply.github.com - num_commits: 46 - first_commit: 2021-01-20 10:31:10 - github: invivoai -- name: Ubuntu - email: ubuntu@ip-172-31-19-91.us-east-2.compute.internal - num_commits: 8 - first_commit: 2021-03-23 12:26:32 -- name: Therence1 - email: 38595485+Therence1@users.noreply.github.com - num_commits: 1 - first_commit: 2021-03-21 11:41:43 - github: Therence1 +- name: Dom + email: dominique@invivoai.com + num_commits: 103 + first_commit: 2021-01-20 01:31:50 + github: invivoai +- name: DomInvivo + email: 47570400+DomInvivo@users.noreply.github.com + num_commits: 2 + first_commit: 2021-01-19 15:30:48 + github: invivoai +- name: Hadrien Mary + email: hadrien.mary@gmail.com + alternate_emails: + - hadim@users.noreply.github.com + num_commits: 46 + first_commit: 2021-01-20 10:31:10 + github: invivoai +- name: Ubuntu + email: ubuntu@ip-172-31-19-91.us-east-2.compute.internal + num_commits: 8 + first_commit: 2021-03-23 12:26:32 +- name: Therence1 + email: 38595485+Therence1@users.noreply.github.com + num_commits: 1 + first_commit: 2021-03-21 11:41:43 + github: Therence1 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 56db38d4a..03d53df63 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,5 +1,5 @@ -Checklist: - -- [ ] Added a `news` entry: _copy `news/TEMPLATE.rst` to `news/my-feature-or-branch.rst`) and edit it._ - ---- +Checklist: + +- [ ] Added a `news` entry: _copy `news/TEMPLATE.rst` to `news/my-feature-or-branch.rst`) and edit it._ + +--- diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f7427062e..daf282303 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,31 +1,31 @@ -name: Lint - -on: - push: - branches: - - "master" - pull_request: - branches: - - "*" - -jobs: - lint-source-code: - name: Lint - runs-on: "ubuntu-latest" - - steps: - - name: Checkout the code - uses: actions/checkout@v2 - - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Install black - run: | - pip install black==20.8b1 - - - name: Lint - run: | - black --check . +name: Lint + +on: + push: + branches: + - "master" + pull_request: + branches: + - "*" + +jobs: + lint-source-code: + name: Lint + runs-on: "ubuntu-latest" + + steps: + - name: Checkout the code + uses: actions/checkout@v2 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install black + run: | + pip install black==20.8b1 + + - name: Lint + run: | + black --check . diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4aea93bff..a7a553f6c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,53 +1,53 @@ -name: test - -on: - push: - branches: - - "master" - pull_request: - branches: - - "*" - - "!privpage" - -jobs: - test: - runs-on: "ubuntu-latest" - defaults: - run: - shell: bash -l {0} - steps: - - name: Checkout the code - uses: actions/checkout@v2 - - - name: Setup conda - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - use-mamba: true - activate-environment: goli - - - name: Install Dependencies - run: mamba env update -f env.yml - - - name: Install library - run: python -m pip install . - - - name: Run tests - run: pytest - - - name: Test CLI - run: goli --help - - - name: Test building the doc - run: | - # Build and serve the doc - mkdocs build - - - name: Deploy the doc - if: ${{ github.ref == 'refs/heads/master' }} - run: | - # Get the privpage branch - git fetch origin privpage - - # Build and serve the doc - mkdocs gh-deploy +name: test + +on: + push: + branches: + - "master" + pull_request: + branches: + - "*" + - "!privpage" + +jobs: + test: + runs-on: "ubuntu-latest" + defaults: + run: + shell: bash -l {0} + steps: + - name: Checkout the code + uses: actions/checkout@v2 + + - name: Setup conda + uses: conda-incubator/setup-miniconda@v2 + with: + miniforge-variant: Mambaforge + use-mamba: true + activate-environment: goli + + - name: Install Dependencies + run: mamba env update -f env.yml + + - name: Install library + run: python -m pip install . + + - name: Run tests + run: pytest + + - name: Test CLI + run: goli --help + + - name: Test building the doc + run: | + # Build and serve the doc + mkdocs build + + - name: Deploy the doc + if: ${{ github.ref == 'refs/heads/master' }} + run: | + # Get the privpage branch + git fetch origin privpage + + # Build and serve the doc + mkdocs gh-deploy diff --git a/.gitignore b/.gitignore index 54346a342..ce25dcb6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,149 +1,149 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# Custom GitIgnore -*.code-workspace -lightning_logs/ -logs/ -.vscode/ -models_checkpoints/ -multirun/ -outputs/ -aws_multirun/ - -rever/ - - - -goli/data/ZINC_bench_gnn/ -goli/data/BindingDB/ -goli/data/cache/ - -predictions/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Custom GitIgnore +*.code-workspace +lightning_logs/ +logs/ +.vscode/ +models_checkpoints/ +multirun/ +outputs/ +aws_multirun/ + +rever/ + + + +goli/data/ZINC_bench_gnn/ +goli/data/BindingDB/ +goli/data/cache/ + +predictions/ diff --git a/.mailmap b/.mailmap index bbb78bf29..3d9bc2e89 100644 --- a/.mailmap +++ b/.mailmap @@ -1,16 +1,16 @@ -# This file was autogenerated by rever: https://regro.github.io/rever-docs/ -# This prevent git from showing duplicates with various logging commands. -# See the git documentation for more details. The syntax is: -# -# good-name bad-name -# -# You can skip bad-name if it is the same as good-name and is unique in the repo. -# -# This file is up-to-date if the command git log --format="%aN <%aE>" | sort -u -# gives no duplicates. - -Dom -DomInvivo <47570400+DomInvivo@users.noreply.github.com> -Hadrien Mary Hadrien Mary -Therence1 <38595485+Therence1@users.noreply.github.com> -Ubuntu +# This file was autogenerated by rever: https://regro.github.io/rever-docs/ +# This prevent git from showing duplicates with various logging commands. +# See the git documentation for more details. The syntax is: +# +# good-name bad-name +# +# You can skip bad-name if it is the same as good-name and is unique in the repo. +# +# This file is up-to-date if the command git log --format="%aN <%aE>" | sort -u +# gives no duplicates. + +Dom +DomInvivo <47570400+DomInvivo@users.noreply.github.com> +Hadrien Mary Hadrien Mary +Therence1 <38595485+Therence1@users.noreply.github.com> +Ubuntu diff --git a/AUTHORS.rst b/AUTHORS.rst index f3a53c795..8906e10f9 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,8 +1,8 @@ -All of the people who have made at least one contribution to goli. -Authors are sorted alphabetically. - -* Dom -* DomInvivo -* Hadrien Mary -* Therence1 -* Ubuntu +All of the people who have made at least one contribution to goli. +Authors are sorted alphabetically. + +* Dom +* DomInvivo +* Hadrien Mary +* Therence1 +* Ubuntu diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 272083297..997aceb9b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,33 +1,33 @@ -===================== -goli Change Log -===================== - -.. current developments - -v0.1.0 -==================== - -**Added:** - -* First working version of goli. Browse the documentation and tutorials for more details. - -**Authors:** - -* Dom -* Hadrien Mary -* Therence1 -* Ubuntu - - - -v0.0.1 -==================== - -**Added:** - -* Fake release to test the process. - -**Authors:** - - - +===================== +goli Change Log +===================== + +.. current developments + +v0.1.0 +==================== + +**Added:** + +* First working version of goli. Browse the documentation and tutorials for more details. + +**Authors:** + +* Dom +* Hadrien Mary +* Therence1 +* Ubuntu + + + +v0.0.1 +==================== + +**Added:** + +* Fake release to test the process. + +**Authors:** + + + diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 18c914718..d28f7f910 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,128 +1,128 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation in our -community a harassment-free experience for everyone, regardless of age, body -size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, -nationality, personal appearance, race, religion, or sexual identity -and orientation. - -We pledge to act and interact in ways that contribute to an open, welcoming, -diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for our -community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, - and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our standards of -acceptable behavior and will take appropriate and fair corrective action in -response to any behavior that they deem inappropriate, threatening, offensive, -or harmful. - -Community leaders have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, and will communicate reasons for moderation -decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also applies when -an individual is officially representing the community in public spaces. -Examples of representing our community include using an official e-mail address, -posting via an official social media account, or acting as an appointed -representative at an online or offline event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported to the community leaders responsible for enforcement at -. -All complaints will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and security of the -reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in determining -the consequences for any action they deem in violation of this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior deemed -unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, providing -clarity around the nature of the violation and an explanation of why the -behavior was inappropriate. A public apology may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued behavior. No -interaction with the people involved, including unsolicited interaction with -those enforcing the Code of Conduct, for a specified period of time. This -includes avoiding interactions in community spaces as well as external channels -like social media. Violating these terms may lead to a temporary or -permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, including -sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or public -communication with the community for a specified period of time. No public or -private interaction with the people involved, including unsolicited interaction -with those enforcing the Code of Conduct, is allowed during this period. -Violating these terms may lead to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of an -individual, or aggression toward or disparagement of classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction within -the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor Covenant][homepage], -version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of conduct -enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the FAQ at -https://www.contributor-covenant.org/faq. Translations are available at -https://www.contributor-covenant.org/translations. +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9549f8e65..b7870fe7c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,3 @@ -# Contribute - -See documentation at https://valence-discovery.github.io/goli/ +# Contribute + +See documentation at https://valence-discovery.github.io/goli/ diff --git a/LICENSE b/LICENSE index 97b34df28..f048b6a9d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,201 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2021 Valence - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 Valence + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in index 36d7eed6d..ee5b0c287 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,8 @@ -include README.md -include *.py -graft goli/*.yaml -recursive-include goli/ *.csv -recursive-include goli/config/ *.yaml -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] -recursive-exclude * .ipynb_checkpoints/* +include README.md +include *.py +graft goli/*.yaml +recursive-include goli/ *.csv +recursive-include goli/config/ *.yaml +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] +recursive-exclude * .ipynb_checkpoints/* diff --git a/README.md b/README.md index c11f8487c..acda39891 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,68 @@ -
- -

The Graph Of LIfe Library.

-
- ---- - -[![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/tutorials/) -[![PyPI](https://img.shields.io/pypi/v/goli)](https://pypi.org/project/goli/) -[![Conda](https://img.shields.io/conda/v/conda-forge/goli?label=conda&color=success)](https://anaconda.org/conda-forge/goli) -[![PyPI - Downloads](https://img.shields.io/pypi/dm/goli)](https://pypi.org/project/goli/) -[![Conda](https://img.shields.io/conda/dn/conda-forge/goli)](https://anaconda.org/conda-forge/goli) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/goli)](https://pypi.org/project/goli/) -[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/valence-discovery/goli/blob/master/LICENSE) -[![GitHub Repo stars](https://img.shields.io/github/stars/valence-discovery/goli)](https://github.com/valence-discovery/goli/stargazers) -[![GitHub Repo stars](https://img.shields.io/github/forks/valence-discovery/goli)](https://github.com/valence-discovery/goli/network/members) - -A deep learning library focused on graph representation learning for real-world chemical tasks. - -- ✅ State-of-the-art GNN architectures. -- 🐍 Extensible API: build your own GNN model and train it with ease. -- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization. -- 🧠 Pretrained models: for fast and easy inference or transfer learning. -- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/). - -## Try Online - -Visit [![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/tutorials/) and try goli online. - -## Documentation - -Visit https://valence-discovery.github.io/goli/. - -## Installation - -Use either [`mamba`](https://github.com/mamba-org/mamba) or [`conda`](https://docs.conda.io/en/latest/): - -```bash -mamba install -c conda-forge goli -``` - -or pip: - -```bash -pip install goli -``` - -## Quick API Tour - -```python -import goli - -# TODO: show a quick snippet of goli that: -# - build a model and train it -# - or load a model and do inference. -``` - -## Changelogs - -See the latest changelogs at [CHANGELOG.rst](./CHANGELOG.rst). - -## License - -Under the Apache-2.0 license. See [LICENSE](LICENSE). - -## Authors - -See [AUTHORS.rst](./AUTHORS.rst). +
+ +

The Graph Of LIfe Library.

+
+ +--- + +[![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/tutorials/) +[![PyPI](https://img.shields.io/pypi/v/goli)](https://pypi.org/project/goli/) +[![Conda](https://img.shields.io/conda/v/conda-forge/goli?label=conda&color=success)](https://anaconda.org/conda-forge/goli) +[![PyPI - Downloads](https://img.shields.io/pypi/dm/goli)](https://pypi.org/project/goli/) +[![Conda](https://img.shields.io/conda/dn/conda-forge/goli)](https://anaconda.org/conda-forge/goli) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/goli)](https://pypi.org/project/goli/) +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/valence-discovery/goli/blob/master/LICENSE) +[![GitHub Repo stars](https://img.shields.io/github/stars/valence-discovery/goli)](https://github.com/valence-discovery/goli/stargazers) +[![GitHub Repo stars](https://img.shields.io/github/forks/valence-discovery/goli)](https://github.com/valence-discovery/goli/network/members) + +A deep learning library focused on graph representation learning for real-world chemical tasks. + +- ✅ State-of-the-art GNN architectures. +- 🐍 Extensible API: build your own GNN model and train it with ease. +- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization. +- 🧠 Pretrained models: for fast and easy inference or transfer learning. +- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/). + +## Try Online + +Visit [![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/tutorials/) and try goli online. + +## Documentation + +Visit https://valence-discovery.github.io/goli/. + +## Installation + +Use either [`mamba`](https://github.com/mamba-org/mamba) or [`conda`](https://docs.conda.io/en/latest/): + +```bash +mamba install -c conda-forge goli +``` + +or pip: + +```bash +pip install goli +``` + +## Quick API Tour + +```python +import goli + +# TODO: show a quick snippet of goli that: +# - build a model and train it +# - or load a model and do inference. +``` + +## Changelogs + +See the latest changelogs at [CHANGELOG.rst](./CHANGELOG.rst). + +## License + +Under the Apache-2.0 license. See [LICENSE](LICENSE). + +## Authors + +See [AUTHORS.rst](./AUTHORS.rst). diff --git a/SECURITY.md b/SECURITY.md index a983683b6..8cbdd1f40 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,3 +1,3 @@ -# Security Policy - -Please report any security-related issues directly to hadrien@valencediscovery.com. +# Security Policy + +Please report any security-related issues directly to hadrien@valencediscovery.com. diff --git a/docs/_assets/css/custom.css b/docs/_assets/css/custom.css index 3427b49bb..f798946a3 100644 --- a/docs/_assets/css/custom.css +++ b/docs/_assets/css/custom.css @@ -1,33 +1,33 @@ -/* Indentation. */ -div.doc-contents:not(.first) { - padding-left: 25px; - border-left: 4px solid rgba(230, 230, 230); - margin-bottom: 80px; -} - -/* Don't capitalize names. */ -h5.doc-heading { - text-transform: none !important; -} - -/* Don't use vertical space on hidden ToC entries. */ -.hidden-toc::before { - margin-top: 0 !important; - padding-top: 0 !important; -} - -/* Don't show permalink of hidden ToC entries. */ -.hidden-toc a.headerlink { - display: none; -} - -/* Avoid breaking parameters name, etc. in table cells. */ -td code { - word-break: normal !important; -} - -/* For pieces of Markdown rendered in table cells. */ -td p { - margin-top: 0 !important; - margin-bottom: 0 !important; -} +/* Indentation. */ +div.doc-contents:not(.first) { + padding-left: 25px; + border-left: 4px solid rgba(230, 230, 230); + margin-bottom: 80px; +} + +/* Don't capitalize names. */ +h5.doc-heading { + text-transform: none !important; +} + +/* Don't use vertical space on hidden ToC entries. */ +.hidden-toc::before { + margin-top: 0 !important; + padding-top: 0 !important; +} + +/* Don't show permalink of hidden ToC entries. */ +.hidden-toc a.headerlink { + display: none; +} + +/* Avoid breaking parameters name, etc. in table cells. */ +td code { + word-break: normal !important; +} + +/* For pieces of Markdown rendered in table cells. */ +td p { + margin-top: 0 !important; + margin-bottom: 0 !important; +} diff --git a/docs/api/goli.data.md b/docs/api/goli.data.md index 74c3cc81b..3322b6e8d 100644 --- a/docs/api/goli.data.md +++ b/docs/api/goli.data.md @@ -1 +1 @@ -::: goli.data +::: goli.data diff --git a/docs/api/goli.features.md b/docs/api/goli.features.md index 2f3437572..4903291a0 100644 --- a/docs/api/goli.features.md +++ b/docs/api/goli.features.md @@ -1 +1 @@ -::: goli.features +::: goli.features diff --git a/docs/api/goli.nn.md b/docs/api/goli.nn.md index ec19d11d8..0b2ee6b43 100644 --- a/docs/api/goli.nn.md +++ b/docs/api/goli.nn.md @@ -1 +1 @@ -::: goli.nn +::: goli.nn diff --git a/docs/api/goli.trainer.md b/docs/api/goli.trainer.md index 1044b7574..1640c869f 100644 --- a/docs/api/goli.trainer.md +++ b/docs/api/goli.trainer.md @@ -1 +1 @@ -::: goli.trainer +::: goli.trainer diff --git a/docs/api/goli.utils.md b/docs/api/goli.utils.md index 7dbb0e157..ea63a6fc1 100644 --- a/docs/api/goli.utils.md +++ b/docs/api/goli.utils.md @@ -1 +1 @@ -::: goli.utils +::: goli.utils diff --git a/docs/cli_references.md b/docs/cli_references.md index 29a13ad5c..fcbf2f23b 100644 --- a/docs/cli_references.md +++ b/docs/cli_references.md @@ -1,8 +1,8 @@ -# CLI Reference - -This page provides documentation for our command line tools. - -::: mkdocs-click - :module: goli.cli - :command: main_cli - :command: data_cli +# CLI Reference + +This page provides documentation for our command line tools. + +::: mkdocs-click + :module: goli.cli + :command: main_cli + :command: data_cli diff --git a/docs/contribute.md b/docs/contribute.md index 383e08bf9..573bd69f9 100644 --- a/docs/contribute.md +++ b/docs/contribute.md @@ -1,43 +1,43 @@ -# Contribute - -The below documents the development lifecycle of Datamol. - -## Setup a dev environment - -```bash -conda create -n goli -conda activate goli - -mamba env update -f env.yml - -conda deactivate && conda activate goli -pip install -e . -``` - -## Run tests - -```bash -pytest -``` - -## Build the documentation - -You can build and serve the documentation locally with: - -```bash -# Build and serve the doc -mike serve -``` - -## Release a new version - -- Run check: `rever check`. -- Bump and release new version: `rever VERSION_NUMBER`. -- Releasing a new version will do the following things in that order: - - Update `AUTHORS.rst`. - - Update `CHANGELOG.rst`. - - Bump the version number in `setup.py` and `_version.py`. - - Add a git tag. - - Push the git tag. - - Add a new release on the GH repo associated with the git tag. - - Update the conda forge feedstock. +# Contribute + +The below documents the development lifecycle of Datamol. + +## Setup a dev environment + +```bash +conda create -n goli +conda activate goli + +mamba env update -f env.yml + +conda deactivate && conda activate goli +pip install -e . +``` + +## Run tests + +```bash +pytest +``` + +## Build the documentation + +You can build and serve the documentation locally with: + +```bash +# Build and serve the doc +mike serve +``` + +## Release a new version + +- Run check: `rever check`. +- Bump and release new version: `rever VERSION_NUMBER`. +- Releasing a new version will do the following things in that order: + - Update `AUTHORS.rst`. + - Update `CHANGELOG.rst`. + - Bump the version number in `setup.py` and `_version.py`. + - Add a git tag. + - Push the git tag. + - Add a new release on the GH repo associated with the git tag. + - Update the conda forge feedstock. diff --git a/docs/datasets.md b/docs/datasets.md index df52153cc..830bfa17f 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -1,43 +1,43 @@ -# GOLI Datasets - -GOLI datasets are hosted at on Google Cloud Storage at `gs://goli-public/datasets`. GOLI provides a convenient utility functions to list and download those datasets: - -```python -import goli - -dataset_dir = "/my/path" -data_path = goli.data.utils.download_goli_dataset("goli-zinc-micro", output_path=dataset_dir) -print(data_path) -# /my/path/goli-zinc-micro -``` - -## `goli-zinc-micro` - -ADD DESCRIPTION. - -- Number of molecules: xxx -- Label columns: xxx -- Split informations. - -## `goli-zinc-bench-gnn` - -ADD DESCRIPTION. - -- Number of molecules: xxx -- Label columns: xxx- Split informations. - -## `goli-htsfp` - -ADD DESCRIPTION. - -- Number of molecules: xxx -- Label columns: xxx -- Split informations. - -## `goli-htsfp-pcba` - -ADD DESCRIPTION. - -- Number of molecules: xxx -- Label columns: xxx -- Split informations. +# GOLI Datasets + +GOLI datasets are hosted at on Google Cloud Storage at `gs://goli-public/datasets`. GOLI provides a convenient utility functions to list and download those datasets: + +```python +import goli + +dataset_dir = "/my/path" +data_path = goli.data.utils.download_goli_dataset("goli-zinc-micro", output_path=dataset_dir) +print(data_path) +# /my/path/goli-zinc-micro +``` + +## `goli-zinc-micro` + +ADD DESCRIPTION. + +- Number of molecules: xxx +- Label columns: xxx +- Split informations. + +## `goli-zinc-bench-gnn` + +ADD DESCRIPTION. + +- Number of molecules: xxx +- Label columns: xxx- Split informations. + +## `goli-htsfp` + +ADD DESCRIPTION. + +- Number of molecules: xxx +- Label columns: xxx +- Split informations. + +## `goli-htsfp-pcba` + +ADD DESCRIPTION. + +- Number of molecules: xxx +- Label columns: xxx +- Split informations. diff --git a/docs/design.md b/docs/design.md index 2fc762309..ba73b6a3f 100644 --- a/docs/design.md +++ b/docs/design.md @@ -1,114 +1,114 @@ -# Goli Library Design - ---- - -**Section from the previous README:** - -### Data setup - -Then, you need to download the data needed to run the code. Right now, we have 2 sets of data folders, present in the link [here](https://drive.google.com/drive/folders/1RrbNZkEE2rf41_iroa1LbIyegW00h3Ql?usp=sharing). - -- **micro_ZINC** (Synthetic dataset) - - A small subset (1000 mols) of the ZINC dataset - - The score is the subtraction of the computed LogP and the synthetic accessibility score SA - - The data must be downloaded to the folder `./goli/data/micro_ZINC/` - -- **ZINC_bench_gnn** (Synthetic dataset) - - A subset (12000 mols) of the ZINC dataset - - The score is the subtraction of the computed LogP and the synthetic accessibility score SA - - These are the same 12k molecules provided by the [Benchmarking-gnn](https://github.com/graphdeeplearning/benchmarking-gnns) repository. - - We provide the pre-processed graphs in `ZINC_bench_gnn/data_from_benchmark` - - We provide the SMILES in `ZINC_bench_gnn/smiles_score.csv`, with the train-val-test indexes in the file `indexes_train_val_test.csv`. - - The first 10k elements are the training set - - The next 1k the valid set - - The last 1k the test set. - - The data must be downloaded to the folder `./goli/data/ZINC_bench_gnn/` - -Then, you can run the main file to make sure that all the dependancies are correctly installed and that the code works as expected. - -```bash -python expts/main_micro_zinc.py -``` - ---- - -**TODO: explain the internal design of Goli so people can contribute to it more easily.** - -## Structure of the code - -The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs. - -### data_parser - -This folder contains tools that allow tdependenciesrent kind of molecular data files, such as `.csv` or `.xlsx` with SMILES data, or `.sdf` files with 3D data. - -### dgl - -This folder contains the code necessary for compatibility with the Deep Graph Library (DGL), and implements many state of the art GNN methods, such as GCN, GIN, MPNN and PNA. -It also contains major network architecture implemented for DGL (feed-forward, resnet, skip-connections, densenet), along with the DGL graph transformer for molecules. -**_I was thinking of removing most models to avoid having too much maintenance to do, since they either don't perform well on molecular tasks (GAT, RingGNN) or they are generalized by DGN (GCN, GIN, MPNN, PNA, GraphSage)_**. - -### features - -Different utilities for molecules, such as Smiles to adjacency graph transformer, molecular property extraction, atomic properties, bond properties, ... - -**_The MolecularTransformer and AdjGraphTransformer come from ivbase, but I don't like them. I think we should replace them with something simpler and give more flexibility for combining one-hot embedding with physical properties embedding._**. - -### trainer - -The trainer contains the interface to the `pytorch-lightning` library, with `ModelWrapper` being the main class used for any NN model, either for regression or classification. It also contains some modifications to the logger and model_summary from `pytorch-lightning` to enable more flexibility. - -### utils - -Any kind of utilities that can be used anywhere, including argument checkers and configuration loader - -### visualization - -Plot visualization tools - -## Modifying the code - -### Adding a new GNN layer - -Any new GNN layer must inherit from the class `goli.nn.dgl_layers.base_dgl_layer.BaseDGLLayer` and be implemented in the folder `goli/dgl/dgl_layers`, imported in the file `goli/dgl/architectures.py`, and in the same file, added to the function `FeedForwardDGL._parse_gnn_layer`. - -To be used in the configuration file as a `goli.model.layer_name`, it must also be implemented with some variable parameters in the file `expts/config_gnns.yaml`. - -### Adding a new NN architecture - -All NN and GNN architectures compatible with the `DGL` library are provided in the file `goli/dgl/architectures.py`. When implementing a new architecture, it is highly recommended to inherit from `goli.nn.architectures.FeedForwardNN` for regular neural networks, from `goli.nn.architectures.FeedForwardDGL` for DGL neural network, or from any of their sub-classes. - -### Changing the ModelWrapper and loss function - -The `ModelWrapper` is a general pytorch-lightning module that should work with any kind of `pytorch.nn.Module` or `pl.LightningModule`. The class defines a structure of including models, loss functions, batch sizes, collate functions, metrics... - -Some loss functions are already implemented in the ModelWrapper, including `mse, bce, mae, cosine`, but some tasks will require more complex loss functions. One can add any new function in `goli.trainer.predictor.ModelWrapper._parse_loss_fun`. - -### Changing the metrics used - -**_!WARNING! The metrics implementation was done for pytorch-lightning v0.8. There has been major changes to how the metrics are used and defined, so the whole implementation must change._** - -Our current code is compatible with the metrics defined by _pytorch-lightning_, which include a great set of metrics. We also added the PearsonR and SpearmanR as they are important correlation metrics. You can define any new metric in the file `goli/trainer/metrics.py`. The metric must inherit from `TensorMetric` and must be added to the dictionary `goli.trainer.metrics.METRICS_DICT`. - -To use the metric, you can easily add it's name from `METRICS_DICT` in the yaml configuration file, at the address `metrics.metrics_dict`. Each metric has an underlying dictionnary with a mandatory `threshold` key containing information on how to threshold the prediction/target before computing the metric. Any `kwargs` arguments of the metric must also be added. - -## (OLD) Running a hyper-parameter search - -In the current repository, we use `hydra-core` to launch multiple experiments in a grid-search manner. It works by specifying the parameters that we want to change from a given YAML file. - -Below is an example of running a set of 3\*2\*2\*2=24 experiments, 3 variations of the gnn type _layer_name_, 2 variations of the learning rate _lr_, 2 variations of the hidden dimension _hidden_dim_, 2 variations of the network depth _hidden_depth_. All parameters not mentionned in the code below are unchanged from the file `expts/main_micro_ZINC.py`. - - python expts/main_micro_ZINC.py --multirun \ - model.layer_name=gin,gcn,pna-conv3 \ - constants.exp_name="testing_hydra" \ - constants.device="cuda:0" \ - constants.ignore_train_error=true \ - predictor.lr=1e-4,1e-3 \ - model.gnn_kwargs.hidden_dim=32,64 \ - model.gnn_kwargs.hidden_depth=4,8 - -The results of the run will be available in the folder `multirun/[CURRENT-DATE]/[CURRENT-TIME]`. To open the results in tensorflow, run the following command using _bash_ or _powershell_ - -`tensorboard --logdir 'multirun/[CURRENT-DATE]/[CURRENT-TIME]/' --port 8000` - -Then open a web-browser and enter the address `http://localhost:8000/`. +# Goli Library Design + +--- + +**Section from the previous README:** + +### Data setup + +Then, you need to download the data needed to run the code. Right now, we have 2 sets of data folders, present in the link [here](https://drive.google.com/drive/folders/1RrbNZkEE2rf41_iroa1LbIyegW00h3Ql?usp=sharing). + +- **micro_ZINC** (Synthetic dataset) + - A small subset (1000 mols) of the ZINC dataset + - The score is the subtraction of the computed LogP and the synthetic accessibility score SA + - The data must be downloaded to the folder `./goli/data/micro_ZINC/` + +- **ZINC_bench_gnn** (Synthetic dataset) + - A subset (12000 mols) of the ZINC dataset + - The score is the subtraction of the computed LogP and the synthetic accessibility score SA + - These are the same 12k molecules provided by the [Benchmarking-gnn](https://github.com/graphdeeplearning/benchmarking-gnns) repository. + - We provide the pre-processed graphs in `ZINC_bench_gnn/data_from_benchmark` + - We provide the SMILES in `ZINC_bench_gnn/smiles_score.csv`, with the train-val-test indexes in the file `indexes_train_val_test.csv`. + - The first 10k elements are the training set + - The next 1k the valid set + - The last 1k the test set. + - The data must be downloaded to the folder `./goli/data/ZINC_bench_gnn/` + +Then, you can run the main file to make sure that all the dependancies are correctly installed and that the code works as expected. + +```bash +python expts/main_micro_zinc.py +``` + +--- + +**TODO: explain the internal design of Goli so people can contribute to it more easily.** + +## Structure of the code + +The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs. + +### data_parser + +This folder contains tools that allow tdependenciesrent kind of molecular data files, such as `.csv` or `.xlsx` with SMILES data, or `.sdf` files with 3D data. + +### dgl + +This folder contains the code necessary for compatibility with the Deep Graph Library (DGL), and implements many state of the art GNN methods, such as GCN, GIN, MPNN and PNA. +It also contains major network architecture implemented for DGL (feed-forward, resnet, skip-connections, densenet), along with the DGL graph transformer for molecules. +**_I was thinking of removing most models to avoid having too much maintenance to do, since they either don't perform well on molecular tasks (GAT, RingGNN) or they are generalized by DGN (GCN, GIN, MPNN, PNA, GraphSage)_**. + +### features + +Different utilities for molecules, such as Smiles to adjacency graph transformer, molecular property extraction, atomic properties, bond properties, ... + +**_The MolecularTransformer and AdjGraphTransformer come from ivbase, but I don't like them. I think we should replace them with something simpler and give more flexibility for combining one-hot embedding with physical properties embedding._**. + +### trainer + +The trainer contains the interface to the `pytorch-lightning` library, with `ModelWrapper` being the main class used for any NN model, either for regression or classification. It also contains some modifications to the logger and model_summary from `pytorch-lightning` to enable more flexibility. + +### utils + +Any kind of utilities that can be used anywhere, including argument checkers and configuration loader + +### visualization + +Plot visualization tools + +## Modifying the code + +### Adding a new GNN layer + +Any new GNN layer must inherit from the class `goli.nn.dgl_layers.base_dgl_layer.BaseDGLLayer` and be implemented in the folder `goli/dgl/dgl_layers`, imported in the file `goli/dgl/architectures.py`, and in the same file, added to the function `FeedForwardDGL._parse_gnn_layer`. + +To be used in the configuration file as a `goli.model.layer_name`, it must also be implemented with some variable parameters in the file `expts/config_gnns.yaml`. + +### Adding a new NN architecture + +All NN and GNN architectures compatible with the `DGL` library are provided in the file `goli/dgl/architectures.py`. When implementing a new architecture, it is highly recommended to inherit from `goli.nn.architectures.FeedForwardNN` for regular neural networks, from `goli.nn.architectures.FeedForwardDGL` for DGL neural network, or from any of their sub-classes. + +### Changing the ModelWrapper and loss function + +The `ModelWrapper` is a general pytorch-lightning module that should work with any kind of `pytorch.nn.Module` or `pl.LightningModule`. The class defines a structure of including models, loss functions, batch sizes, collate functions, metrics... + +Some loss functions are already implemented in the ModelWrapper, including `mse, bce, mae, cosine`, but some tasks will require more complex loss functions. One can add any new function in `goli.trainer.predictor.ModelWrapper._parse_loss_fun`. + +### Changing the metrics used + +**_!WARNING! The metrics implementation was done for pytorch-lightning v0.8. There has been major changes to how the metrics are used and defined, so the whole implementation must change._** + +Our current code is compatible with the metrics defined by _pytorch-lightning_, which include a great set of metrics. We also added the PearsonR and SpearmanR as they are important correlation metrics. You can define any new metric in the file `goli/trainer/metrics.py`. The metric must inherit from `TensorMetric` and must be added to the dictionary `goli.trainer.metrics.METRICS_DICT`. + +To use the metric, you can easily add it's name from `METRICS_DICT` in the yaml configuration file, at the address `metrics.metrics_dict`. Each metric has an underlying dictionnary with a mandatory `threshold` key containing information on how to threshold the prediction/target before computing the metric. Any `kwargs` arguments of the metric must also be added. + +## (OLD) Running a hyper-parameter search + +In the current repository, we use `hydra-core` to launch multiple experiments in a grid-search manner. It works by specifying the parameters that we want to change from a given YAML file. + +Below is an example of running a set of 3\*2\*2\*2=24 experiments, 3 variations of the gnn type _layer_name_, 2 variations of the learning rate _lr_, 2 variations of the hidden dimension _hidden_dim_, 2 variations of the network depth _hidden_depth_. All parameters not mentionned in the code below are unchanged from the file `expts/main_micro_ZINC.py`. + + python expts/main_micro_ZINC.py --multirun \ + model.layer_name=gin,gcn,pna-conv3 \ + constants.exp_name="testing_hydra" \ + constants.device="cuda:0" \ + constants.ignore_train_error=true \ + predictor.lr=1e-4,1e-3 \ + model.gnn_kwargs.hidden_dim=32,64 \ + model.gnn_kwargs.hidden_depth=4,8 + +The results of the run will be available in the folder `multirun/[CURRENT-DATE]/[CURRENT-TIME]`. To open the results in tensorflow, run the following command using _bash_ or _powershell_ + +`tensorboard --logdir 'multirun/[CURRENT-DATE]/[CURRENT-TIME]/' --port 8000` + +Then open a web-browser and enter the address `http://localhost:8000/`. diff --git a/docs/images/logo-title.svg b/docs/images/logo-title.svg index 6f9828c0f..d8d909e98 100644 --- a/docs/images/logo-title.svg +++ b/docs/images/logo-title.svg @@ -1,188 +1,188 @@ - - - - - - - - - - image/svg+xml - - - - - - - - goli - - - - - - - - - - - - - - - - - - - + + + + + + + + + + image/svg+xml + + + + + + + + goli + + + + + + + + + + + + + + + + + + + diff --git a/docs/images/logo.svg b/docs/images/logo.svg index 0733fede4..c4a9bb471 100644 --- a/docs/images/logo.svg +++ b/docs/images/logo.svg @@ -1,177 +1,177 @@ - - - - - - - - - - image/svg+xml - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/index.md b/docs/index.md index b008c8400..dea66e55c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,41 +1,41 @@ -# Overview - -A deep learning library focused on graph representation learning for real-world chemical tasks. - -- ✅ State-of-the-art GNN architectures. -- 🐍 Extensible API: build your own GNN model and train it with ease. -- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization. -- 🧠 Pretrained models: for fast and easy inference or transfer learning. -- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/). - -## Try Online - -Visit [![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/*tutorials*.ipynb) and try goli online. - -## Documentation - -Visit https://valence-discovery.github.io/goli/. - -## Installation - -Use either [`mamba`](https://github.com/mamba-org/mamba) or [`conda`](https://docs.conda.io/en/latest/): - -```bash -mamba install -c conda-forge goli -``` - -or pip: - -```bash -pip install goli -``` - -## Quick API Tour - -```python -import goli - -# TODO: show a quick snippet of goli that: -# - build a model and train it -# - or load a model and do inference. -``` +# Overview + +A deep learning library focused on graph representation learning for real-world chemical tasks. + +- ✅ State-of-the-art GNN architectures. +- 🐍 Extensible API: build your own GNN model and train it with ease. +- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization. +- 🧠 Pretrained models: for fast and easy inference or transfer learning. +- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/). + +## Try Online + +Visit [![Binder](http://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/valence-discovery/goli/master?urlpath=lab/tree/docs/*tutorials*.ipynb) and try goli online. + +## Documentation + +Visit https://valence-discovery.github.io/goli/. + +## Installation + +Use either [`mamba`](https://github.com/mamba-org/mamba) or [`conda`](https://docs.conda.io/en/latest/): + +```bash +mamba install -c conda-forge goli +``` + +or pip: + +```bash +pip install goli +``` + +## Quick API Tour + +```python +import goli + +# TODO: show a quick snippet of goli that: +# - build a model and train it +# - or load a model and do inference. +``` diff --git a/docs/license.md b/docs/license.md index ec5a3857a..7b738dbc5 100644 --- a/docs/license.md +++ b/docs/license.md @@ -1,3 +1,3 @@ -``` -{!LICENSE!} -``` +``` +{!LICENSE!} +``` diff --git a/docs/pretrained_models.md b/docs/pretrained_models.md index 9ecff895d..2fa01d36d 100644 --- a/docs/pretrained_models.md +++ b/docs/pretrained_models.md @@ -1,15 +1,15 @@ -# GOLI pretrained models - -GOLI provides a set of pretrained models that you can use for inference or transfer learning. The models are available on Google Cloud Storage at `gs://goli-public/pretrained-models`. - -You can load a pretrained models using the GOLI API: - -```python -import goli - -predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") -``` - -## `goli-zinc-micro-dummy-test` - -Dummy model used for testing purposes _(probably to delete in the future)_. +# GOLI pretrained models + +GOLI provides a set of pretrained models that you can use for inference or transfer learning. The models are available on Google Cloud Storage at `gs://goli-public/pretrained-models`. + +You can load a pretrained models using the GOLI API: + +```python +import goli + +predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") +``` + +## `goli-zinc-micro-dummy-test` + +Dummy model used for testing purposes _(probably to delete in the future)_. diff --git a/docs/tutorials/basics/implementing_gnn_layers.ipynb b/docs/tutorials/basics/implementing_gnn_layers.ipynb index ec154a81a..05422b3c5 100644 --- a/docs/tutorials/basics/implementing_gnn_layers.ipynb +++ b/docs/tutorials/basics/implementing_gnn_layers.ipynb @@ -1,321 +1,321 @@ -{ - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8-final" - }, - "orig_nbformat": 2, - "kernelspec": { - "name": "python3", - "display_name": "Python 3", - "language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2, - "cells": [ - { - "source": [ - "# Creating GNN layers\n", - "\n", - "In this example, you will learn how to create a custom GNN class such that it is compatible with the flexible network architecture `goli.nn.architectures.FeedForwardDGL`. We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.\n", - "\n", - "Since these examples are built on top of DGL, we recommend looking at their [library](https://docs.dgl.ai/en/0.5.x/index.html) for more info. " - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "import dgl\n", - "from copy import deepcopy\n", - "\n", - "from goli.nn.dgl_layers import BaseDGLLayer\n", - "from goli.nn.base_layers import FCLayer\n", - "from goli.utils.decorators import classproperty\n", - "\n", - "\n", - "_ = torch.manual_seed(42)" - ] - }, - { - "source": [ - "## Pre-defining test variables\n", - "\n", - "We define below a small batched graph on which we can test the created layers" - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})\n" - ] - } - ], - "source": [ - "in_dim = 5 # Input node-feature dimensions\n", - "out_dim = 11 # Desired output node-feature dimensions\n", - "in_dim_edges = 13 # Input edge-feature dimensions\n", - "\n", - "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", - "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", - "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", - "\n", - "# We add some node features to the graphs\n", - "g1.ndata[\"h\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", - "g2.ndata[\"h\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", - "\n", - "# We also add some edge features to the graphs\n", - "g1.edata[\"e\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", - "g2.edata[\"e\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", - "\n", - "# Finally we batch the graphs in a way compatible with the DGL library\n", - "bg = dgl.batch([g1, g2])\n", - "bg = dgl.add_self_loop(bg)\n", - "\n", - "# The batched graph will show as a single graph with 7 nodes\n", - "print(bg)" - ] - }, - { - "source": [ - "## Creating a simple layer\n", - "\n", - "Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features.\n", - "\n", - "First, for the layer to be fully compatible with the flexible architecture provided by `goli.nn.architectures.FeedForwardDGL`, it needs to inhering from the class `goli.nn.dgl_layers.BaseDGLLayer`. This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.\n", - "\n", - "The virtual methods are below\n", - "\n", - "- `layer_supports_edges`: We want to return `False` since our layer doesn't support edges\n", - "- `layer_inputs_edges`: We want to return `False` since our layer doesn't input edges\n", - "- `layer_outputs_edges`: We want to return `False` since our layer doesn't output edges\n", - "- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.\n", - "\n", - "The example is given below" - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class SimpleMeanLayer(BaseDGLLayer):\n", - " def __init__(self, in_dim, out_dim, activation, dropout, batch_norm):\n", - " # Initialize the parent class\n", - " super().__init__( in_dim=in_dim, out_dim=out_dim, activation=activation,\n", - " dropout=dropout, batch_norm=batch_norm)\n", - "\n", - " # Create the layer with learned parameters\n", - " self.layer = FCLayer(in_dim=in_dim, out_dim=out_dim)\n", - "\n", - " def forward(self, g, h):\n", - " # We first apply the mean aggregation\n", - " g.ndata[\"h\"] = h\n", - " g.update_all(message_func=dgl.function.copy_u(\"h\", \"m\"), \n", - " reduce_func=dgl.function.mean(\"m\", \"h\"))\n", - "\n", - " # Then we apply the FCLayer, and the non-linearities\n", - " h = g.ndata[\"h\"]\n", - " h = self.layer(h)\n", - " h = self.apply_norm_activation_dropout(h)\n", - " return h\n", - "\n", - " # Finally, we define all the virtual properties according to how\n", - " # the class works\n", - " @classproperty\n", - " def layer_supports_edges(cls):\n", - " return False\n", - "\n", - " @property\n", - " def layer_inputs_edges(self):\n", - " return False\n", - "\n", - " @property\n", - " def layer_outputs_edges(self):\n", - " return False\n", - "\n", - " @property\n", - " def out_dim_factor(self):\n", - " return 1 " - ] - }, - { - "source": [ - "Now, we are ready to test the `SimpleMeanLayer` on some DGL graphs. Note that in this example, we **ignore** the edge features since they are not supported." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "layer = SimpleMeanLayer(\n", - " in_dim=in_dim, out_dim=out_dim, \n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "h_out = layer(graph, h_in)\n", - "\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "source": [ - "## Creating a complex layer with edges\n", - "\n", - "Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.\n", - "\n", - "The virtual methods will have different outputs\n", - "\n", - "- `layer_supports_edges`: We want to return `True` since our layer does support edges\n", - "- `layer_inputs_edges`: We want to return `True` since our layer does input edges\n", - "- `layer_outputs_edges`: We want to return `False` since our layer will not output new edges\n", - "- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.\n", - "\n", - "The example is given below" - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class ComplexMeanLayer(BaseDGLLayer):\n", - " def __init__(self, in_dim, out_dim, in_dim_edges, activation, dropout, batch_norm):\n", - " # Initialize the parent class\n", - " super().__init__( in_dim=in_dim, out_dim=out_dim, activation=activation,\n", - " dropout=dropout, batch_norm=batch_norm)\n", - "\n", - " # Create the layer with learned parameters. Note the addition\n", - " self.layer = FCLayer(in_dim=in_dim + in_dim_edges, out_dim=out_dim)\n", - "\n", - " def cat_nodes_edges(self, edges):\n", - " # Create a message \"m\" by concatenating \"h\" and \"e\" for each pair of nodes\n", - " nodes_edges = torch.cat([edges.src[\"h\"], edges.data[\"e\"]], dim=-1)\n", - " return {\"m\": nodes_edges}\n", - "\n", - " def get_edges_messages(self, edges): # Simply return the messages on the edges\n", - " return {\"m\": edges.data[\"m\"]}\n", - "\n", - " def forward(self, g, h, e):\n", - "\n", - " # We first concatenate both the node and edge features on the edges\n", - " g.ndata[\"h\"] = h\n", - " g.edata[\"e\"] = e\n", - " g.apply_edges(self.cat_nodes_edges)\n", - "\n", - " # Then we apply the mean aggregation to generate a message \"m\"\n", - " g.update_all(message_func=self.get_edges_messages, \n", - " reduce_func=dgl.function.mean(\"m\", \"h\"))\n", - "\n", - " # Finally we apply the FCLayer, and the non-linearities\n", - " h = g.ndata[\"h\"]\n", - " h = self.layer(h)\n", - " h = self.apply_norm_activation_dropout(h)\n", - " return h\n", - "\n", - " # Finally, we define all the virtual properties according to how\n", - " # the class works\n", - " @classproperty\n", - " def layer_supports_edges(cls):\n", - " return True\n", - "\n", - " @property\n", - " def layer_inputs_edges(self):\n", - " return True\n", - "\n", - " @property\n", - " def layer_outputs_edges(self):\n", - " return False\n", - "\n", - " @property\n", - " def out_dim_factor(self):\n", - " return 1 " - ] - }, - { - "source": [ - "Now, we are ready to test the `ComplexMeanLayer` on some DGL graphs. Note that in this example, we **use** the edge features since they are mandatory." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "e_in = graph.edata[\"e\"]\n", - "layer = ComplexMeanLayer(\n", - " in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,\n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "h_out = layer(graph, h_in, e_in)\n", - "\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - } - ] +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3", + "language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "# Creating GNN layers\n", + "\n", + "In this example, you will learn how to create a custom GNN class such that it is compatible with the flexible network architecture `goli.nn.architectures.FeedForwardDGL`. We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.\n", + "\n", + "Since these examples are built on top of DGL, we recommend looking at their [library](https://docs.dgl.ai/en/0.5.x/index.html) for more info. " + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import torch\n", + "import dgl\n", + "from copy import deepcopy\n", + "\n", + "from goli.nn.dgl_layers import BaseDGLLayer\n", + "from goli.nn.base_layers import FCLayer\n", + "from goli.utils.decorators import classproperty\n", + "\n", + "\n", + "_ = torch.manual_seed(42)" + ] + }, + { + "source": [ + "## Pre-defining test variables\n", + "\n", + "We define below a small batched graph on which we can test the created layers" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})\n" + ] + } + ], + "source": [ + "in_dim = 5 # Input node-feature dimensions\n", + "out_dim = 11 # Desired output node-feature dimensions\n", + "in_dim_edges = 13 # Input edge-feature dimensions\n", + "\n", + "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", + "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", + "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", + "\n", + "# We add some node features to the graphs\n", + "g1.ndata[\"h\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", + "g2.ndata[\"h\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", + "\n", + "# We also add some edge features to the graphs\n", + "g1.edata[\"e\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", + "g2.edata[\"e\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", + "\n", + "# Finally we batch the graphs in a way compatible with the DGL library\n", + "bg = dgl.batch([g1, g2])\n", + "bg = dgl.add_self_loop(bg)\n", + "\n", + "# The batched graph will show as a single graph with 7 nodes\n", + "print(bg)" + ] + }, + { + "source": [ + "## Creating a simple layer\n", + "\n", + "Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features.\n", + "\n", + "First, for the layer to be fully compatible with the flexible architecture provided by `goli.nn.architectures.FeedForwardDGL`, it needs to inhering from the class `goli.nn.dgl_layers.BaseDGLLayer`. This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.\n", + "\n", + "The virtual methods are below\n", + "\n", + "- `layer_supports_edges`: We want to return `False` since our layer doesn't support edges\n", + "- `layer_inputs_edges`: We want to return `False` since our layer doesn't input edges\n", + "- `layer_outputs_edges`: We want to return `False` since our layer doesn't output edges\n", + "- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.\n", + "\n", + "The example is given below" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class SimpleMeanLayer(BaseDGLLayer):\n", + " def __init__(self, in_dim, out_dim, activation, dropout, batch_norm):\n", + " # Initialize the parent class\n", + " super().__init__( in_dim=in_dim, out_dim=out_dim, activation=activation,\n", + " dropout=dropout, batch_norm=batch_norm)\n", + "\n", + " # Create the layer with learned parameters\n", + " self.layer = FCLayer(in_dim=in_dim, out_dim=out_dim)\n", + "\n", + " def forward(self, g, h):\n", + " # We first apply the mean aggregation\n", + " g.ndata[\"h\"] = h\n", + " g.update_all(message_func=dgl.function.copy_u(\"h\", \"m\"), \n", + " reduce_func=dgl.function.mean(\"m\", \"h\"))\n", + "\n", + " # Then we apply the FCLayer, and the non-linearities\n", + " h = g.ndata[\"h\"]\n", + " h = self.layer(h)\n", + " h = self.apply_norm_activation_dropout(h)\n", + " return h\n", + "\n", + " # Finally, we define all the virtual properties according to how\n", + " # the class works\n", + " @classproperty\n", + " def layer_supports_edges(cls):\n", + " return False\n", + "\n", + " @property\n", + " def layer_inputs_edges(self):\n", + " return False\n", + "\n", + " @property\n", + " def layer_outputs_edges(self):\n", + " return False\n", + "\n", + " @property\n", + " def out_dim_factor(self):\n", + " return 1 " + ] + }, + { + "source": [ + "Now, we are ready to test the `SimpleMeanLayer` on some DGL graphs. Note that in this example, we **ignore** the edge features since they are not supported." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "layer = SimpleMeanLayer(\n", + " in_dim=in_dim, out_dim=out_dim, \n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "h_out = layer(graph, h_in)\n", + "\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "source": [ + "## Creating a complex layer with edges\n", + "\n", + "Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.\n", + "\n", + "The virtual methods will have different outputs\n", + "\n", + "- `layer_supports_edges`: We want to return `True` since our layer does support edges\n", + "- `layer_inputs_edges`: We want to return `True` since our layer does input edges\n", + "- `layer_outputs_edges`: We want to return `False` since our layer will not output new edges\n", + "- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.\n", + "\n", + "The example is given below" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class ComplexMeanLayer(BaseDGLLayer):\n", + " def __init__(self, in_dim, out_dim, in_dim_edges, activation, dropout, batch_norm):\n", + " # Initialize the parent class\n", + " super().__init__( in_dim=in_dim, out_dim=out_dim, activation=activation,\n", + " dropout=dropout, batch_norm=batch_norm)\n", + "\n", + " # Create the layer with learned parameters. Note the addition\n", + " self.layer = FCLayer(in_dim=in_dim + in_dim_edges, out_dim=out_dim)\n", + "\n", + " def cat_nodes_edges(self, edges):\n", + " # Create a message \"m\" by concatenating \"h\" and \"e\" for each pair of nodes\n", + " nodes_edges = torch.cat([edges.src[\"h\"], edges.data[\"e\"]], dim=-1)\n", + " return {\"m\": nodes_edges}\n", + "\n", + " def get_edges_messages(self, edges): # Simply return the messages on the edges\n", + " return {\"m\": edges.data[\"m\"]}\n", + "\n", + " def forward(self, g, h, e):\n", + "\n", + " # We first concatenate both the node and edge features on the edges\n", + " g.ndata[\"h\"] = h\n", + " g.edata[\"e\"] = e\n", + " g.apply_edges(self.cat_nodes_edges)\n", + "\n", + " # Then we apply the mean aggregation to generate a message \"m\"\n", + " g.update_all(message_func=self.get_edges_messages, \n", + " reduce_func=dgl.function.mean(\"m\", \"h\"))\n", + "\n", + " # Finally we apply the FCLayer, and the non-linearities\n", + " h = g.ndata[\"h\"]\n", + " h = self.layer(h)\n", + " h = self.apply_norm_activation_dropout(h)\n", + " return h\n", + "\n", + " # Finally, we define all the virtual properties according to how\n", + " # the class works\n", + " @classproperty\n", + " def layer_supports_edges(cls):\n", + " return True\n", + "\n", + " @property\n", + " def layer_inputs_edges(self):\n", + " return True\n", + "\n", + " @property\n", + " def layer_outputs_edges(self):\n", + " return False\n", + "\n", + " @property\n", + " def out_dim_factor(self):\n", + " return 1 " + ] + }, + { + "source": [ + "Now, we are ready to test the `ComplexMeanLayer` on some DGL graphs. Note that in this example, we **use** the edge features since they are mandatory." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "e_in = graph.edata[\"e\"]\n", + "layer = ComplexMeanLayer(\n", + " in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,\n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "h_out = layer(graph, h_in, e_in)\n", + "\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + } + ] } \ No newline at end of file diff --git a/docs/tutorials/basics/making_gnn_networks.ipynb b/docs/tutorials/basics/making_gnn_networks.ipynb index 5bff7ffd7..4fe295bdb 100644 --- a/docs/tutorials/basics/making_gnn_networks.ipynb +++ b/docs/tutorials/basics/making_gnn_networks.ipynb @@ -1,228 +1,228 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Making GNN Networks\n", - "\n", - "In this example, you will learn how to easily build a full GNN network using any kind of GNN layer." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "import dgl\n", - "from copy import deepcopy\n", - "\n", - "from goli.nn.dgl_layers import PNAMessagePassingLayer\n", - "from goli.nn.architectures import FullDGLNetwork\n", - "\n", - "_ = torch.manual_seed(42)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will first create some simple batched graphs that will be used accross the examples." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'feat': Scheme(shape=(13,), dtype=torch.float64)})\n" - ] - } - ], - "source": [ - "in_dim = 5 # Input node-feature dimensions\n", - "out_dim = 11 # Desired output node-feature dimensions\n", - "in_dim_edges = 13 # Input edge-feature dimensions\n", - "\n", - "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", - "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", - "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", - "\n", - "# We add some node features to the graphs\n", - "g1.ndata[\"feat\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", - "g2.ndata[\"feat\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", - "\n", - "# We also add some edge features to the graphs\n", - "g1.edata[\"feat\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", - "g2.edata[\"feat\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", - "\n", - "# Finally we batch the graphs in a way compatible with the DGL library\n", - "bg = dgl.batch([g1, g2])\n", - "bg = dgl.add_self_loop(bg)\n", - "\n", - "# The batched graph will show as a single graph with 7 nodes\n", - "print(bg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Building the network\n", - "\n", - "To build the network, we must define the arguments to pass at the different steps:\n", - "\n", - "- `pre_nn_kwargs`: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.\n", - "\n", - "- `gnn_kwargs`: The parameters used by a feed-forward **graph** neural network on the features after it has passed through the pre-processing neural network. See class `FeedForwardDGL` for details on the required parameters.\n", - "\n", - "- `post_nn_kwargs`: The parameters used by a feed-forward neural network on the features after the GNN layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "temp_dim_1 = 23\n", - "temp_dim_2 = 17\n", - "\n", - "pre_nn_kwargs = {\n", - " \"in_dim\": in_dim,\n", - " \"out_dim\": temp_dim_1,\n", - " \"hidden_dims\": [4, 4, 4],\n", - " \"activation\": \"relu\",\n", - " \"last_activation\": \"none\",\n", - " \"batch_norm\": True,\n", - " \"dropout\": 0.2, }\n", - "\n", - "post_nn_kwargs = {\n", - " \"in_dim\": temp_dim_2,\n", - " \"out_dim\": out_dim,\n", - " \"hidden_dims\": [6, 6],\n", - " \"activation\": \"relu\",\n", - " \"last_activation\": \"sigmoid\",\n", - " \"batch_norm\": False,\n", - " \"dropout\": 0., }\n", - "\n", - "layer_kwargs = {\n", - " \"aggregators\": [\"mean\", \"max\", \"sum\"], \n", - " \"scalers\": [\"identity\", \"amplification\"],}\n", - "\n", - "gnn_kwargs = {\n", - " \"in_dim\": temp_dim_1,\n", - " \"out_dim\": temp_dim_2,\n", - " \"hidden_dims\": [5, 5, 5, 5, 5, 5],\n", - " \"residual_type\": \"densenet\",\n", - " \"residual_skip_steps\": 2,\n", - " \"layer_type\": PNAMessagePassingLayer,\n", - " \"pooling\": [\"sum\"],\n", - " \"activation\": \"relu\",\n", - " \"last_activation\": \"none\",\n", - " \"batch_norm\": False,\n", - " \"dropout\": 0.2,\n", - " \"in_dim_edges\": in_dim_edges,\n", - " \"layer_kwargs\": layer_kwargs,\n", - "}\n", - "\n", - "gnn_net = FullDGLNetwork(\n", - " pre_nn_kwargs=pre_nn_kwargs, \n", - " gnn_kwargs=gnn_kwargs, \n", - " post_nn_kwargs=post_nn_kwargs).to(float)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Applying the network\n", - "\n", - "Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.\n", - "\n", - "The network will handle the node and edge features depending on it's parameters and layer type." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([7, 5])\ntorch.Size([1, 11])\n\n\nDGL_GNN\n---------\n pre-NN(depth=4, ResidualConnectionNone)\n [FCLayer[5 -> 4 -> 4 -> 4 -> 23]\n \n GNN(depth=7, ResidualConnectionDenseNet(skip_steps=2))\n PNAMessagePassingLayer[23 -> 5 -> 5 -> 5 -> 5 -> 5 -> 5 -> 17]\n -> Pooling(['sum']) -> FCLayer(17 -> 17, activation=None)\n \n post-NN(depth=3, ResidualConnectionNone)\n [FCLayer[17 -> 6 -> 6 -> 11]\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"feat\"]\n", - "\n", - "h_out = gnn_net(graph)\n", - "\n", - "print(h_in.shape)\n", - "print(h_out.shape)\n", - "print(\"\\n\")\n", - "print(gnn_net)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "name": "python3", - "display_name": "Python 3.8.8 64-bit ('goli': conda)", - "metadata": { - "interpreter": { - "hash": "0db34acba4cc11eb0fa2be1290630bee7d7b89811a824f1a15cd26acd4633567" - } - } - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8-final" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": {}, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Making GNN Networks\n", + "\n", + "In this example, you will learn how to easily build a full GNN network using any kind of GNN layer." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import torch\n", + "import dgl\n", + "from copy import deepcopy\n", + "\n", + "from goli.nn.dgl_layers import PNAMessagePassingLayer\n", + "from goli.nn.architectures import FullDGLNetwork\n", + "\n", + "_ = torch.manual_seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will first create some simple batched graphs that will be used accross the examples." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'feat': Scheme(shape=(13,), dtype=torch.float64)})\n" + ] + } + ], + "source": [ + "in_dim = 5 # Input node-feature dimensions\n", + "out_dim = 11 # Desired output node-feature dimensions\n", + "in_dim_edges = 13 # Input edge-feature dimensions\n", + "\n", + "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", + "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", + "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", + "\n", + "# We add some node features to the graphs\n", + "g1.ndata[\"feat\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", + "g2.ndata[\"feat\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", + "\n", + "# We also add some edge features to the graphs\n", + "g1.edata[\"feat\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", + "g2.edata[\"feat\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", + "\n", + "# Finally we batch the graphs in a way compatible with the DGL library\n", + "bg = dgl.batch([g1, g2])\n", + "bg = dgl.add_self_loop(bg)\n", + "\n", + "# The batched graph will show as a single graph with 7 nodes\n", + "print(bg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building the network\n", + "\n", + "To build the network, we must define the arguments to pass at the different steps:\n", + "\n", + "- `pre_nn_kwargs`: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.\n", + "\n", + "- `gnn_kwargs`: The parameters used by a feed-forward **graph** neural network on the features after it has passed through the pre-processing neural network. See class `FeedForwardDGL` for details on the required parameters.\n", + "\n", + "- `post_nn_kwargs`: The parameters used by a feed-forward neural network on the features after the GNN layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "temp_dim_1 = 23\n", + "temp_dim_2 = 17\n", + "\n", + "pre_nn_kwargs = {\n", + " \"in_dim\": in_dim,\n", + " \"out_dim\": temp_dim_1,\n", + " \"hidden_dims\": [4, 4, 4],\n", + " \"activation\": \"relu\",\n", + " \"last_activation\": \"none\",\n", + " \"batch_norm\": True,\n", + " \"dropout\": 0.2, }\n", + "\n", + "post_nn_kwargs = {\n", + " \"in_dim\": temp_dim_2,\n", + " \"out_dim\": out_dim,\n", + " \"hidden_dims\": [6, 6],\n", + " \"activation\": \"relu\",\n", + " \"last_activation\": \"sigmoid\",\n", + " \"batch_norm\": False,\n", + " \"dropout\": 0., }\n", + "\n", + "layer_kwargs = {\n", + " \"aggregators\": [\"mean\", \"max\", \"sum\"], \n", + " \"scalers\": [\"identity\", \"amplification\"],}\n", + "\n", + "gnn_kwargs = {\n", + " \"in_dim\": temp_dim_1,\n", + " \"out_dim\": temp_dim_2,\n", + " \"hidden_dims\": [5, 5, 5, 5, 5, 5],\n", + " \"residual_type\": \"densenet\",\n", + " \"residual_skip_steps\": 2,\n", + " \"layer_type\": PNAMessagePassingLayer,\n", + " \"pooling\": [\"sum\"],\n", + " \"activation\": \"relu\",\n", + " \"last_activation\": \"none\",\n", + " \"batch_norm\": False,\n", + " \"dropout\": 0.2,\n", + " \"in_dim_edges\": in_dim_edges,\n", + " \"layer_kwargs\": layer_kwargs,\n", + "}\n", + "\n", + "gnn_net = FullDGLNetwork(\n", + " pre_nn_kwargs=pre_nn_kwargs, \n", + " gnn_kwargs=gnn_kwargs, \n", + " post_nn_kwargs=post_nn_kwargs).to(float)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying the network\n", + "\n", + "Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.\n", + "\n", + "The network will handle the node and edge features depending on it's parameters and layer type." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([7, 5])\ntorch.Size([1, 11])\n\n\nDGL_GNN\n---------\n pre-NN(depth=4, ResidualConnectionNone)\n [FCLayer[5 -> 4 -> 4 -> 4 -> 23]\n \n GNN(depth=7, ResidualConnectionDenseNet(skip_steps=2))\n PNAMessagePassingLayer[23 -> 5 -> 5 -> 5 -> 5 -> 5 -> 5 -> 17]\n -> Pooling(['sum']) -> FCLayer(17 -> 17, activation=None)\n \n post-NN(depth=3, ResidualConnectionNone)\n [FCLayer[17 -> 6 -> 6 -> 11]\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"feat\"]\n", + "\n", + "h_out = gnn_net(graph)\n", + "\n", + "print(h_in.shape)\n", + "print(h_out.shape)\n", + "print(\"\\n\")\n", + "print(gnn_net)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3.8.8 64-bit ('goli': conda)", + "metadata": { + "interpreter": { + "hash": "0db34acba4cc11eb0fa2be1290630bee7d7b89811a824f1a15cd26acd4633567" + } + } + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8-final" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 } \ No newline at end of file diff --git a/docs/tutorials/basics/using_gnn_layers.ipynb b/docs/tutorials/basics/using_gnn_layers.ipynb index 31aedcf0d..1a0a188da 100644 --- a/docs/tutorials/basics/using_gnn_layers.ipynb +++ b/docs/tutorials/basics/using_gnn_layers.ipynb @@ -1,375 +1,375 @@ -{ - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8-final" - }, - "orig_nbformat": 2, - "kernelspec": { - "name": "python3", - "display_name": "Python 3", - "language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2, - "cells": [ - { - "source": [ - "# Using GNN layers\n", - "\n", - "In this example, you will learn how to use the GCN, GIN, Gated-GCN and PNA layers in a simple `forward` context." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "import dgl\n", - "from copy import deepcopy\n", - "\n", - "from goli.nn.dgl_layers import (\n", - " GCNLayer,\n", - " GINLayer,\n", - " GATLayer,\n", - " GatedGCNLayer,\n", - " PNAConvolutionalLayer,\n", - " PNAMessagePassingLayer,\n", - ")\n", - "\n", - "_ = torch.manual_seed(42)" - ] - }, - { - "source": [ - "We will first create some simple batched graphs that will be used accross the examples." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})\n" - ] - } - ], - "source": [ - "in_dim = 5 # Input node-feature dimensions\n", - "out_dim = 11 # Desired output node-feature dimensions\n", - "in_dim_edges = 13 # Input edge-feature dimensions\n", - "\n", - "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", - "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", - "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", - "\n", - "# We add some node features to the graphs\n", - "g1.ndata[\"h\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", - "g2.ndata[\"h\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", - "\n", - "# We also add some edge features to the graphs\n", - "g1.edata[\"e\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", - "g2.edata[\"e\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", - "\n", - "# Finally we batch the graphs in a way compatible with the DGL library\n", - "bg = dgl.batch([g1, g2])\n", - "bg = dgl.add_self_loop(bg)\n", - "\n", - "# The batched graph will show as a single graph with 7 nodes\n", - "print(bg)\n" - ] - }, - { - "source": [ - "## GCN Layer\n", - "\n", - "To use the GCN layer from the *Kipf et al.* paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.\n", - "\n", - "Kipf, Thomas N., and Max Welling. \"Semi-supervised classification with graph convolutional networks.\" arXiv preprint arXiv:1609.02907 (2016)." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "GCNLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "# We first need to extract the node features from the graph.\n", - "# The GCN method doesn't support edge features, so we ignore them\n", - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "\n", - "# We create the layer\n", - "layer = GCNLayer(\n", - " in_dim=in_dim, out_dim=out_dim, \n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "\n", - "# We apply the forward loop on the node features\n", - "h_out = layer(graph, h_in)\n", - "\n", - "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "source": [ - "## GIN Layer\n", - "\n", - "To use the GIN layer from the *Xu et al.* paper, the steps are identical to GCN.\n", - "\n", - "Xu, Keyulu, et al. \"How powerful are graph neural networks?.\" arXiv preprint arXiv:1810.00826 (2018)." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "GINLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "layer = GINLayer(\n", - " in_dim=in_dim, out_dim=out_dim, \n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "h_out = layer(graph, h_in)\n", - "\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "source": [ - "## GAT Layer\n", - "\n", - "To use the GAT layer from the *Velickovic et al.* paper, the steps are identical to GCN, but the output dimension is multiplied by the number of heads.\n", - "\n", - "Velickovic, Petar, et al. \"Graph attention networks.\" arXiv preprint arXiv:1710.10903 (2017)." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "GATLayer(5 -> 11 * 5, activation=elu)\ntorch.Size([7, 5])\ntorch.Size([7, 55])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "layer = GATLayer(\n", - " in_dim=in_dim, out_dim=out_dim, num_heads=5,\n", - " activation=\"elu\", dropout=.3, batch_norm=True).to(float)\n", - "h_out = layer(graph, h_in)\n", - "\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "source": [ - "## Gated-GCN Layer\n", - "\n", - "To use the Gated-GCN layer from the *Bresson et al.* paper, the steps are different since the layer requires edge features as inputs, and outputs new edge features.\n", - "\n", - "Bresson, Xavier, and Thomas Laurent. \"Residual gated graph convnets.\" arXiv preprint arXiv:1711.07553 (2017)." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "GatedGCNLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\ntorch.Size([14, 13])\ntorch.Size([14, 11])\n" - ] - } - ], - "source": [ - "# We first need to extract the node and edge features from the graph.\n", - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "e_in = graph.edata[\"e\"]\n", - "\n", - "# We create the layer\n", - "layer = GatedGCNLayer(\n", - " in_dim=in_dim, out_dim=out_dim, \n", - " in_dim_edges=in_dim_edges, out_dim_edges=out_dim,\n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "\n", - "# We apply the forward loop on the node features\n", - "h_out, e_out = layer(graph, h_in, e_in)\n", - "\n", - "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n", - "# 13 is the number of input edge features\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)\n", - "print(e_in.shape)\n", - "print(e_out.shape)" - ] - }, - { - "source": [ - "# PNA\n", - "\n", - "PNA is a multi-aggregator method proposed by *Corso et al.*. It supports 2 types of aggregations, convolutional *PNA-conv* or message passing *PNA-msgpass*.\n", - "\n", - "PNA: Principal Neighbourhood Aggregation \n", - "Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic\n", - "https://arxiv.org/abs/2004.05718\n", - "\n", - "## PNA-conv\n", - "\n", - "First, let's focus on the PNA-conv. In this case, it works exactly as the GCN and GIN methods." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "PNAConvolutionalLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "\n", - "# We create the layer, and need to specify the aggregators and scalers\n", - "layer = PNAConvolutionalLayer(\n", - " in_dim=in_dim, out_dim=out_dim, \n", - " aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n", - " scalers=[\"identity\", \"amplification\", \"attenuation\"],\n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "\n", - "h_out = layer(graph, h_in)\n", - "\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "source": [ - "## PNA-msgpass\n", - "\n", - "The PNA message passing is typically more powerful that the convolutional one, and it supports edges as inputs, but doesn't output edges. It's usage is very similar to the *PNA-conv*, but we need to specify the edge dimensions and features." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "PNAMessagePassingLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" - ] - } - ], - "source": [ - "graph = deepcopy(bg)\n", - "h_in = graph.ndata[\"h\"]\n", - "e_in = graph.edata[\"e\"]\n", - "\n", - "# We create the layer, and need to specify the aggregators and scalers\n", - "layer = PNAMessagePassingLayer(\n", - " in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,\n", - " aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n", - " scalers=[\"identity\", \"amplification\", \"attenuation\"],\n", - " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", - "\n", - "h_out = layer(graph, h_in, e_in)\n", - "\n", - "print(layer)\n", - "print(h_in.shape)\n", - "print(h_out.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ] +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3", + "language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "# Using GNN layers\n", + "\n", + "In this example, you will learn how to use the GCN, GIN, Gated-GCN and PNA layers in a simple `forward` context." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import torch\n", + "import dgl\n", + "from copy import deepcopy\n", + "\n", + "from goli.nn.dgl_layers import (\n", + " GCNLayer,\n", + " GINLayer,\n", + " GATLayer,\n", + " GatedGCNLayer,\n", + " PNAConvolutionalLayer,\n", + " PNAMessagePassingLayer,\n", + ")\n", + "\n", + "_ = torch.manual_seed(42)" + ] + }, + { + "source": [ + "We will first create some simple batched graphs that will be used accross the examples." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Graph(num_nodes=7, num_edges=14,\n ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}\n edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})\n" + ] + } + ], + "source": [ + "in_dim = 5 # Input node-feature dimensions\n", + "out_dim = 11 # Desired output node-feature dimensions\n", + "in_dim_edges = 13 # Input edge-feature dimensions\n", + "\n", + "# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes\n", + "g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n", + "g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n", + "\n", + "# We add some node features to the graphs\n", + "g1.ndata[\"h\"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)\n", + "g2.ndata[\"h\"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)\n", + "\n", + "# We also add some edge features to the graphs\n", + "g1.edata[\"e\"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)\n", + "g2.edata[\"e\"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)\n", + "\n", + "# Finally we batch the graphs in a way compatible with the DGL library\n", + "bg = dgl.batch([g1, g2])\n", + "bg = dgl.add_self_loop(bg)\n", + "\n", + "# The batched graph will show as a single graph with 7 nodes\n", + "print(bg)\n" + ] + }, + { + "source": [ + "## GCN Layer\n", + "\n", + "To use the GCN layer from the *Kipf et al.* paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.\n", + "\n", + "Kipf, Thomas N., and Max Welling. \"Semi-supervised classification with graph convolutional networks.\" arXiv preprint arXiv:1609.02907 (2016)." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GCNLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "# We first need to extract the node features from the graph.\n", + "# The GCN method doesn't support edge features, so we ignore them\n", + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "\n", + "# We create the layer\n", + "layer = GCNLayer(\n", + " in_dim=in_dim, out_dim=out_dim, \n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "\n", + "# We apply the forward loop on the node features\n", + "h_out = layer(graph, h_in)\n", + "\n", + "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "source": [ + "## GIN Layer\n", + "\n", + "To use the GIN layer from the *Xu et al.* paper, the steps are identical to GCN.\n", + "\n", + "Xu, Keyulu, et al. \"How powerful are graph neural networks?.\" arXiv preprint arXiv:1810.00826 (2018)." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GINLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "layer = GINLayer(\n", + " in_dim=in_dim, out_dim=out_dim, \n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "h_out = layer(graph, h_in)\n", + "\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "source": [ + "## GAT Layer\n", + "\n", + "To use the GAT layer from the *Velickovic et al.* paper, the steps are identical to GCN, but the output dimension is multiplied by the number of heads.\n", + "\n", + "Velickovic, Petar, et al. \"Graph attention networks.\" arXiv preprint arXiv:1710.10903 (2017)." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GATLayer(5 -> 11 * 5, activation=elu)\ntorch.Size([7, 5])\ntorch.Size([7, 55])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "layer = GATLayer(\n", + " in_dim=in_dim, out_dim=out_dim, num_heads=5,\n", + " activation=\"elu\", dropout=.3, batch_norm=True).to(float)\n", + "h_out = layer(graph, h_in)\n", + "\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "source": [ + "## Gated-GCN Layer\n", + "\n", + "To use the Gated-GCN layer from the *Bresson et al.* paper, the steps are different since the layer requires edge features as inputs, and outputs new edge features.\n", + "\n", + "Bresson, Xavier, and Thomas Laurent. \"Residual gated graph convnets.\" arXiv preprint arXiv:1711.07553 (2017)." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GatedGCNLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\ntorch.Size([14, 13])\ntorch.Size([14, 11])\n" + ] + } + ], + "source": [ + "# We first need to extract the node and edge features from the graph.\n", + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "e_in = graph.edata[\"e\"]\n", + "\n", + "# We create the layer\n", + "layer = GatedGCNLayer(\n", + " in_dim=in_dim, out_dim=out_dim, \n", + " in_dim_edges=in_dim_edges, out_dim_edges=out_dim,\n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "\n", + "# We apply the forward loop on the node features\n", + "h_out, e_out = layer(graph, h_in, e_in)\n", + "\n", + "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n", + "# 13 is the number of input edge features\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)\n", + "print(e_in.shape)\n", + "print(e_out.shape)" + ] + }, + { + "source": [ + "# PNA\n", + "\n", + "PNA is a multi-aggregator method proposed by *Corso et al.*. It supports 2 types of aggregations, convolutional *PNA-conv* or message passing *PNA-msgpass*.\n", + "\n", + "PNA: Principal Neighbourhood Aggregation \n", + "Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic\n", + "https://arxiv.org/abs/2004.05718\n", + "\n", + "## PNA-conv\n", + "\n", + "First, let's focus on the PNA-conv. In this case, it works exactly as the GCN and GIN methods." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "PNAConvolutionalLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "\n", + "# We create the layer, and need to specify the aggregators and scalers\n", + "layer = PNAConvolutionalLayer(\n", + " in_dim=in_dim, out_dim=out_dim, \n", + " aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n", + " scalers=[\"identity\", \"amplification\", \"attenuation\"],\n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "\n", + "h_out = layer(graph, h_in)\n", + "\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "source": [ + "## PNA-msgpass\n", + "\n", + "The PNA message passing is typically more powerful that the convolutional one, and it supports edges as inputs, but doesn't output edges. It's usage is very similar to the *PNA-conv*, but we need to specify the edge dimensions and features." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "PNAMessagePassingLayer(5 -> 11, activation=relu)\ntorch.Size([7, 5])\ntorch.Size([7, 11])\n" + ] + } + ], + "source": [ + "graph = deepcopy(bg)\n", + "h_in = graph.ndata[\"h\"]\n", + "e_in = graph.edata[\"e\"]\n", + "\n", + "# We create the layer, and need to specify the aggregators and scalers\n", + "layer = PNAMessagePassingLayer(\n", + " in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,\n", + " aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n", + " scalers=[\"identity\", \"amplification\", \"attenuation\"],\n", + " activation=\"relu\", dropout=.3, batch_norm=True).to(float)\n", + "\n", + "h_out = layer(graph, h_in, e_in)\n", + "\n", + "print(layer)\n", + "print(h_in.shape)\n", + "print(h_out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] } \ No newline at end of file diff --git a/docs/tutorials/model_training/simple-molecular-model.ipynb b/docs/tutorials/model_training/simple-molecular-model.ipynb index c383499b7..9f71398ca 100644 --- a/docs/tutorials/model_training/simple-molecular-model.ipynb +++ b/docs/tutorials/model_training/simple-molecular-model.ipynb @@ -1,954 +1,954 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Building and training a simple model from configurations\n", - "\n", - "This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.\n", - "\n", - "## Creating the yaml file\n", - "\n", - "The first step is to create a YAML file containing all the required configurations, with an example given at `goli/expts/config_micro_ZINC.yaml`. We will go through each part of the configurations." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "import omegaconf" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def print_config_with_key(config, key):\n", - " new_config = {key: config[key]}\n", - " print(omegaconf.OmegaConf.to_yaml(new_config))" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Yaml file loaded\n" - ] - } - ], - "source": [ - "# First, let's read the yaml configuration file\n", - "with open(\"../../../expts/config_micro_ZINC.yaml\", \"r\") as file:\n", - " yaml_config = yaml.load(file, Loader=yaml.FullLoader)\n", - "\n", - "print(\"Yaml file loaded\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Constants\n", - "\n", - "First, we define the constants such as the random seed and whether the model should raise or ignore an error." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "constants:\n", - " seed: 42\n", - " raise_train_error: true\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"constants\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Datamodule\n", - "\n", - "Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.\n", - "\n", - "For more details, see class `goli.data.datamodule.DGLFromSmilesDataModule`" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "datamodule:\n", - " df_path: goli/data/micro_ZINC/micro_ZINC.csv\n", - " cache_data_path: goli/data/cache/micro_ZINC/full.cache\n", - " label_cols:\n", - " - score\n", - " smiles_col: SMILES\n", - " featurization_n_jobs: -1\n", - " featurization_progress: true\n", - " featurization:\n", - " atom_property_list_onehot:\n", - " - atomic-number\n", - " - valence\n", - " atom_property_list_float:\n", - " - mass\n", - " - electronegativity\n", - " - in-ring\n", - " edge_property_list: []\n", - " add_self_loop: false\n", - " explicit_H: false\n", - " use_bonds_weights: false\n", - " split_val: 0.2\n", - " split_test: 0.2\n", - " split_seed: 42\n", - " splits_path: null\n", - " batch_size_train_val: 128\n", - " batch_size_test: 256\n", - " num_workers: -1\n", - " pin_memory: false\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"datamodule\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Architecture\n", - "\n", - "In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).\n", - "\n", - "The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as `gcn`, `gin`, `gat`, `gated-gcn`, `pna-conv` and `pna-msgpass`.\n", - "\n", - "For more details, see the following classes:\n", - "\n", - "- `goli.nn.architecture.FullDGLNetwork`: Main class for the architecture\n", - "- `goli.nn.architecture.FeedForwardNN`: Main class for the inputs and outputs MLP\n", - "- `goli.nn.architecture.FeedForwardDGL`: Main class for the GNN layers" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "architecture:\n", - " model_type: fulldglnetwork\n", - " pre_nn:\n", - " out_dim: 32\n", - " hidden_dims: 32\n", - " depth: 1\n", - " activation: relu\n", - " last_activation: none\n", - " dropout: 0.1\n", - " batch_norm: true\n", - " last_batch_norm: true\n", - " residual_type: none\n", - " gnn:\n", - " out_dim: 32\n", - " hidden_dims: 32\n", - " depth: 4\n", - " activation: relu\n", - " last_activation: none\n", - " dropout: 0.1\n", - " batch_norm: true\n", - " last_batch_norm: true\n", - " residual_type: simple\n", - " pooling: sum\n", - " virtual_node: sum\n", - " layer_type: pna-msgpass\n", - " layer_kwargs:\n", - " aggregators:\n", - " - mean\n", - " - max\n", - " - min\n", - " - std\n", - " scalers:\n", - " - identity\n", - " - amplification\n", - " - attenuation\n", - " post_nn:\n", - " out_dim: 1\n", - " hidden_dims: 32\n", - " depth: 2\n", - " activation: relu\n", - " last_activation: none\n", - " dropout: 0.1\n", - " batch_norm: true\n", - " last_batch_norm: false\n", - " residual_type: none\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"architecture\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Predictor\n", - "\n", - "In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "predictor:\n", - " metrics_on_progress_bar:\n", - " - mae\n", - " - pearsonr\n", - " - f1 > 3\n", - " - precision > 3\n", - " loss_fun: mse\n", - " random_seed: 42\n", - " optim_kwargs:\n", - " lr: 0.01\n", - " weight_decay: 1.0e-07\n", - " lr_reduce_on_plateau_kwargs:\n", - " factor: 0.5\n", - " patience: 7\n", - " scheduler_kwargs:\n", - " monitor: loss/val\n", - " frequency: 1\n", - " target_nan_mask: 0\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"predictor\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Metrics\n", - "\n", - "All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.\n", - "\n", - "See class `goli.trainer.metrics.MetricWrapper` for more details.\n", - "\n", - "See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "metrics:\n", - "- name: mae\n", - " metric: mae\n", - " threshold_kwargs: null\n", - "- name: pearsonr\n", - " metric: pearsonr\n", - " threshold_kwargs: null\n", - "- name: f1 > 3\n", - " metric: f1\n", - " num_classes: 2\n", - " average: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 3\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "- name: f1 > 5\n", - " metric: f1\n", - " num_classes: 2\n", - " average: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 5\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "- name: precision > 3\n", - " metric: precision\n", - " class_reduction: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 3\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"metrics\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Trainer\n", - "\n", - "Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainer:\n", - " logger:\n", - " save_dir: logs/micro_ZINC\n", - " early_stopping:\n", - " monitor: loss/val\n", - " min_delta: 0\n", - " patience: 10\n", - " mode: min\n", - " model_checkpoint:\n", - " dirpath: models_checkpoints/micro_ZINC/\n", - " filename: bob\n", - " monitor: loss/val\n", - " mode: min\n", - " save_top_k: 1\n", - " period: 1\n", - " trainer:\n", - " max_epochs: 25\n", - " min_epochs: 5\n", - " gpus: 1\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"trainer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training the model\n", - "\n", - "Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using backend: pytorch\n", - "2021-03-25 09:44:37.314 | WARNING | goli.config._loader:load_trainer:111 - Number of GPUs selected is `1`, but will be ignored since no GPU are available on this device\n", - "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: Checkpoint directory models_checkpoints/micro_ZINC/ exists and is not empty.\n", - " warnings.warn(*args, **kwargs)\n", - "GPU available: False, used: False\n", - "TPU available: None, using: 0 TPU cores\n", - "2021-03-25 09:44:37.331 | INFO | goli.data.datamodule:prepare_data:153 - Reload data from goli/data/cache/micro_ZINC/full.cache.\n", - "\n", - "datamodule:\n", - " name: DGLFromSmilesDataModule\n", - "len: 1000\n", - "batch_size_train_val: 128\n", - "batch_size_test: 256\n", - "num_node_feats: 55\n", - "num_edge_feats: 0\n", - "collate_fn: goli_collate_fn\n", - "featurization:\n", - " atom_property_list_onehot:\n", - " - atomic-number\n", - " - valence\n", - " atom_property_list_float:\n", - " - mass\n", - " - electronegativity\n", - " - in-ring\n", - " edge_property_list: []\n", - " add_self_loop: false\n", - " explicit_H: false\n", - " use_bonds_weights: false\n", - " \n", - "\n", - "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 3': f1(>3), 'f1 > 5': f1(>5), 'precision > 3': precision(>3)}\n", - "DGL_GNN\n", - "---------\n", - " pre-NN(depth=1, ResidualConnectionNone)\n", - " [FCLayer[55 -> 32]\n", - " \n", - " GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", - " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", - " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", - " \n", - " post-NN(depth=2, ResidualConnectionNone)\n", - " [FCLayer[32 -> 32 -> 1]\n", - " | Name | Type | Params\n", - "------------------------------------------------------------------------------\n", - "0 | model | FullDGLNetwork | 69.7 K\n", - "1 | model.pre_nn | FeedForwardNN | 1.9 K \n", - "2 | model.pre_nn.activation | ReLU | 0 \n", - "3 | model.pre_nn.residual_layer | ResidualConnectionNone | 0 \n", - "4 | model.pre_nn.layers | ModuleList | 1.9 K \n", - "5 | model.pre_nn.layers.0 | FCLayer | 1.9 K \n", - "6 | model.gnn | FeedForwardDGL | 66.7 K\n", - "7 | model.gnn.activation | ReLU | 0 \n", - "8 | model.gnn.layers | ModuleList | 62.2 K\n", - "9 | model.gnn.layers.0 | PNAMessagePassingLayer | 15.6 K\n", - "10 | model.gnn.layers.1 | PNAMessagePassingLayer | 15.6 K\n", - "11 | model.gnn.layers.2 | PNAMessagePassingLayer | 15.6 K\n", - "12 | model.gnn.layers.3 | PNAMessagePassingLayer | 15.6 K\n", - "13 | model.gnn.virtual_node_layers | ModuleList | 3.4 K \n", - "14 | model.gnn.virtual_node_layers.0 | VirtualNode | 1.1 K \n", - "15 | model.gnn.virtual_node_layers.1 | VirtualNode | 1.1 K \n", - "16 | model.gnn.virtual_node_layers.2 | VirtualNode | 1.1 K \n", - "17 | model.gnn.residual_layer | ResidualConnectionSimple | 0 \n", - "18 | model.gnn.global_pool_layer | ModuleListConcat | 0 \n", - "19 | model.gnn.global_pool_layer.0 | SumPooling | 0 \n", - "20 | model.gnn.out_linear | FCLayer | 1.1 K \n", - "21 | model.gnn.out_linear.linear | Linear | 1.1 K \n", - "22 | model.gnn.out_linear.dropout | Dropout | 0 \n", - "23 | model.gnn.out_linear.batch_norm | BatchNorm1d | 64 \n", - "24 | model.post_nn | FeedForwardNN | 1.2 K \n", - "25 | model.post_nn.activation | ReLU | 0 \n", - "26 | model.post_nn.residual_layer | ResidualConnectionNone | 0 \n", - "27 | model.post_nn.layers | ModuleList | 1.2 K \n", - "28 | model.post_nn.layers.0 | FCLayer | 1.1 K \n", - "29 | model.post_nn.layers.1 | FCLayer | 33 \n", - "30 | loss_fun | MSELoss | 0 \n", - "------------------------------------------------------------------------------\n", - "69.7 K Trainable params\n", - "0 Non-trainable params\n", - "69.7 K Total params\n", - "0.279 Total estimated model params size (MB)\n", - "\n", - " | Name | Type | Params\n", - "--------------------------------------------\n", - "0 | model | FullDGLNetwork | 69.7 K\n", - "1 | loss_fun | MSELoss | 0 \n", - "--------------------------------------------\n", - "69.7 K Trainable params\n", - "0 Non-trainable params\n", - "69.7 K Total params\n", - "0.279 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "61dc1894ee264599ab493d982b390430", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation sanity check: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "657567b3d0b546a1a648173c2bfb1e4a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "01792491c7fd49b08ce5086832135c7b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b9d96f2469234fe380d07ffd806350ed", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eed00b1f81524fe99e07e2084c952532", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c435a81142c24be09a0e38d7575b365b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ae0357d542ec4021b0c9c38fb38bd11c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1d24f83f1504d3b80b0741d1c0f404b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "647aee1f810f407c90697c394d0f604f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c69086a4802e421f816cab1ebbc20ae9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ebb4a32a0f78470ba21ff5bea8b450f4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "840d27d095344fcd9a74862e61a2fe7f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "905c7ee70f4b4282871c130b0c7b9f0a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "daab3285c0854c23a4e3dd47846c820e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e0e0d9096fc64198a470ae1b3cd7f351", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "99ac351f4e334e8c838a6913ef6bee08", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "69b47fad071248eab8095d67e33b5d5e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8c68dcc01135429e845427bb6908f414", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "096aaea9ce2649fba9bf70b99b7e7955", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d8ecff999d934a119157a3e0ca7a1c6a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4287a56d059b4eb2966eb2e90498a210", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0a5ffab4db4e4768a4876b01a8b10f96", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6177ef595f9542598e5b065d6d77bb32", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e86938e35b0b443791119e37dd2e2199", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "98aff21b49cc434dbaaaf12c355ab783", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e4c88c49c0c843c09934e9786e9b6aa5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "dcb917418a084d4ba36d57f5b0406819", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from os.path import dirname, abspath\n", - "from copy import deepcopy\n", - "\n", - "import goli\n", - "from goli.config._loader import (load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer)\n", - "\n", - "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", - "os.chdir(MAIN_DIR)\n", - "\n", - "cfg = dict(deepcopy(yaml_config))\n", - "\n", - "# Load and initialize the dataset\n", - "datamodule = load_datamodule(cfg)\n", - "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n", - "\n", - "# Initialize the network\n", - "model_class, model_kwargs = load_architecture(\n", - " cfg,\n", - " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", - " in_dim_edges=datamodule.num_edge_feats,\n", - ")\n", - "\n", - "metrics = load_metrics(cfg)\n", - "print(metrics)\n", - "\n", - "predictor = load_predictor(cfg, model_class, model_kwargs, metrics)\n", - "\n", - "print(predictor.model)\n", - "print(predictor.summarize(mode=4, to_print=False))\n", - "\n", - "trainer = load_trainer(cfg, metrics)\n", - "\n", - "# Run the model training\n", - "trainer.fit(model=predictor, datamodule=datamodule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.4" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": {}, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building and training a simple model from configurations\n", + "\n", + "This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.\n", + "\n", + "## Creating the yaml file\n", + "\n", + "The first step is to create a YAML file containing all the required configurations, with an example given at `goli/expts/config_micro_ZINC.yaml`. We will go through each part of the configurations." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "import omegaconf" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def print_config_with_key(config, key):\n", + " new_config = {key: config[key]}\n", + " print(omegaconf.OmegaConf.to_yaml(new_config))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Yaml file loaded\n" + ] + } + ], + "source": [ + "# First, let's read the yaml configuration file\n", + "with open(\"../../../expts/config_micro_ZINC.yaml\", \"r\") as file:\n", + " yaml_config = yaml.load(file, Loader=yaml.FullLoader)\n", + "\n", + "print(\"Yaml file loaded\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Constants\n", + "\n", + "First, we define the constants such as the random seed and whether the model should raise or ignore an error." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "constants:\n", + " seed: 42\n", + " raise_train_error: true\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"constants\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Datamodule\n", + "\n", + "Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.\n", + "\n", + "For more details, see class `goli.data.datamodule.DGLFromSmilesDataModule`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "datamodule:\n", + " df_path: goli/data/micro_ZINC/micro_ZINC.csv\n", + " cache_data_path: goli/data/cache/micro_ZINC/full.cache\n", + " label_cols:\n", + " - score\n", + " smiles_col: SMILES\n", + " featurization_n_jobs: -1\n", + " featurization_progress: true\n", + " featurization:\n", + " atom_property_list_onehot:\n", + " - atomic-number\n", + " - valence\n", + " atom_property_list_float:\n", + " - mass\n", + " - electronegativity\n", + " - in-ring\n", + " edge_property_list: []\n", + " add_self_loop: false\n", + " explicit_H: false\n", + " use_bonds_weights: false\n", + " split_val: 0.2\n", + " split_test: 0.2\n", + " split_seed: 42\n", + " splits_path: null\n", + " batch_size_train_val: 128\n", + " batch_size_test: 256\n", + " num_workers: -1\n", + " pin_memory: false\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"datamodule\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Architecture\n", + "\n", + "In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).\n", + "\n", + "The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as `gcn`, `gin`, `gat`, `gated-gcn`, `pna-conv` and `pna-msgpass`.\n", + "\n", + "For more details, see the following classes:\n", + "\n", + "- `goli.nn.architecture.FullDGLNetwork`: Main class for the architecture\n", + "- `goli.nn.architecture.FeedForwardNN`: Main class for the inputs and outputs MLP\n", + "- `goli.nn.architecture.FeedForwardDGL`: Main class for the GNN layers" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "architecture:\n", + " model_type: fulldglnetwork\n", + " pre_nn:\n", + " out_dim: 32\n", + " hidden_dims: 32\n", + " depth: 1\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: true\n", + " residual_type: none\n", + " gnn:\n", + " out_dim: 32\n", + " hidden_dims: 32\n", + " depth: 4\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: true\n", + " residual_type: simple\n", + " pooling: sum\n", + " virtual_node: sum\n", + " layer_type: pna-msgpass\n", + " layer_kwargs:\n", + " aggregators:\n", + " - mean\n", + " - max\n", + " - min\n", + " - std\n", + " scalers:\n", + " - identity\n", + " - amplification\n", + " - attenuation\n", + " post_nn:\n", + " out_dim: 1\n", + " hidden_dims: 32\n", + " depth: 2\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: false\n", + " residual_type: none\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"architecture\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Predictor\n", + "\n", + "In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predictor:\n", + " metrics_on_progress_bar:\n", + " - mae\n", + " - pearsonr\n", + " - f1 > 3\n", + " - precision > 3\n", + " loss_fun: mse\n", + " random_seed: 42\n", + " optim_kwargs:\n", + " lr: 0.01\n", + " weight_decay: 1.0e-07\n", + " lr_reduce_on_plateau_kwargs:\n", + " factor: 0.5\n", + " patience: 7\n", + " scheduler_kwargs:\n", + " monitor: loss/val\n", + " frequency: 1\n", + " target_nan_mask: 0\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"predictor\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metrics\n", + "\n", + "All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.\n", + "\n", + "See class `goli.trainer.metrics.MetricWrapper` for more details.\n", + "\n", + "See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "metrics:\n", + "- name: mae\n", + " metric: mae\n", + " threshold_kwargs: null\n", + "- name: pearsonr\n", + " metric: pearsonr\n", + " threshold_kwargs: null\n", + "- name: f1 > 3\n", + " metric: f1\n", + " num_classes: 2\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 3\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "- name: f1 > 5\n", + " metric: f1\n", + " num_classes: 2\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 5\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "- name: precision > 3\n", + " metric: precision\n", + " class_reduction: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 3\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"metrics\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Trainer\n", + "\n", + "Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainer:\n", + " logger:\n", + " save_dir: logs/micro_ZINC\n", + " early_stopping:\n", + " monitor: loss/val\n", + " min_delta: 0\n", + " patience: 10\n", + " mode: min\n", + " model_checkpoint:\n", + " dirpath: models_checkpoints/micro_ZINC/\n", + " filename: bob\n", + " monitor: loss/val\n", + " mode: min\n", + " save_top_k: 1\n", + " period: 1\n", + " trainer:\n", + " max_epochs: 25\n", + " min_epochs: 5\n", + " gpus: 1\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"trainer\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the model\n", + "\n", + "Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n", + "2021-03-25 09:44:37.314 | WARNING | goli.config._loader:load_trainer:111 - Number of GPUs selected is `1`, but will be ignored since no GPU are available on this device\n", + "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: Checkpoint directory models_checkpoints/micro_ZINC/ exists and is not empty.\n", + " warnings.warn(*args, **kwargs)\n", + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n", + "2021-03-25 09:44:37.331 | INFO | goli.data.datamodule:prepare_data:153 - Reload data from goli/data/cache/micro_ZINC/full.cache.\n", + "\n", + "datamodule:\n", + " name: DGLFromSmilesDataModule\n", + "len: 1000\n", + "batch_size_train_val: 128\n", + "batch_size_test: 256\n", + "num_node_feats: 55\n", + "num_edge_feats: 0\n", + "collate_fn: goli_collate_fn\n", + "featurization:\n", + " atom_property_list_onehot:\n", + " - atomic-number\n", + " - valence\n", + " atom_property_list_float:\n", + " - mass\n", + " - electronegativity\n", + " - in-ring\n", + " edge_property_list: []\n", + " add_self_loop: false\n", + " explicit_H: false\n", + " use_bonds_weights: false\n", + " \n", + "\n", + "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 3': f1(>3), 'f1 > 5': f1(>5), 'precision > 3': precision(>3)}\n", + "DGL_GNN\n", + "---------\n", + " pre-NN(depth=1, ResidualConnectionNone)\n", + " [FCLayer[55 -> 32]\n", + " \n", + " GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", + " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", + " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", + " \n", + " post-NN(depth=2, ResidualConnectionNone)\n", + " [FCLayer[32 -> 32 -> 1]\n", + " | Name | Type | Params\n", + "------------------------------------------------------------------------------\n", + "0 | model | FullDGLNetwork | 69.7 K\n", + "1 | model.pre_nn | FeedForwardNN | 1.9 K \n", + "2 | model.pre_nn.activation | ReLU | 0 \n", + "3 | model.pre_nn.residual_layer | ResidualConnectionNone | 0 \n", + "4 | model.pre_nn.layers | ModuleList | 1.9 K \n", + "5 | model.pre_nn.layers.0 | FCLayer | 1.9 K \n", + "6 | model.gnn | FeedForwardDGL | 66.7 K\n", + "7 | model.gnn.activation | ReLU | 0 \n", + "8 | model.gnn.layers | ModuleList | 62.2 K\n", + "9 | model.gnn.layers.0 | PNAMessagePassingLayer | 15.6 K\n", + "10 | model.gnn.layers.1 | PNAMessagePassingLayer | 15.6 K\n", + "11 | model.gnn.layers.2 | PNAMessagePassingLayer | 15.6 K\n", + "12 | model.gnn.layers.3 | PNAMessagePassingLayer | 15.6 K\n", + "13 | model.gnn.virtual_node_layers | ModuleList | 3.4 K \n", + "14 | model.gnn.virtual_node_layers.0 | VirtualNode | 1.1 K \n", + "15 | model.gnn.virtual_node_layers.1 | VirtualNode | 1.1 K \n", + "16 | model.gnn.virtual_node_layers.2 | VirtualNode | 1.1 K \n", + "17 | model.gnn.residual_layer | ResidualConnectionSimple | 0 \n", + "18 | model.gnn.global_pool_layer | ModuleListConcat | 0 \n", + "19 | model.gnn.global_pool_layer.0 | SumPooling | 0 \n", + "20 | model.gnn.out_linear | FCLayer | 1.1 K \n", + "21 | model.gnn.out_linear.linear | Linear | 1.1 K \n", + "22 | model.gnn.out_linear.dropout | Dropout | 0 \n", + "23 | model.gnn.out_linear.batch_norm | BatchNorm1d | 64 \n", + "24 | model.post_nn | FeedForwardNN | 1.2 K \n", + "25 | model.post_nn.activation | ReLU | 0 \n", + "26 | model.post_nn.residual_layer | ResidualConnectionNone | 0 \n", + "27 | model.post_nn.layers | ModuleList | 1.2 K \n", + "28 | model.post_nn.layers.0 | FCLayer | 1.1 K \n", + "29 | model.post_nn.layers.1 | FCLayer | 33 \n", + "30 | loss_fun | MSELoss | 0 \n", + "------------------------------------------------------------------------------\n", + "69.7 K Trainable params\n", + "0 Non-trainable params\n", + "69.7 K Total params\n", + "0.279 Total estimated model params size (MB)\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------\n", + "0 | model | FullDGLNetwork | 69.7 K\n", + "1 | loss_fun | MSELoss | 0 \n", + "--------------------------------------------\n", + "69.7 K Trainable params\n", + "0 Non-trainable params\n", + "69.7 K Total params\n", + "0.279 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "61dc1894ee264599ab493d982b390430", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation sanity check: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "657567b3d0b546a1a648173c2bfb1e4a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "01792491c7fd49b08ce5086832135c7b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b9d96f2469234fe380d07ffd806350ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eed00b1f81524fe99e07e2084c952532", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c435a81142c24be09a0e38d7575b365b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ae0357d542ec4021b0c9c38fb38bd11c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a1d24f83f1504d3b80b0741d1c0f404b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "647aee1f810f407c90697c394d0f604f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c69086a4802e421f816cab1ebbc20ae9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ebb4a32a0f78470ba21ff5bea8b450f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "840d27d095344fcd9a74862e61a2fe7f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "905c7ee70f4b4282871c130b0c7b9f0a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "daab3285c0854c23a4e3dd47846c820e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e0e0d9096fc64198a470ae1b3cd7f351", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "99ac351f4e334e8c838a6913ef6bee08", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "69b47fad071248eab8095d67e33b5d5e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8c68dcc01135429e845427bb6908f414", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "096aaea9ce2649fba9bf70b99b7e7955", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d8ecff999d934a119157a3e0ca7a1c6a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4287a56d059b4eb2966eb2e90498a210", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0a5ffab4db4e4768a4876b01a8b10f96", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6177ef595f9542598e5b065d6d77bb32", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e86938e35b0b443791119e37dd2e2199", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98aff21b49cc434dbaaaf12c355ab783", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4c88c49c0c843c09934e9786e9b6aa5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dcb917418a084d4ba36d57f5b0406819", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from os.path import dirname, abspath\n", + "from copy import deepcopy\n", + "\n", + "import goli\n", + "from goli.config._loader import (load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer)\n", + "\n", + "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", + "os.chdir(MAIN_DIR)\n", + "\n", + "cfg = dict(deepcopy(yaml_config))\n", + "\n", + "# Load and initialize the dataset\n", + "datamodule = load_datamodule(cfg)\n", + "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n", + "\n", + "# Initialize the network\n", + "model_class, model_kwargs = load_architecture(\n", + " cfg,\n", + " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", + " in_dim_edges=datamodule.num_edge_feats,\n", + ")\n", + "\n", + "metrics = load_metrics(cfg)\n", + "print(metrics)\n", + "\n", + "predictor = load_predictor(cfg, model_class, model_kwargs, metrics)\n", + "\n", + "print(predictor.model)\n", + "print(predictor.summarize(mode=4, to_print=False))\n", + "\n", + "trainer = load_trainer(cfg, metrics)\n", + "\n", + "# Run the model training\n", + "trainer.fit(model=predictor, datamodule=datamodule)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.4" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/env.yml b/env.yml index f2307e223..06a3ba457 100644 --- a/env.yml +++ b/env.yml @@ -1,66 +1,66 @@ -channels: - - pytorch - - dglteam - - conda-forge - -dependencies: - - python >=3.7 - - pip - - click - - loguru - - omegaconf >=2.0.0 - - tqdm - - ipywidgets - - # scientific - - numpy - - scipy >=1.4 - - pandas >=1.0 - - scikit-learn - - # viz - - matplotlib >=3.0.1 - - seaborn - - # cloud IO - - fsspec - - s3fs - - gcsfs - - appdirs - - # ML packages - - cudatoolkit # works also with CPU-only system. - - pytorch >=1.7.0 - - torchvision - - tensorboard - - pytorch-lightning >=1.2.3 - - torchmetrics >=0.2.0 - - dgl >=0.5.2 - - hydra-core >=1.0 - - ogb - - # chemistry - - rdkit >=2020.09 - - datamol >=0.3.4 - - mordred - - umap-learn - - # Dev - - pytest >=6.0 - - pytest-cov - - black >=20.8b1 - - jupyterlab - - anaconda-client - - # Doc - - mkdocs - - mkdocs-material - - mkdocs-material-extensions - - mkdocstrings - - mkdocs-jupyter - - mkdocs-click - - markdown-include - - # Releasing tools - - rever >=0.4.5 - - conda-smithy +channels: + - pytorch + - dglteam + - conda-forge + +dependencies: + - python >=3.7 + - pip + - click + - loguru + - omegaconf >=2.0.0 + - tqdm + - ipywidgets + + # scientific + - numpy + - scipy >=1.4 + - pandas >=1.0 + - scikit-learn + + # viz + - matplotlib >=3.0.1 + - seaborn + + # cloud IO + - fsspec + - s3fs + - gcsfs + - appdirs + + # ML packages + - cudatoolkit # works also with CPU-only system. + - pytorch >=1.7.0 + - torchvision + - tensorboard + - pytorch-lightning >=1.2.3 + - torchmetrics >=0.2.0 + - dgl >=0.5.2 + - hydra-core >=1.0 + - ogb + + # chemistry + - rdkit >=2020.09 + - datamol >=0.3.4 + - mordred + - umap-learn + + # Dev + - pytest >=6.0 + - pytest-cov + - black >=20.8b1 + - jupyterlab + - anaconda-client + + # Doc + - mkdocs + - mkdocs-material + - mkdocs-material-extensions + - mkdocstrings + - mkdocs-jupyter + - mkdocs-click + - markdown-include + + # Releasing tools + - rever >=0.4.5 + - conda-smithy diff --git a/expts/config_ZINC_bench_gnn.yaml b/expts/config_ZINC_bench_gnn.yaml index 76db1d973..173b1f629 100644 --- a/expts/config_ZINC_bench_gnn.yaml +++ b/expts/config_ZINC_bench_gnn.yaml @@ -1,159 +1,159 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - df_path: goli/data/ZINC_bench_gnn/smiles_score.csv - cache_data_path: goli/data/cache/ZINC_bench_gnn/smiles_score.cache - label_cols: ['score'] - smiles_col: SMILES - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity] - edge_property_list: [bond-type-onehot, stereo] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: null - split_test: null - split_seed: *seed - splits_path: goli/data/ZINC_bench_gnn/indexes_train_val_test.csv - batch_size_train_val: 128 - batch_size_test: 256 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers, false for single worker - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: &middle_dim 90 - hidden_dims: *middle_dim - depth: 0 - activation: relu - last_activation: none - dropout: &dropout 0. - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: none - - gnn: # Set as null to avoid a post-nn network - out_dim: *middle_dim - hidden_dims: *middle_dim - depth: 4 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: 'sum' - virtual_node: none - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, dir1/dx_abs, dir1/smooth] - scalers: [identity, amplification, attenuation] - - post_nn: - out_dim: 1 - hidden_dims: *middle_dim - depth: 0 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: False - residual_type: none - - -predictor: - metrics_on_progress_bar: ["mae", "pearsonr", "f1 < 0", "precision < 0"] - loss_fun: mse - random_seed: *seed - optim_kwargs: - lr: 1.e-3 - weight_decay: 3.e-6 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 50 - min_lr: 1.e-5 - scheduler_kwargs: - monitor: &monitor loss/val - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - - - name: pearsonr - metric: pearsonr - threshold_kwargs: null - - - name: f1 < 0 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_0 - operator: lower - threshold: 0 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: f1 < -1 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: - operator: lower - threshold: -1 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: precision < 0 - metric: precision - class_reduction: micro - threshold_kwargs: *threshold_0 - -trainer: - logger: - save_dir: logs/ZINC_bench_gnn - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 200 - mode: &mode min - model_checkpoint: - dirpath: models_checkpoints/ZINC_bench_gnn/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 2000 - min_epochs: 100 - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + df_path: goli/data/ZINC_bench_gnn/smiles_score.csv + cache_data_path: goli/data/cache/ZINC_bench_gnn/smiles_score.cache + label_cols: ['score'] + smiles_col: SMILES + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity] + edge_property_list: [bond-type-onehot, stereo] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: null + split_test: null + split_seed: *seed + splits_path: goli/data/ZINC_bench_gnn/indexes_train_val_test.csv + batch_size_train_val: 128 + batch_size_test: 256 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers, false for single worker + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: &middle_dim 90 + hidden_dims: *middle_dim + depth: 0 + activation: relu + last_activation: none + dropout: &dropout 0. + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: none + + gnn: # Set as null to avoid a post-nn network + out_dim: *middle_dim + hidden_dims: *middle_dim + depth: 4 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: 'sum' + virtual_node: none + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, dir1/dx_abs, dir1/smooth] + scalers: [identity, amplification, attenuation] + + post_nn: + out_dim: 1 + hidden_dims: *middle_dim + depth: 0 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: False + residual_type: none + + +predictor: + metrics_on_progress_bar: ["mae", "pearsonr", "f1 < 0", "precision < 0"] + loss_fun: mse + random_seed: *seed + optim_kwargs: + lr: 1.e-3 + weight_decay: 3.e-6 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 50 + min_lr: 1.e-5 + scheduler_kwargs: + monitor: &monitor loss/val + frequency: 1 + target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + + - name: pearsonr + metric: pearsonr + threshold_kwargs: null + + - name: f1 < 0 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_0 + operator: lower + threshold: 0 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: f1 < -1 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: + operator: lower + threshold: -1 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: precision < 0 + metric: precision + class_reduction: micro + threshold_kwargs: *threshold_0 + +trainer: + logger: + save_dir: logs/ZINC_bench_gnn + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 200 + mode: &mode min + model_checkpoint: + dirpath: models_checkpoints/ZINC_bench_gnn/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 2000 + min_epochs: 100 + gpus: 1 + \ No newline at end of file diff --git a/expts/config_bindingDB_pretrained.yaml b/expts/config_bindingDB_pretrained.yaml index e58ed308a..41306e7bc 100644 --- a/expts/config_bindingDB_pretrained.yaml +++ b/expts/config_bindingDB_pretrained.yaml @@ -1,52 +1,52 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - args: - df_path: goli/data/BindingDB/BindingDB_All.tsv - cache_data_path: null # goli/data/cache/BindingDB/full.cache - label_cols: ['IC50 (nM)'] - smiles_col: Ligand SMILES - sample_size: null - - # Weights - weights_type: null - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: 0.2 - split_test: 0.2 - split_seed: *seed - splits_path: null - batch_size_train_val: 64 - batch_size_test: 64 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + args: + df_path: goli/data/BindingDB/BindingDB_All.tsv + cache_data_path: null # goli/data/cache/BindingDB/full.cache + label_cols: ['IC50 (nM)'] + smiles_col: Ligand SMILES + sample_size: null + + # Weights + weights_type: null + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: 0.2 + split_test: 0.2 + split_seed: *seed + splits_path: null + batch_size_train_val: 64 + batch_size_test: 64 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/config_htsfp_pcba.yaml b/expts/config_htsfp_pcba.yaml index 3066804a0..cd516f6a9 100644 --- a/expts/config_htsfp_pcba.yaml +++ b/expts/config_htsfp_pcba.yaml @@ -1,203 +1,203 @@ -constants: - seed: &seed 42 - raise_train_error: True # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - args: - df_path: https://storage.googleapis.com/goli-public/datasets/goli-htsfp-pcba.csv.gz - - cache_data_path: goli/data/cache/htsfp-pcba/full.cache - label_cols: "assayID-*" - smiles_col: SMILES - - # Weights - weights_type: null - sample_size: null - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: 0.1 - split_test: 0.1 - split_seed: *seed - batch_size_train_val: 1024 - batch_size_test: 1024 - - # Data loading - num_workers: 8 - pin_memory: True - persistent_workers: False # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: &hidden_dim 400 - hidden_dims: *hidden_dim - depth: 3 - activation: relu - last_activation: none - dropout: &dropout_mlp 0.3 - last_dropout: *dropout_mlp - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: simple - - pre_nn_edges: # Set as null to avoid a pre-nn network - out_dim: 32 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout_mlp - last_dropout: *dropout_mlp - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - - gnn: # Set as null to avoid a post-nn network - out_dim: &out_mlp_dim 700 - hidden_dims: *hidden_dim - depth: 5 - activation: none - last_activation: none - dropout: &dropout_gnn 0.2 - last_dropout: *dropout_gnn - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: ['sum', 'max'] - virtual_node: 'none' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, sum, dir1/dx_abs] - scalers: [identity] - - post_nn: - out_dim: 689 - hidden_dims: *out_mlp_dim - depth: 3 - activation: relu - last_activation: sigmoid - dropout: *dropout_mlp - last_dropout: 0. - batch_norm: *batch_norm - last_batch_norm: False - residual_type: simple - -predictor: - metrics_on_progress_bar: ["mae", "averageprecision", "auroc"] - metrics_on_training_set: ["mae", "averageprecision"] - loss_fun: bce - random_seed: *seed - optim_kwargs: - lr: 5.e-3 - weight_decay: 0 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 13 - min_lr: 2.e-4 - scheduler_kwargs: - monitor: &monitor averageprecision/val - mode: &mode max - frequency: 1 - target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - target_nan_mask: ignore-flatten - - - name: averageprecision - metric: averageprecision - target_nan_mask: ignore-mean-label - threshold_kwargs: - threshold: null - th_on_preds: False - th_on_target: False - target_to_int: True - - - name: auroc - metric: auroc - target_nan_mask: ignore-mean-label - threshold_kwargs: - threshold: null - th_on_preds: False - th_on_target: False - target_to_int: True - - - name: f1 > 0.5 - metric: f1 - num_classes: 2 - average: micro - target_nan_mask: ignore-mean-label - threshold_kwargs: &threshold_1 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: f1 > 0.25 - metric: f1 - num_classes: 2 - average: micro - target_nan_mask: ignore-mean-label - threshold_kwargs: &threshold_2 - operator: greater - threshold: 0.25 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: recall > 0.5 - metric: recall - target_nan_mask: ignore-mean-label - threshold_kwargs: *threshold_1 - - - name: recall > 0.25 - metric: recall - target_nan_mask: ignore-mean-label - threshold_kwargs: *threshold_2 - -trainer: - logger: - save_dir: logs/htsfp-pcba - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 40 - mode: *mode - model_checkpoint: - dirpath: models_checkpoints/htsfp-pcba/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 1000 - min_epochs: 100 - gpus: 1 - accumulate_grad_batches: 1 - +constants: + seed: &seed 42 + raise_train_error: True # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + args: + df_path: https://storage.googleapis.com/goli-public/datasets/goli-htsfp-pcba.csv.gz + + cache_data_path: goli/data/cache/htsfp-pcba/full.cache + label_cols: "assayID-*" + smiles_col: SMILES + + # Weights + weights_type: null + sample_size: null + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: 0.1 + split_test: 0.1 + split_seed: *seed + batch_size_train_val: 1024 + batch_size_test: 1024 + + # Data loading + num_workers: 8 + pin_memory: True + persistent_workers: False # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: &hidden_dim 400 + hidden_dims: *hidden_dim + depth: 3 + activation: relu + last_activation: none + dropout: &dropout_mlp 0.3 + last_dropout: *dropout_mlp + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: simple + + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout_mlp + last_dropout: *dropout_mlp + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + + gnn: # Set as null to avoid a post-nn network + out_dim: &out_mlp_dim 700 + hidden_dims: *hidden_dim + depth: 5 + activation: none + last_activation: none + dropout: &dropout_gnn 0.2 + last_dropout: *dropout_gnn + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: ['sum', 'max'] + virtual_node: 'none' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, sum, dir1/dx_abs] + scalers: [identity] + + post_nn: + out_dim: 689 + hidden_dims: *out_mlp_dim + depth: 3 + activation: relu + last_activation: sigmoid + dropout: *dropout_mlp + last_dropout: 0. + batch_norm: *batch_norm + last_batch_norm: False + residual_type: simple + +predictor: + metrics_on_progress_bar: ["mae", "averageprecision", "auroc"] + metrics_on_training_set: ["mae", "averageprecision"] + loss_fun: bce + random_seed: *seed + optim_kwargs: + lr: 5.e-3 + weight_decay: 0 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 13 + min_lr: 2.e-4 + scheduler_kwargs: + monitor: &monitor averageprecision/val + mode: &mode max + frequency: 1 + target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + target_nan_mask: ignore-flatten + + - name: averageprecision + metric: averageprecision + target_nan_mask: ignore-mean-label + threshold_kwargs: + threshold: null + th_on_preds: False + th_on_target: False + target_to_int: True + + - name: auroc + metric: auroc + target_nan_mask: ignore-mean-label + threshold_kwargs: + threshold: null + th_on_preds: False + th_on_target: False + target_to_int: True + + - name: f1 > 0.5 + metric: f1 + num_classes: 2 + average: micro + target_nan_mask: ignore-mean-label + threshold_kwargs: &threshold_1 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: f1 > 0.25 + metric: f1 + num_classes: 2 + average: micro + target_nan_mask: ignore-mean-label + threshold_kwargs: &threshold_2 + operator: greater + threshold: 0.25 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: recall > 0.5 + metric: recall + target_nan_mask: ignore-mean-label + threshold_kwargs: *threshold_1 + + - name: recall > 0.25 + metric: recall + target_nan_mask: ignore-mean-label + threshold_kwargs: *threshold_2 + +trainer: + logger: + save_dir: logs/htsfp-pcba + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 40 + mode: *mode + model_checkpoint: + dirpath: models_checkpoints/htsfp-pcba/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 1000 + min_epochs: 100 + gpus: 1 + accumulate_grad_batches: 1 + \ No newline at end of file diff --git a/expts/config_micro_ZINC.yaml b/expts/config_micro_ZINC.yaml index f67018daa..e8e85b3bd 100644 --- a/expts/config_micro_ZINC.yaml +++ b/expts/config_micro_ZINC.yaml @@ -1,169 +1,169 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - args: - df_path: goli/data/micro_ZINC/micro_ZINC.csv - cache_data_path: goli/data/cache/micro_ZINC/full.cache - label_cols: ['score'] - smiles_col: SMILES - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: 0.2 - split_test: 0.2 - split_seed: *seed - splits_path: null - batch_size_train_val: 128 - batch_size_test: 128 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 32 - hidden_dims: 32 - depth: 1 - activation: relu - last_activation: none - dropout: &dropout 0.1 - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: none - - pre_nn_edges: # Set as null to avoid a pre-nn network - out_dim: 16 - hidden_dims: 16 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: none - - gnn: # Set as null to avoid a post-nn network - out_dim: 32 - hidden_dims: 32 - depth: 4 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: [sum, max, dir1] - virtual_node: 'sum' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, dir1/dx_abs, dir1/smooth] - scalers: [identity, amplification, attenuation] - - post_nn: - out_dim: 1 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: False - residual_type: none - -predictor: - metrics_on_progress_bar: ["mae", "pearsonr", "f1 > 3", "precision > 3"] - loss_fun: mse - random_seed: *seed - optim_kwargs: - lr: 1.e-2 - weight_decay: 1.e-7 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 7 - scheduler_kwargs: - monitor: &monitor loss/val - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - - - name: pearsonr - metric: pearsonr - threshold_kwargs: null - - - name: f1 > 3 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_3 - operator: greater - threshold: 3 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: f1 > 5 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: - operator: greater - threshold: 5 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: precision > 3 - metric: precision - class_reduction: micro - threshold_kwargs: *threshold_3 - -trainer: - logger: - save_dir: logs/micro_ZINC - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 10 - mode: &mode min - model_checkpoint: - dirpath: models_checkpoints/micro_ZINC/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 25 - min_epochs: 5 - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + args: + df_path: goli/data/micro_ZINC/micro_ZINC.csv + cache_data_path: goli/data/cache/micro_ZINC/full.cache + label_cols: ['score'] + smiles_col: SMILES + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring] + edge_property_list: [bond-type-onehot, stereo, in-ring] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: 0.2 + split_test: 0.2 + split_seed: *seed + splits_path: null + batch_size_train_val: 128 + batch_size_test: 128 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 32 + depth: 1 + activation: relu + last_activation: none + dropout: &dropout 0.1 + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: none + + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 16 + hidden_dims: 16 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: none + + gnn: # Set as null to avoid a post-nn network + out_dim: 32 + hidden_dims: 32 + depth: 4 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: [sum, max, dir1] + virtual_node: 'sum' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, dir1/dx_abs, dir1/smooth] + scalers: [identity, amplification, attenuation] + + post_nn: + out_dim: 1 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: False + residual_type: none + +predictor: + metrics_on_progress_bar: ["mae", "pearsonr", "f1 > 3", "precision > 3"] + loss_fun: mse + random_seed: *seed + optim_kwargs: + lr: 1.e-2 + weight_decay: 1.e-7 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 7 + scheduler_kwargs: + monitor: &monitor loss/val + frequency: 1 + target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + + - name: pearsonr + metric: pearsonr + threshold_kwargs: null + + - name: f1 > 3 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_3 + operator: greater + threshold: 3 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: f1 > 5 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: + operator: greater + threshold: 5 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: precision > 3 + metric: precision + class_reduction: micro + threshold_kwargs: *threshold_3 + +trainer: + logger: + save_dir: logs/micro_ZINC + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 10 + mode: &mode min + model_checkpoint: + dirpath: models_checkpoints/micro_ZINC/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 25 + min_epochs: 5 + gpus: 1 + \ No newline at end of file diff --git a/expts/config_molHIV.yaml b/expts/config_molHIV.yaml index bd5dd2095..e2b7b2698 100644 --- a/expts/config_molHIV.yaml +++ b/expts/config_molHIV.yaml @@ -1,165 +1,165 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-molhiv" - cache_data_path: goli/data/cache/ogb-molhiv/full.cache - - # Weights - weights_type: sample_balanced - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring] #, hybridization, chirality, aromatic, degree, formal-charge, single-bond] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 256 - batch_size_test: 256 - - # Data loading - num_workers: 4 - pin_memory: False - persistent_workers: True # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: &hidden_dim 32 - hidden_dims: *hidden_dim - depth: 1 - activation: relu - last_activation: none - dropout: &dropout 0.1 - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: none - - gnn: # Set as null to avoid a post-nn network - out_dim: *hidden_dim - hidden_dims: *hidden_dim - depth: 4 - activation: none - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: sum - virtual_node: 'none' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, min, dir1/dx_abs, dir1/smooth] - scalers: [identity] - - post_nn: - out_dim: 1 - hidden_dims: *hidden_dim - depth: 1 - activation: relu - last_activation: sigmoid - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: False - residual_type: none - -predictor: - metrics_on_progress_bar: ["mae", "auroc", "f1 > 0.5"] - loss_fun: bce - random_seed: *seed - optim_kwargs: - lr: 1.e-4 - weight_decay: 0 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 10 - min_lr: 1.e-5 - scheduler_kwargs: - monitor: &monitor auroc/val - mode: &mode max - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - - - name: auroc - metric: auroc - threshold_kwargs: - threshold: null - th_on_preds: False - th_on_target: False - target_to_int: True - - - name: f1 > 0.5 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_1 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: f1 > 0.25 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_2 - operator: greater - threshold: 0.25 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: recall > 0.5 - metric: recall - threshold_kwargs: *threshold_1 - - - name: recall > 0.25 - metric: recall - threshold_kwargs: *threshold_2 - -trainer: - logger: - save_dir: logs/ogb-molhiv - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 80 - mode: *mode - model_checkpoint: - dirpath: models_checkpoints/ogb-molhiv/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 1000 - min_epochs: 100 - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-molhiv" + cache_data_path: goli/data/cache/ogb-molhiv/full.cache + + # Weights + weights_type: sample_balanced + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring] #, hybridization, chirality, aromatic, degree, formal-charge, single-bond] + edge_property_list: [bond-type-onehot, stereo, in-ring] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 256 + batch_size_test: 256 + + # Data loading + num_workers: 4 + pin_memory: False + persistent_workers: True # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: &hidden_dim 32 + hidden_dims: *hidden_dim + depth: 1 + activation: relu + last_activation: none + dropout: &dropout 0.1 + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: none + + gnn: # Set as null to avoid a post-nn network + out_dim: *hidden_dim + hidden_dims: *hidden_dim + depth: 4 + activation: none + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: sum + virtual_node: 'none' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, min, dir1/dx_abs, dir1/smooth] + scalers: [identity] + + post_nn: + out_dim: 1 + hidden_dims: *hidden_dim + depth: 1 + activation: relu + last_activation: sigmoid + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: False + residual_type: none + +predictor: + metrics_on_progress_bar: ["mae", "auroc", "f1 > 0.5"] + loss_fun: bce + random_seed: *seed + optim_kwargs: + lr: 1.e-4 + weight_decay: 0 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 10 + min_lr: 1.e-5 + scheduler_kwargs: + monitor: &monitor auroc/val + mode: &mode max + frequency: 1 + target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + + - name: auroc + metric: auroc + threshold_kwargs: + threshold: null + th_on_preds: False + th_on_target: False + target_to_int: True + + - name: f1 > 0.5 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_1 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: f1 > 0.25 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_2 + operator: greater + threshold: 0.25 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: recall > 0.5 + metric: recall + threshold_kwargs: *threshold_1 + + - name: recall > 0.25 + metric: recall + threshold_kwargs: *threshold_2 + +trainer: + logger: + save_dir: logs/ogb-molhiv + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 80 + mode: *mode + model_checkpoint: + dirpath: models_checkpoints/ogb-molhiv/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 1000 + min_epochs: 100 + gpus: 1 + \ No newline at end of file diff --git a/expts/config_molHIV_pretrained.yaml b/expts/config_molHIV_pretrained.yaml index 12170c60b..3cba9733a 100644 --- a/expts/config_molHIV_pretrained.yaml +++ b/expts/config_molHIV_pretrained.yaml @@ -1,46 +1,46 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-molhiv" - cache_data_path: goli/data/cache/ogb-molhiv/full.cache - - # Weights - weights_type: sample_balanced - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 256 - batch_size_test: 256 - - # Data loading - num_workers: 4 - pin_memory: False - persistent_workers: True # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-molhiv" + cache_data_path: goli/data/cache/ogb-molhiv/full.cache + + # Weights + weights_type: sample_balanced + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 256 + batch_size_test: 256 + + # Data loading + num_workers: 4 + pin_memory: False + persistent_workers: True # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/config_molPCBA.yaml b/expts/config_molPCBA.yaml index 548d183b4..dd12b1515 100644 --- a/expts/config_molPCBA.yaml +++ b/expts/config_molPCBA.yaml @@ -1,197 +1,197 @@ -constants: - seed: &seed 42 - raise_train_error: false # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - dataset_name: "ogbg-molpcba" - cache_data_path: goli/data/cache/ogb-molpcba/full-morefeatures.cache - - # Weights - weights_type: null - sample_size: null - - # Featurization - featurization_n_jobs: 8 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 1800 - batch_size_test: 1800 - - # Data loading - num_workers: 8 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: &hidden_dim 400 - hidden_dims: *hidden_dim - depth: 3 - activation: relu - last_activation: none - dropout: &dropout_mlp 0.3 - last_dropout: *dropout_mlp - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: simple - - pre_nn_edges: # Set as null to avoid a pre-nn network - out_dim: 32 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout_mlp - last_dropout: *dropout_mlp - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - - gnn: # Set as null to avoid a post-nn network - out_dim: *hidden_dim - hidden_dims: *hidden_dim - depth: 5 - activation: none - last_activation: none - dropout: &dropout_gnn 0.2 - last_dropout: *dropout_gnn - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: ['sum', 'max'] - virtual_node: 'none' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, sum, dir1/dx_abs] - scalers: [identity] - - post_nn: - out_dim: 128 - hidden_dims: *hidden_dim - depth: 3 - activation: relu - last_activation: sigmoid - dropout: *dropout_mlp - last_dropout: 0. - batch_norm: *batch_norm - last_batch_norm: False - residual_type: simple - -predictor: - metrics_on_progress_bar: ["mae", "averageprecision", "auroc"] - metrics_on_training_set: ["mae", "averageprecision", "auroc"] - loss_fun: bce - random_seed: *seed - optim_kwargs: - lr: 5.e-3 - weight_decay: 0 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 20 - min_lr: 2.e-4 - scheduler_kwargs: - monitor: &monitor averageprecision/val - mode: &mode max - frequency: 1 - target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - target_nan_mask: ignore-flatten - - - name: averageprecision - metric: averageprecision - target_nan_mask: ignore-mean-label - threshold_kwargs: - threshold: null - th_on_preds: False - th_on_target: False - target_to_int: True - - - name: auroc - metric: auroc - target_nan_mask: ignore-mean-label - threshold_kwargs: - threshold: null - th_on_preds: False - th_on_target: False - target_to_int: True - - - name: f1 > 0.5 - metric: f1 - num_classes: 2 - average: micro - target_nan_mask: ignore-mean-label - threshold_kwargs: &threshold_1 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: f1 > 0.25 - metric: f1 - num_classes: 2 - average: micro - target_nan_mask: ignore-mean-label - threshold_kwargs: &threshold_2 - operator: greater - threshold: 0.25 - th_on_preds: True - th_on_target: False - target_to_int: True - - - name: recall > 0.5 - metric: recall - target_nan_mask: ignore-mean-label - threshold_kwargs: *threshold_1 - - - name: recall > 0.25 - metric: recall - target_nan_mask: ignore-mean-label - threshold_kwargs: *threshold_2 - -trainer: - logger: - save_dir: logs/ogb-molpcba - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 80 - mode: *mode - model_checkpoint: - dirpath: models_checkpoints/ogb-molpcba/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 1000 - min_epochs: 100 - gpus: 1 - accumulate_grad_batches: 1 - +constants: + seed: &seed 42 + raise_train_error: false # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + dataset_name: "ogbg-molpcba" + cache_data_path: goli/data/cache/ogb-molpcba/full-morefeatures.cache + + # Weights + weights_type: null + sample_size: null + + # Featurization + featurization_n_jobs: 8 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 1800 + batch_size_test: 1800 + + # Data loading + num_workers: 8 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: &hidden_dim 400 + hidden_dims: *hidden_dim + depth: 3 + activation: relu + last_activation: none + dropout: &dropout_mlp 0.3 + last_dropout: *dropout_mlp + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: simple + + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout_mlp + last_dropout: *dropout_mlp + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + + gnn: # Set as null to avoid a post-nn network + out_dim: *hidden_dim + hidden_dims: *hidden_dim + depth: 5 + activation: none + last_activation: none + dropout: &dropout_gnn 0.2 + last_dropout: *dropout_gnn + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: ['sum', 'max'] + virtual_node: 'none' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, sum, dir1/dx_abs] + scalers: [identity] + + post_nn: + out_dim: 128 + hidden_dims: *hidden_dim + depth: 3 + activation: relu + last_activation: sigmoid + dropout: *dropout_mlp + last_dropout: 0. + batch_norm: *batch_norm + last_batch_norm: False + residual_type: simple + +predictor: + metrics_on_progress_bar: ["mae", "averageprecision", "auroc"] + metrics_on_training_set: ["mae", "averageprecision", "auroc"] + loss_fun: bce + random_seed: *seed + optim_kwargs: + lr: 5.e-3 + weight_decay: 0 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 20 + min_lr: 2.e-4 + scheduler_kwargs: + monitor: &monitor averageprecision/val + mode: &mode max + frequency: 1 + target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + target_nan_mask: ignore-flatten + + - name: averageprecision + metric: averageprecision + target_nan_mask: ignore-mean-label + threshold_kwargs: + threshold: null + th_on_preds: False + th_on_target: False + target_to_int: True + + - name: auroc + metric: auroc + target_nan_mask: ignore-mean-label + threshold_kwargs: + threshold: null + th_on_preds: False + th_on_target: False + target_to_int: True + + - name: f1 > 0.5 + metric: f1 + num_classes: 2 + average: micro + target_nan_mask: ignore-mean-label + threshold_kwargs: &threshold_1 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: f1 > 0.25 + metric: f1 + num_classes: 2 + average: micro + target_nan_mask: ignore-mean-label + threshold_kwargs: &threshold_2 + operator: greater + threshold: 0.25 + th_on_preds: True + th_on_target: False + target_to_int: True + + - name: recall > 0.5 + metric: recall + target_nan_mask: ignore-mean-label + threshold_kwargs: *threshold_1 + + - name: recall > 0.25 + metric: recall + target_nan_mask: ignore-mean-label + threshold_kwargs: *threshold_2 + +trainer: + logger: + save_dir: logs/ogb-molpcba + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 80 + mode: *mode + model_checkpoint: + dirpath: models_checkpoints/ogb-molpcba/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 1000 + min_epochs: 100 + gpus: 1 + accumulate_grad_batches: 1 + \ No newline at end of file diff --git a/expts/config_molPCQM4M.yaml b/expts/config_molPCQM4M.yaml index 5f0a1a663..e4a82da15 100644 --- a/expts/config_molPCQM4M.yaml +++ b/expts/config_molPCQM4M.yaml @@ -1,180 +1,185 @@ -constants: - seed: &seed 42 - raise_train_error: True # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-molpcqm4m" - - - # Weights - weights_type: null - sample_size: 10000 - - # Featurization - featurization_n_jobs: 8 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 1500 - batch_size_test: 1500 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: &hidden_dim 420 - hidden_dims: *hidden_dim - depth: 3 - activation: relu - last_activation: none - dropout: &dropout_mlp 0.2 - last_dropout: *dropout_mlp - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: simple - - pre_nn_edges: # Set as null to avoid a pre-nn network - out_dim: 32 - hidden_dims: 32 - depth: 3 - activation: relu - last_activation: none - dropout: *dropout_mlp - last_dropout: *dropout_mlp - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - - gnn: # Set as null to avoid a post-nn network - out_dim: *hidden_dim - hidden_dims: *hidden_dim - depth: 5 - activation: none - last_activation: none - dropout: &dropout_gnn 0.2 - last_dropout: *dropout_gnn - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: ['sum', 'max', 'dir1'] - virtual_node: 'none' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, sum, dir1/dx_abs] - scalers: [identity] - - post_nn: - out_dim: 1 - hidden_dims: *hidden_dim - depth: 3 - activation: relu - last_activation: none - dropout: *dropout_mlp - last_dropout: 0. - batch_norm: *batch_norm - last_batch_norm: False - residual_type: simple - -predictor: - metrics_on_progress_bar: ["mae", "mse", "pearsonr"] - loss_fun: bce - random_seed: *seed - optim_kwargs: - lr: 5.e-3 - weight_decay: 0 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 20 - min_lr: 2.e-4 - scheduler_kwargs: - monitor: &monitor mae/val - mode: &mode max - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - - - name: pearsonr - metric: pearsonr - threshold_kwargs: null - - - name: mse - metric: mse - threshold_kwargs: null - - - name: spearmanr - metric: spearmanr - threshold_kwargs: null - - - - name: f1 > 5 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_1 - operator: greater - threshold: 5 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: f1 > 4 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_2 - operator: greater - threshold: 4 - th_on_preds: True - th_on_target: True - target_to_int: True - - -trainer: - logger: - save_dir: logs/ogb-molpcqm4m - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 80 - mode: *mode - model_checkpoint: - dirpath: models_checkpoints/ogb-molpcqm4m/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 1000 - min_epochs: 100 - gpus: 1 - accumulate_grad_batches: 1 - +constants: + seed: &seed 42 + raise_train_error: True # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-molpcqm4m" + + + # Weights + weights_type: null + sample_size: null + + # Featurization + featurization_n_jobs: 8 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 512 + batch_size_test: 512 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: &hidden_dim 420 + hidden_dims: *hidden_dim + depth: 3 + activation: relu + last_activation: none + dropout: &dropout_mlp 0.2 + last_dropout: *dropout_mlp + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: simple + + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 32 + depth: 3 + activation: relu + last_activation: none + dropout: *dropout_mlp + last_dropout: *dropout_mlp + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + + gnn: # Set as null to avoid a post-nn network + out_dim: *hidden_dim + hidden_dims: *hidden_dim + depth: 16 + activation: none + last_activation: none + dropout: &dropout_gnn 0.2 + last_dropout: *dropout_gnn + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: ['sum', 'max', 'dir1'] + virtual_node: 'none' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, sum, dir1/dx_abs] + scalers: [identity] + + post_nn: + out_dim: 1 + hidden_dims: *hidden_dim + depth: 3 + activation: relu + last_activation: none + dropout: *dropout_mlp + last_dropout: 0. + batch_norm: *batch_norm + last_batch_norm: False + residual_type: simple + +predictor: + metrics_on_progress_bar: ["mse", "mae", "pearsonr"] + loss_fun: mse + random_seed: *seed + optim_kwargs: + lr: 5.e-3 + weight_decay: 0 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 20 + min_lr: 2.e-4 + scheduler_kwargs: + monitor: &monitor mae/val + mode: &mode min + frequency: 1 + target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + target_nan_mask: ignore-flatten + + + - name: pearsonr + metric: pearsonr + threshold_kwargs: null + target_nan_mask: ignore-flatten + + - name: mse + metric: mse + threshold_kwargs: null + target_nan_mask: ignore-flatten + + - name: spearmanr + metric: spearmanr + threshold_kwargs: null + target_nan_mask: ignore-flatten + + + - name: f1 > 5 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_1 + operator: greater + threshold: 5 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: f1 > 4 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_2 + operator: greater + threshold: 4 + th_on_preds: True + th_on_target: True + target_to_int: True + + +trainer: + logger: + save_dir: logs/ogb-molpcqm4m + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 80 + mode: *mode + model_checkpoint: + dirpath: models_checkpoints/ogb-molpcqm4m/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 1000 + min_epochs: 100 + gpus: 1 + accumulate_grad_batches: 1 + \ No newline at end of file diff --git a/expts/config_molbace_pretrained.yaml b/expts/config_molbace_pretrained.yaml index 498bd4bba..449467e04 100644 --- a/expts/config_molbace_pretrained.yaml +++ b/expts/config_molbace_pretrained.yaml @@ -1,46 +1,46 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-molbace" - cache_data_path: goli/data/cache/ogb-molbace/full.cache - - # Weights - weights_type: null - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 256 - batch_size_test: 256 - - # Data loading - num_workers: 4 - pin_memory: False - persistent_workers: True # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-molbace" + cache_data_path: goli/data/cache/ogb-molbace/full.cache + + # Weights + weights_type: null + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 256 + batch_size_test: 256 + + # Data loading + num_workers: 4 + pin_memory: False + persistent_workers: True # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/config_mollipo_pretrained.yaml b/expts/config_mollipo_pretrained.yaml index 20941d092..df40afb89 100644 --- a/expts/config_mollipo_pretrained.yaml +++ b/expts/config_mollipo_pretrained.yaml @@ -1,46 +1,46 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-mollipo" - cache_data_path: goli/data/cache/ogb-mollipo/full.cache - - # Weights - weights_type: null - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 256 - batch_size_test: 256 - - # Data loading - num_workers: 4 - pin_memory: False - persistent_workers: True # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-mollipo" + cache_data_path: goli/data/cache/ogb-mollipo/full.cache + + # Weights + weights_type: null + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 256 + batch_size_test: 256 + + # Data loading + num_workers: 4 + pin_memory: False + persistent_workers: True # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/config_moltox21_pretrained.yaml b/expts/config_moltox21_pretrained.yaml index 7b8369bd3..2c47bb036 100644 --- a/expts/config_moltox21_pretrained.yaml +++ b/expts/config_moltox21_pretrained.yaml @@ -1,46 +1,46 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLOGBDataModule" - args: - cache_data_path: null - dataset_name: "ogbg-moltox21" - cache_data_path: goli/data/cache/ogb-moltox21/full.cache - - # Weights - weights_type: null - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - batch_size_train_val: 256 - batch_size_test: 256 - - # Data loading - num_workers: 4 - pin_memory: False - persistent_workers: True # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLOGBDataModule" + args: + cache_data_path: null + dataset_name: "ogbg-moltox21" + cache_data_path: goli/data/cache/ogb-moltox21/full.cache + + # Weights + weights_type: null + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + batch_size_train_val: 256 + batch_size_test: 256 + + # Data loading + num_workers: 4 + pin_memory: False + persistent_workers: True # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/config_single_atom_dataset_pretrained.yaml b/expts/config_single_atom_dataset_pretrained.yaml index 26787e7e4..3ca235f12 100644 --- a/expts/config_single_atom_dataset_pretrained.yaml +++ b/expts/config_single_atom_dataset_pretrained.yaml @@ -1,52 +1,52 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - args: - df_path: goli/data/single_atom_dataset/single_atom_dataset.csv - cache_data_path: null # goli/data/cache/BindingDB/full.cache - label_cols: ['score'] - smiles_col: SMILES - sample_size: null - - # Weights - weights_type: null - - # Featurization - featurization_n_jobs: 0 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] - edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: 0.2 - split_test: 0.2 - split_seed: *seed - splits_path: null - batch_size_train_val: 2 - batch_size_test: 2 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -trainer: - trainer: - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + args: + df_path: goli/data/single_atom_dataset/single_atom_dataset.csv + cache_data_path: null # goli/data/cache/BindingDB/full.cache + label_cols: ['score'] + smiles_col: SMILES + sample_size: null + + # Weights + weights_type: null + + # Featurization + featurization_n_jobs: 0 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal] + edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: 0.2 + split_test: 0.2 + split_seed: *seed + splits_path: null + batch_size_train_val: 2 + batch_size_test: 2 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +trainer: + trainer: + gpus: 1 + \ No newline at end of file diff --git a/expts/data/micro_zinc_splits.csv b/expts/data/micro_zinc_splits.csv index 8c3d3d20f..82fcdad6d 100644 --- a/expts/data/micro_zinc_splits.csv +++ b/expts/data/micro_zinc_splits.csv @@ -1,601 +1,601 @@ -train,val,test -957,392.0,654.0 -223,588.0,430.0 -916,538.0,317.0 -410,903.0,629.0 -287,923.0,927.0 -496,677.0,641.0 -456,144.0,518.0 -589,83.0,645.0 -151,385.0,649.0 -10,996.0,504.0 -7,506.0,716.0 -180,205.0,557.0 -769,161.0,931.0 -389,746.0,804.0 -304,99.0,30.0 -867,444.0,836.0 -177,814.0,469.0 -79,984.0,448.0 -937,292.0,863.0 -494,377.0,412.0 -771,986.0,61.0 -587,864.0,805.0 -552,966.0,326.0 -881,974.0,520.0 -808,106.0,263.0 -186,670.0,874.0 -925,238.0,920.0 -268,995.0,226.0 -108,387.0,749.0 -717,696.0,724.0 -37,871.0,583.0 -47,515.0,928.0 -258,596.0,117.0 -173,801.0,459.0 -979,301.0,633.0 -846,420.0,239.0 -578,688.0,421.0 -320,572.0,484.0 -69,451.0,399.0 -221,318.0,886.0 -96,87.0,133.0 -732,672.0,759.0 -303,944.0,774.0 -48,891.0,523.0 -455,311.0,103.0 -140,297.0,610.0 -734,776.0,743.0 -138,972.0,104.0 -699,548.0,574.0 -706,152.0,204.0 -713,418.0,822.0 -705,214.0,989.0 -230,254.0,620.0 -134,709.0,648.0 -551,164.0,66.0 -92,835.0,75.0 -890,395.0,929.0 -185,539.0,184.0 -691,445.0,712.0 -31,573.0,98.0 -195,894.0,997.0 -630,852.0,94.0 -447,893.0,959.0 -967,673.0,728.0 -280,819.0,689.0 -813,829.0,323.0 -788,711.0,658.0 -432,750.0,722.0 -876,153.0,305.0 -766,524.0,600.0 -764,800.0,23.0 -135,279.0,76.0 -12,700.0,73.0 -136,71.0,740.0 -406,942.0,482.0 -267,935.0,773.0 -234,854.0,208.0 -531,59.0,985.0 -274,124.0,812.0 -938,913.0,122.0 -512,591.0,398.0 -353,693.0,726.0 -127,790.0,380.0 -737,54.0,293.0 -401,283.0,429.0 -121,492.0,747.0 -563,702.0,906.0 -694,954.0,316.0 -176,360.0,842.0 -25,624.0,145.0 -187,898.0,302.0 -434,719.0,781.0 -828,628.0,65.0 -113,232.0,388.0 -199,537.0,314.0 -227,60.0,825.0 -411,206.0,536.0 -590,567.0,787.0 -581,878.0,495.0 -697,383.0,89.0 -147,683.0,331.0 -798,806.0,338.0 -918,49.0,656.0 -751,993.0,350.0 -415,85.0,485.0 -692,150.0,840.0 -120,64.0,760.0 -686,980.0,41.0 -466,356.0,261.0 -517,735.0,855.0 -101,486.0,439.0 -528,883.0,503.0 -845,18.0,38.0 -644,879.0,202.0 -698,478.0,519.0 -501,336.0,625.0 -86,40.0,105.0 -224,472.0,763.0 -422,522.0,493.0 -15,276.0,405.0 -767,397.0,950.0 -193,192.0,156.0 -437,889.0,351.0 -376,638.0,603.0 -910,549.0,461.0 -288,680.0,919.0 -171,381.0,765.0 -955,655.0,940.0 -932,490.0,352.0 -623,132.0,917.0 -848,17.0,142.0 -592,643.0,870.0 -155,24.0,508.0 -457,157.0,561.0 -216,229.0,174.0 -250,453.0,951.0 -483,857.0,550.0 -404,463.0,601.0 -282,626.0,904.0 -529,646.0,831.0 -513,477.0,668.0 -189,211.0,880.0 -441,307.0,824.0 -576,408.0,613.0 -2,428.0,657.0 -744,615.0,241.0 -111,378.0,402.0 -718,100.0,873.0 -391,499.0,16.0 -322,785.0,755.0 -566,675.0,584.0 -452,794.0,650.0 -860,129.0,669.0 -952,729.0,598.0 -921,290.0,468.0 -875,337.0,571.0 -42,602.0,660.0 -128,553.0,443.0 -172,431.0,219.0 -908,26.0,577.0 -21,994.0,489.0 -895,95.0,58.0 -414,479.0,373.0 -505,197.0,119.0 -809,278.0,652.0 -371,249.0,137.0 -899,442.0,868.0 -6,594.0,826.0 -960,269.0,666.0 -964,324.0,341.0 -84,887.0,252.0 -329,502.0,470.0 -344,662.0,858.0 -102,27.0,262.0 -933,123.0,953.0 -608,877.0,667.0 -681,862.0,359.0 -542,924.0,595.0 -977,945.0,419.0 -582,514.0,902.0 -568,256.0,1.0 -464,321.0,939.0 -355,343.0,833.0 -454,527.0,163.0 -115,400.0,67.0 -36,14.0,168.0 -190,295.0,851.0 -799,973.0,721.0 -642,965.0,74.0 -285,62.0,756.0 -328,110.0,245.0 -141,213.0,970.0 -237,555.0,786.0 -130,333.0,754.0 -81,704.0,362.0 -8,481.0,348.0 -210,792.0,827.0 -118,546.0,762.0 -963,196.0,375.0 -139,260.0,334.0 -992,, -13,, -417,, -632,, -427,, -28,, -560,, -525,, -509,, -723,, -264,, -653,, -181,, -823,, -884,, -91,, -471,, -535,, -782,, -914,, -257,, -143,, -179,, -222,, -255,, -897,, -20,, -146,, -349,, -922,, -565,, -170,, -200,, -497,, -365,, -820,, -607,, -661,, -242,, -160,, -235,, -236,, -68,, -701,, -43,, -948,, -90,, -983,, -240,, -310,, -46,, -166,, -57,, -838,, -346,, -892,, -720,, -690,, -271,, -856,, -97,, -299,, -423,, -209,, -384,, -357,, -810,, -265,, -943,, -541,, -165,, -768,, -659,, -368,, -225,, -821,, -604,, -631,, -22,, -253,, -714,, -544,, -159,, -148,, -956,, -796,, -476,, -33,, -684,, -178,, -9,, -207,, -545,, -532,, -635,, -149,, -296,, -533,, -715,, -473,, -850,, -559,, -647,, -298,, -990,, -982,, -640,, -926,, -498,, -270,, -738,, -912,, -736,, -424,, -272,, -56,, -651,, -366,, -770,, -784,, -637,, -911,, -599,, -802,, -885,, -78,, -586,, -968,, -830,, -843,, -866,, -731,, -999,, -319,, -50,, -361,, -175,, -861,, -987,, -289,, -450,, -394,, -832,, -564,, -687,, -909,, -363,, -11,, -45,, -248,, -543,, -679,, -34,, -480,, -158,, -674,, -547,, -347,, -791,, -569,, -300,, -425,, -981,, -975,, -386,, -507,, -847,, -218,, -627,, -154,, -436,, -286,, -930,, -888,, -35,, -685,, -72,, -570,, -109,, -217,, -562,, -795,, -962,, -82,, -775,, -882,, -580,, -460,, -554,, -703,, -407,, -526,, -511,, -93,, -621,, -4,, -752,, -131,, -315,, -370,, -978,, -70,, -593,, -682,, -947,, -818,, -339,, -676,, -971,, -345,, -797,, -3,, -616,, -663,, -390,, -725,, -462,, -281,, -664,, -745,, -0,, -844,, -617,, -946,, -294,, -859,, -748,, -915,, -998,, -742,, -530,, -116,, -309,, -741,, -988,, -275,, -606,, -900,, -374,, -367,, -107,, -312,, -639,, -182,, -739,, -934,, -39,, -811,, -5,, -63,, -487,, -80,, -753,, -51,, -678,, -335,, -841,, -426,, -284,, -29,, -251,, -611,, -585,, -396,, -491,, -340,, -815,, -488,, -793,, -358,, -77,, -259,, -44,, -228,, -243,, -125,, -761,, -516,, -958,, -558,, -665,, -540,, -244,, -215,, -277,, -521,, -403,, -198,, -730,, -413,, -332,, -710,, -458,, -500,, -579,, -112,, -865,, -247,, -52,, -695,, -758,, -614,, -382,, -619,, -372,, -393,, -126,, -961,, -449,, -839,, -634,, -816,, -88,, -169,, -907,, -597,, -707,, -789,, -474,, -556,, -618,, -727,, -778,, -575,, -32,, -991,, -636,, -191,, -114,, -803,, -233,, -896,, -246,, -446,, -379,, -612,, -733,, -779,, -905,, -183,, -475,, -435,, -409,, -837,, -949,, -273,, -212,, -941,, -433,, -467,, -780,, -465,, -510,, -220,, -438,, -708,, -817,, -416,, -325,, -969,, -188,, -872,, -55,, -313,, -291,, -53,, -194,, -369,, -849,, -869,, -853,, -772,, -201,, -203,, -777,, -327,, -807,, -976,, -231,, -783,, -167,, -364,, -306,, -342,, -266,, -609,, -901,, -162,, -534,, -440,, -834,, -671,, -330,, -308,, -19,, -936,, -354,, -757,, -622,, -605,, +train,val,test +957,392.0,654.0 +223,588.0,430.0 +916,538.0,317.0 +410,903.0,629.0 +287,923.0,927.0 +496,677.0,641.0 +456,144.0,518.0 +589,83.0,645.0 +151,385.0,649.0 +10,996.0,504.0 +7,506.0,716.0 +180,205.0,557.0 +769,161.0,931.0 +389,746.0,804.0 +304,99.0,30.0 +867,444.0,836.0 +177,814.0,469.0 +79,984.0,448.0 +937,292.0,863.0 +494,377.0,412.0 +771,986.0,61.0 +587,864.0,805.0 +552,966.0,326.0 +881,974.0,520.0 +808,106.0,263.0 +186,670.0,874.0 +925,238.0,920.0 +268,995.0,226.0 +108,387.0,749.0 +717,696.0,724.0 +37,871.0,583.0 +47,515.0,928.0 +258,596.0,117.0 +173,801.0,459.0 +979,301.0,633.0 +846,420.0,239.0 +578,688.0,421.0 +320,572.0,484.0 +69,451.0,399.0 +221,318.0,886.0 +96,87.0,133.0 +732,672.0,759.0 +303,944.0,774.0 +48,891.0,523.0 +455,311.0,103.0 +140,297.0,610.0 +734,776.0,743.0 +138,972.0,104.0 +699,548.0,574.0 +706,152.0,204.0 +713,418.0,822.0 +705,214.0,989.0 +230,254.0,620.0 +134,709.0,648.0 +551,164.0,66.0 +92,835.0,75.0 +890,395.0,929.0 +185,539.0,184.0 +691,445.0,712.0 +31,573.0,98.0 +195,894.0,997.0 +630,852.0,94.0 +447,893.0,959.0 +967,673.0,728.0 +280,819.0,689.0 +813,829.0,323.0 +788,711.0,658.0 +432,750.0,722.0 +876,153.0,305.0 +766,524.0,600.0 +764,800.0,23.0 +135,279.0,76.0 +12,700.0,73.0 +136,71.0,740.0 +406,942.0,482.0 +267,935.0,773.0 +234,854.0,208.0 +531,59.0,985.0 +274,124.0,812.0 +938,913.0,122.0 +512,591.0,398.0 +353,693.0,726.0 +127,790.0,380.0 +737,54.0,293.0 +401,283.0,429.0 +121,492.0,747.0 +563,702.0,906.0 +694,954.0,316.0 +176,360.0,842.0 +25,624.0,145.0 +187,898.0,302.0 +434,719.0,781.0 +828,628.0,65.0 +113,232.0,388.0 +199,537.0,314.0 +227,60.0,825.0 +411,206.0,536.0 +590,567.0,787.0 +581,878.0,495.0 +697,383.0,89.0 +147,683.0,331.0 +798,806.0,338.0 +918,49.0,656.0 +751,993.0,350.0 +415,85.0,485.0 +692,150.0,840.0 +120,64.0,760.0 +686,980.0,41.0 +466,356.0,261.0 +517,735.0,855.0 +101,486.0,439.0 +528,883.0,503.0 +845,18.0,38.0 +644,879.0,202.0 +698,478.0,519.0 +501,336.0,625.0 +86,40.0,105.0 +224,472.0,763.0 +422,522.0,493.0 +15,276.0,405.0 +767,397.0,950.0 +193,192.0,156.0 +437,889.0,351.0 +376,638.0,603.0 +910,549.0,461.0 +288,680.0,919.0 +171,381.0,765.0 +955,655.0,940.0 +932,490.0,352.0 +623,132.0,917.0 +848,17.0,142.0 +592,643.0,870.0 +155,24.0,508.0 +457,157.0,561.0 +216,229.0,174.0 +250,453.0,951.0 +483,857.0,550.0 +404,463.0,601.0 +282,626.0,904.0 +529,646.0,831.0 +513,477.0,668.0 +189,211.0,880.0 +441,307.0,824.0 +576,408.0,613.0 +2,428.0,657.0 +744,615.0,241.0 +111,378.0,402.0 +718,100.0,873.0 +391,499.0,16.0 +322,785.0,755.0 +566,675.0,584.0 +452,794.0,650.0 +860,129.0,669.0 +952,729.0,598.0 +921,290.0,468.0 +875,337.0,571.0 +42,602.0,660.0 +128,553.0,443.0 +172,431.0,219.0 +908,26.0,577.0 +21,994.0,489.0 +895,95.0,58.0 +414,479.0,373.0 +505,197.0,119.0 +809,278.0,652.0 +371,249.0,137.0 +899,442.0,868.0 +6,594.0,826.0 +960,269.0,666.0 +964,324.0,341.0 +84,887.0,252.0 +329,502.0,470.0 +344,662.0,858.0 +102,27.0,262.0 +933,123.0,953.0 +608,877.0,667.0 +681,862.0,359.0 +542,924.0,595.0 +977,945.0,419.0 +582,514.0,902.0 +568,256.0,1.0 +464,321.0,939.0 +355,343.0,833.0 +454,527.0,163.0 +115,400.0,67.0 +36,14.0,168.0 +190,295.0,851.0 +799,973.0,721.0 +642,965.0,74.0 +285,62.0,756.0 +328,110.0,245.0 +141,213.0,970.0 +237,555.0,786.0 +130,333.0,754.0 +81,704.0,362.0 +8,481.0,348.0 +210,792.0,827.0 +118,546.0,762.0 +963,196.0,375.0 +139,260.0,334.0 +992,, +13,, +417,, +632,, +427,, +28,, +560,, +525,, +509,, +723,, +264,, +653,, +181,, +823,, +884,, +91,, +471,, +535,, +782,, +914,, +257,, +143,, +179,, +222,, +255,, +897,, +20,, +146,, +349,, +922,, +565,, +170,, +200,, +497,, +365,, +820,, +607,, +661,, +242,, +160,, +235,, +236,, +68,, +701,, +43,, +948,, +90,, +983,, +240,, +310,, +46,, +166,, +57,, +838,, +346,, +892,, +720,, +690,, +271,, +856,, +97,, +299,, +423,, +209,, +384,, +357,, +810,, +265,, +943,, +541,, +165,, +768,, +659,, +368,, +225,, +821,, +604,, +631,, +22,, +253,, +714,, +544,, +159,, +148,, +956,, +796,, +476,, +33,, +684,, +178,, +9,, +207,, +545,, +532,, +635,, +149,, +296,, +533,, +715,, +473,, +850,, +559,, +647,, +298,, +990,, +982,, +640,, +926,, +498,, +270,, +738,, +912,, +736,, +424,, +272,, +56,, +651,, +366,, +770,, +784,, +637,, +911,, +599,, +802,, +885,, +78,, +586,, +968,, +830,, +843,, +866,, +731,, +999,, +319,, +50,, +361,, +175,, +861,, +987,, +289,, +450,, +394,, +832,, +564,, +687,, +909,, +363,, +11,, +45,, +248,, +543,, +679,, +34,, +480,, +158,, +674,, +547,, +347,, +791,, +569,, +300,, +425,, +981,, +975,, +386,, +507,, +847,, +218,, +627,, +154,, +436,, +286,, +930,, +888,, +35,, +685,, +72,, +570,, +109,, +217,, +562,, +795,, +962,, +82,, +775,, +882,, +580,, +460,, +554,, +703,, +407,, +526,, +511,, +93,, +621,, +4,, +752,, +131,, +315,, +370,, +978,, +70,, +593,, +682,, +947,, +818,, +339,, +676,, +971,, +345,, +797,, +3,, +616,, +663,, +390,, +725,, +462,, +281,, +664,, +745,, +0,, +844,, +617,, +946,, +294,, +859,, +748,, +915,, +998,, +742,, +530,, +116,, +309,, +741,, +988,, +275,, +606,, +900,, +374,, +367,, +107,, +312,, +639,, +182,, +739,, +934,, +39,, +811,, +5,, +63,, +487,, +80,, +753,, +51,, +678,, +335,, +841,, +426,, +284,, +29,, +251,, +611,, +585,, +396,, +491,, +340,, +815,, +488,, +793,, +358,, +77,, +259,, +44,, +228,, +243,, +125,, +761,, +516,, +958,, +558,, +665,, +540,, +244,, +215,, +277,, +521,, +403,, +198,, +730,, +413,, +332,, +710,, +458,, +500,, +579,, +112,, +865,, +247,, +52,, +695,, +758,, +614,, +382,, +619,, +372,, +393,, +126,, +961,, +449,, +839,, +634,, +816,, +88,, +169,, +907,, +597,, +707,, +789,, +474,, +556,, +618,, +727,, +778,, +575,, +32,, +991,, +636,, +191,, +114,, +803,, +233,, +896,, +246,, +446,, +379,, +612,, +733,, +779,, +905,, +183,, +475,, +435,, +409,, +837,, +949,, +273,, +212,, +941,, +433,, +467,, +780,, +465,, +510,, +220,, +438,, +708,, +817,, +416,, +325,, +969,, +188,, +872,, +55,, +313,, +291,, +53,, +194,, +369,, +849,, +869,, +853,, +772,, +201,, +203,, +777,, +327,, +807,, +976,, +231,, +783,, +167,, +364,, +306,, +342,, +266,, +609,, +901,, +162,, +534,, +440,, +834,, +671,, +330,, +308,, +19,, +936,, +354,, +757,, +622,, +605,, diff --git a/expts/data/tiny_zinc_splits.csv b/expts/data/tiny_zinc_splits.csv index cc9d34da3..03a538f92 100644 --- a/expts/data/tiny_zinc_splits.csv +++ b/expts/data/tiny_zinc_splits.csv @@ -1,61 +1,61 @@ -train,val,test -3,61.0,65.0 -21,64.0,24.0 -26,89.0,90.0 -42,28.0,37.0 -23,95.0,83.0 -87,46.0,63.0 -68,30.0,58.0 -99,91.0,85.0 -77,48.0,70.0 -41,43.0,18.0 -0,7.0,81.0 -36,72.0,94.0 -29,38.0,13.0 -11,47.0,1.0 -79,16.0,33.0 -82,62.0,14.0 -27,8.0,97.0 -51,76.0,2.0 -25,78.0,59.0 -88,6.0,44.0 -96,, -17,, -20,, -45,, -67,, -12,, -54,, -49,, -32,, -69,, -60,, -55,, -57,, -35,, -53,, -92,, -4,, -73,, -75,, -9,, -71,, -84,, -80,, -15,, -39,, -50,, -86,, -10,, -5,, -34,, -22,, -56,, -66,, -31,, -74,, -52,, -19,, -40,, -98,, -93,, +train,val,test +3,61.0,65.0 +21,64.0,24.0 +26,89.0,90.0 +42,28.0,37.0 +23,95.0,83.0 +87,46.0,63.0 +68,30.0,58.0 +99,91.0,85.0 +77,48.0,70.0 +41,43.0,18.0 +0,7.0,81.0 +36,72.0,94.0 +29,38.0,13.0 +11,47.0,1.0 +79,16.0,33.0 +82,62.0,14.0 +27,8.0,97.0 +51,76.0,2.0 +25,78.0,59.0 +88,6.0,44.0 +96,, +17,, +20,, +45,, +67,, +12,, +54,, +49,, +32,, +69,, +60,, +55,, +57,, +35,, +53,, +92,, +4,, +73,, +75,, +9,, +71,, +84,, +80,, +15,, +39,, +50,, +86,, +10,, +5,, +34,, +22,, +56,, +66,, +31,, +74,, +52,, +19,, +40,, +98,, +93,, diff --git a/expts/example_zinc.yaml b/expts/example_zinc.yaml index 73e5cda72..83d9ade2e 100644 --- a/expts/example_zinc.yaml +++ b/expts/example_zinc.yaml @@ -1,22 +1,22 @@ -data: - module_type: "DGLFromSmilesDataModule" - args: - df_path: null # could be set from the CLI. Ideally config files does not have paths. - cache_data_path: null # could be set from the CLI. Ideally config files does not have paths. - - smiles_col: "SMILES" - label_cols: ["SA"] - split_val: 0.2 - split_test: 0.2 - split_seed: 19 - - batch_size_train_val: 16 - batch_size_test: 16 - - featurization: - atom_property_list_float: [] - atom_property_list_onehot: ["atomic-number", "degree"] - edge_property_list: ["in-ring", "bond-type-onehot"] - add_self_loop: false - use_bonds_weights: false - explicit_H: false +data: + module_type: "DGLFromSmilesDataModule" + args: + df_path: null # could be set from the CLI. Ideally config files does not have paths. + cache_data_path: null # could be set from the CLI. Ideally config files does not have paths. + + smiles_col: "SMILES" + label_cols: ["SA"] + split_val: 0.2 + split_test: 0.2 + split_seed: 19 + + batch_size_train_val: 16 + batch_size_test: 16 + + featurization: + atom_property_list_float: [] + atom_property_list_onehot: ["atomic-number", "degree"] + edge_property_list: ["in-ring", "bond-type-onehot"] + add_self_loop: false + use_bonds_weights: false + explicit_H: false diff --git a/expts/main_run.py b/expts/main_run.py index 61967c751..2e6f4dfdf 100644 --- a/expts/main_run.py +++ b/expts/main_run.py @@ -1,76 +1,77 @@ -# General imports -import os -from os.path import dirname, abspath -import yaml -from copy import deepcopy -from omegaconf import DictConfig - - -# Current project imports -import goli -from goli.config._loader import load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer - - -# Set up the working directory -MAIN_DIR = dirname(dirname(abspath(goli.__file__))) -# CONFIG_FILE = "expts/config_molPCBA.yaml" -# CONFIG_FILE = "expts/config_molHIV.yaml" -CONFIG_FILE = "expts/config_micro_ZINC.yaml" -# CONFIG_FILE = "expts/config_ZINC_bench_gnn.yaml" -# CONFIG_FILE = "expts/config_htsfp_pcba.yaml" -os.chdir(MAIN_DIR) - - -def main(cfg: DictConfig) -> None: - cfg = deepcopy(cfg) - - # Load and initialize the dataset - datamodule = load_datamodule(cfg) - print("\ndatamodule:\n", datamodule, "\n") - - # Initialize the network - model_class, model_kwargs = load_architecture( - cfg, - in_dim_nodes=datamodule.num_node_feats_with_positional_encoding, - in_dim_edges=datamodule.num_edge_feats, - ) - - metrics = load_metrics(cfg) - print(metrics) - - predictor = load_predictor(cfg, model_class, model_kwargs, metrics) - - print(predictor.model) - print(predictor.summarize(mode=4, to_print=False)) - - trainer = load_trainer(cfg) - - # Run the model training - print("\n------------ TRAINING STARTED ------------") - try: - trainer.fit(model=predictor, datamodule=datamodule) - print("\n------------ TRAINING COMPLETED ------------\n\n") - - except Exception as e: - if not cfg["constants"]["raise_train_error"]: - print("\n------------ TRAINING ERROR: ------------\n\n", e) - else: - raise e - - print("\n------------ TESTING STARTED ------------") - try: - ckpt_path = trainer.checkpoint_callbacks[0].best_model_path - trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path) - print("\n------------ TESTING COMPLETED ------------\n\n") - - except Exception as e: - if not cfg["constants"]["raise_train_error"]: - print("\n------------ TESTING ERROR: ------------\n\n", e) - else: - raise e - - -if __name__ == "__main__": - with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: - cfg = yaml.safe_load(f) - main(cfg) +# General imports +import os +from os.path import dirname, abspath +import yaml +from copy import deepcopy +from omegaconf import DictConfig + + +# Current project imports +import goli +from goli.config._loader import load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer + + +# Set up the working directory +MAIN_DIR = dirname(dirname(abspath(goli.__file__))) +# CONFIG_FILE = "expts/config_molPCBA.yaml" +# CONFIG_FILE = "expts/config_molHIV.yaml" +CONFIG_FILE = "expts/config_molPCQM4M.yaml" +# CONFIG_FILE = "expts/config_micro_ZINC.yaml" +# CONFIG_FILE = "expts/config_ZINC_bench_gnn.yaml" +# CONFIG_FILE = "expts/config_htsfp_pcba.yaml" +os.chdir(MAIN_DIR) + + +def main(cfg: DictConfig) -> None: + cfg = deepcopy(cfg) + + # Load and initialize the dataset + datamodule = load_datamodule(cfg) + print("\ndatamodule:\n", datamodule, "\n") + + # Initialize the network + model_class, model_kwargs = load_architecture( + cfg, + in_dim_nodes=datamodule.num_node_feats_with_positional_encoding, + in_dim_edges=datamodule.num_edge_feats, + ) + + metrics = load_metrics(cfg) + print(metrics) + + predictor = load_predictor(cfg, model_class, model_kwargs, metrics) + + print(predictor.model) + print(predictor.summarize(mode=4, to_print=False)) + + trainer = load_trainer(cfg) + + # Run the model training + print("\n------------ TRAINING STARTED ------------") + try: + trainer.fit(model=predictor, datamodule=datamodule) + print("\n------------ TRAINING COMPLETED ------------\n\n") + + except Exception as e: + if not cfg["constants"]["raise_train_error"]: + print("\n------------ TRAINING ERROR: ------------\n\n", e) + else: + raise e + + print("\n------------ TESTING STARTED ------------") + try: + ckpt_path = trainer.checkpoint_callbacks[0].best_model_path + trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path) + print("\n------------ TESTING COMPLETED ------------\n\n") + + except Exception as e: + if not cfg["constants"]["raise_train_error"]: + print("\n------------ TESTING ERROR: ------------\n\n", e) + else: + raise e + + +if __name__ == "__main__": + with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: + cfg = yaml.safe_load(f) + main(cfg) diff --git a/expts/main_run_predict.py b/expts/main_run_predict.py index d326c09b8..775a12959 100644 --- a/expts/main_run_predict.py +++ b/expts/main_run_predict.py @@ -1,78 +1,78 @@ -# General imports -import os -from os.path import dirname, abspath -import yaml -from copy import deepcopy -from omegaconf import DictConfig -import numpy as np -import pandas as pd -import torch - -# Current project imports -import goli -from goli.config._loader import load_datamodule, load_trainer -from goli.utils.fs import mkdir -from goli.trainer.predictor import PredictorModule - - -# Set up the working directory -MAIN_DIR = dirname(dirname(abspath(goli.__file__))) -os.chdir(MAIN_DIR) - -DATA_NAME = "bindingDB" -MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt" -CONFIG_FILE = f"expts/config_{DATA_NAME}_pretrained.yaml" - -# MODEL_FILE = "models_checkpoints/micro_ZINC/model.ckpt" -# CONFIG_FILE = "expts/config_micro_ZINC.yaml" - - -NUM_LAYERS_TO_DROP = [3] # range(3) - - -def main(cfg: DictConfig) -> None: - - cfg = deepcopy(cfg) - - # Load and initialize the dataset - datamodule = load_datamodule(cfg) - print("\ndatamodule:\n", datamodule, "\n") - - for num_layers_to_drop in NUM_LAYERS_TO_DROP: - - export_df_path = f"predictions/fingerprint-drop-output-{DATA_NAME}-{num_layers_to_drop}.csv.gz" - - predictor = PredictorModule.load_from_checkpoint(MODEL_FILE) - predictor.model.drop_post_nn_layers(num_layers_to_drop=num_layers_to_drop) - - print(predictor.model) - print(predictor.summarize(mode=4, to_print=False)) - - trainer = load_trainer(cfg) - - # Run the model prediction - preds = trainer.predict(model=predictor, datamodule=datamodule) - if isinstance(preds[0], torch.Tensor): - preds = [p.detach().cpu().numpy() for p in preds] - preds = np.concatenate(preds, axis=0) - - # Generate output dataframe - df = {"SMILES": datamodule.dataset.smiles} - - target = datamodule.dataset.labels - for ii in range(target.shape[1]): - df[f"Target-{ii}"] = target[:, ii] - - for ii in range(preds.shape[1]): - df[f"Preds-{ii}"] = preds[:, ii] - df = pd.DataFrame(df) - mkdir("predictions") - df.to_csv(export_df_path) - print(df) - print(f"file saved to:`{export_df_path}`") - - -if __name__ == "__main__": - with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: - cfg = yaml.safe_load(f) - main(cfg) +# General imports +import os +from os.path import dirname, abspath +import yaml +from copy import deepcopy +from omegaconf import DictConfig +import numpy as np +import pandas as pd +import torch + +# Current project imports +import goli +from goli.config._loader import load_datamodule, load_trainer +from goli.utils.fs import mkdir +from goli.trainer.predictor import PredictorModule + + +# Set up the working directory +MAIN_DIR = dirname(dirname(abspath(goli.__file__))) +os.chdir(MAIN_DIR) + +DATA_NAME = "bindingDB" +MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt" +CONFIG_FILE = f"expts/config_{DATA_NAME}_pretrained.yaml" + +# MODEL_FILE = "models_checkpoints/micro_ZINC/model.ckpt" +# CONFIG_FILE = "expts/config_micro_ZINC.yaml" + + +NUM_LAYERS_TO_DROP = [3] # range(3) + + +def main(cfg: DictConfig) -> None: + + cfg = deepcopy(cfg) + + # Load and initialize the dataset + datamodule = load_datamodule(cfg) + print("\ndatamodule:\n", datamodule, "\n") + + for num_layers_to_drop in NUM_LAYERS_TO_DROP: + + export_df_path = f"predictions/fingerprint-drop-output-{DATA_NAME}-{num_layers_to_drop}.csv.gz" + + predictor = PredictorModule.load_from_checkpoint(MODEL_FILE) + predictor.model.drop_post_nn_layers(num_layers_to_drop=num_layers_to_drop) + + print(predictor.model) + print(predictor.summarize(mode=4, to_print=False)) + + trainer = load_trainer(cfg) + + # Run the model prediction + preds = trainer.predict(model=predictor, datamodule=datamodule) + if isinstance(preds[0], torch.Tensor): + preds = [p.detach().cpu().numpy() for p in preds] + preds = np.concatenate(preds, axis=0) + + # Generate output dataframe + df = {"SMILES": datamodule.dataset.smiles} + + target = datamodule.dataset.labels + for ii in range(target.shape[1]): + df[f"Target-{ii}"] = target[:, ii] + + for ii in range(preds.shape[1]): + df[f"Preds-{ii}"] = preds[:, ii] + df = pd.DataFrame(df) + mkdir("predictions") + df.to_csv(export_df_path) + print(df) + print(f"file saved to:`{export_df_path}`") + + +if __name__ == "__main__": + with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: + cfg = yaml.safe_load(f) + main(cfg) diff --git a/expts/main_run_test.py b/expts/main_run_test.py index bfeacf57c..bc61a9b83 100644 --- a/expts/main_run_test.py +++ b/expts/main_run_test.py @@ -1,50 +1,50 @@ -# General imports -import os -from os.path import dirname, abspath -import yaml -from copy import deepcopy -from omegaconf import DictConfig - -# Current project imports -import goli -from goli.config._loader import load_datamodule, load_metrics, load_trainer - -from goli.trainer.predictor import PredictorModule - - -# Set up the working directory -MAIN_DIR = dirname(dirname(abspath(goli.__file__))) -os.chdir(MAIN_DIR) - -MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt" - -CONFIG_FILE = "expts/config_molPCBA.yaml" - - -def main(cfg: DictConfig) -> None: - - cfg = deepcopy(cfg) - - # Load and initialize the dataset - datamodule = load_datamodule(cfg) - print("\ndatamodule:\n", datamodule, "\n") - - metrics = load_metrics(cfg) - print(metrics) - - predictor = PredictorModule.load_from_checkpoint(MODEL_FILE) - predictor.metrics = metrics - - print(predictor.model) - print(predictor.summarize(mode=4, to_print=False)) - - trainer = load_trainer(cfg) - - # Run the model testing - trainer.test(model=predictor, datamodule=datamodule, ckpt_path=MODEL_FILE) - - -if __name__ == "__main__": - with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: - cfg = yaml.safe_load(f) - main(cfg) +# General imports +import os +from os.path import dirname, abspath +import yaml +from copy import deepcopy +from omegaconf import DictConfig + +# Current project imports +import goli +from goli.config._loader import load_datamodule, load_metrics, load_trainer + +from goli.trainer.predictor import PredictorModule + + +# Set up the working directory +MAIN_DIR = dirname(dirname(abspath(goli.__file__))) +os.chdir(MAIN_DIR) + +MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt" + +CONFIG_FILE = "expts/config_molPCBA.yaml" + + +def main(cfg: DictConfig) -> None: + + cfg = deepcopy(cfg) + + # Load and initialize the dataset + datamodule = load_datamodule(cfg) + print("\ndatamodule:\n", datamodule, "\n") + + metrics = load_metrics(cfg) + print(metrics) + + predictor = PredictorModule.load_from_checkpoint(MODEL_FILE) + predictor.metrics = metrics + + print(predictor.model) + print(predictor.summarize(mode=4, to_print=False)) + + trainer = load_trainer(cfg) + + # Run the model testing + trainer.test(model=predictor, datamodule=datamodule, ckpt_path=MODEL_FILE) + + +if __name__ == "__main__": + with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f: + cfg = yaml.safe_load(f) + main(cfg) diff --git a/goli/__init__.py b/goli/__init__.py index 4795eefde..e4e77b98d 100644 --- a/goli/__init__.py +++ b/goli/__init__.py @@ -1,9 +1,9 @@ -from ._version import __version__ - -from .config import load_config - -from . import utils -from . import features -from . import data -from . import nn -from . import trainer +from ._version import __version__ + +from .config import load_config + +from . import utils +from . import features +from . import data +from . import nn +from . import trainer diff --git a/goli/_version.py b/goli/_version.py index 3dc1f76bc..3f5c4a7d6 100644 --- a/goli/_version.py +++ b/goli/_version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.0" diff --git a/goli/cli/__init__.py b/goli/cli/__init__.py index 5c5f801f3..a60f913d1 100644 --- a/goli/cli/__init__.py +++ b/goli/cli/__init__.py @@ -1,2 +1,2 @@ -from .main import main_cli -from .data import data_cli +from .main import main_cli +from .data import data_cli diff --git a/goli/cli/data.py b/goli/cli/data.py index cfd58d5c1..45b20b42f 100644 --- a/goli/cli/data.py +++ b/goli/cli/data.py @@ -1,57 +1,57 @@ -import click - -from loguru import logger - -import goli - -from .main import main_cli - - -@main_cli.group(name="data", help="Goli datasets.") -def data_cli(): - pass - - -@data_cli.command(name="download", help="Download a Goli dataset.") -@click.option( - "-n", - "--name", - type=str, - required=True, - help="Name of the goli dataset to download.", -) -@click.option( - "-o", - "--output", - type=str, - required=True, - help="Where to download the Goli dataset.", -) -@click.option( - "--progress", - type=bool, - is_flag=True, - default=False, - required=False, - help="Whether to extract the dataset if it's a zip file.", -) -def download(name, output, progress): - - args = {} - args["name"] = name - args["output_path"] = output - args["extract_zip"] = True - args["progress"] = progress - - logger.info(f"Download dataset '{name}' into {output}.") - - fpath = goli.data.utils.download_goli_dataset(**args) - - logger.info(f"Dataset available at {fpath}.") - - -@data_cli.command(name="list", help="List available Goli dataset.") -def list(): - - logger.info("Goli datasets:") - logger.info(goli.data.utils.list_goli_datasets()) +import click + +from loguru import logger + +import goli + +from .main import main_cli + + +@main_cli.group(name="data", help="Goli datasets.") +def data_cli(): + pass + + +@data_cli.command(name="download", help="Download a Goli dataset.") +@click.option( + "-n", + "--name", + type=str, + required=True, + help="Name of the goli dataset to download.", +) +@click.option( + "-o", + "--output", + type=str, + required=True, + help="Where to download the Goli dataset.", +) +@click.option( + "--progress", + type=bool, + is_flag=True, + default=False, + required=False, + help="Whether to extract the dataset if it's a zip file.", +) +def download(name, output, progress): + + args = {} + args["name"] = name + args["output_path"] = output + args["extract_zip"] = True + args["progress"] = progress + + logger.info(f"Download dataset '{name}' into {output}.") + + fpath = goli.data.utils.download_goli_dataset(**args) + + logger.info(f"Dataset available at {fpath}.") + + +@data_cli.command(name="list", help="List available Goli dataset.") +def list(): + + logger.info("Goli datasets:") + logger.info(goli.data.utils.list_goli_datasets()) diff --git a/goli/cli/main.py b/goli/cli/main.py index 2161514e7..4a44b3a5f 100644 --- a/goli/cli/main.py +++ b/goli/cli/main.py @@ -1,11 +1,11 @@ -import click - - -@click.group() -@click.version_option() -def main_cli(): - pass - - -if __name__ == "__main__": - main_cli() +import click + + +@click.group() +@click.version_option() +def main_cli(): + pass + + +if __name__ == "__main__": + main_cli() diff --git a/goli/config/__init__.py b/goli/config/__init__.py index 2f13795b6..9ea001495 100644 --- a/goli/config/__init__.py +++ b/goli/config/__init__.py @@ -1,7 +1,7 @@ -from ._load import load_config - -from ._loader import load_architecture -from ._loader import load_datamodule -from ._loader import load_metrics -from ._loader import load_predictor -from ._loader import load_trainer +from ._load import load_config + +from ._loader import load_architecture +from ._loader import load_datamodule +from ._loader import load_metrics +from ._loader import load_predictor +from ._loader import load_trainer diff --git a/goli/config/_load.py b/goli/config/_load.py index bc587645b..0c24eab29 100644 --- a/goli/config/_load.py +++ b/goli/config/_load.py @@ -1,16 +1,16 @@ -import importlib.resources - -import omegaconf - - -def load_config(name: str): - """Load a default config file by its name. - - Args: - name: name of the config to load. - """ - - with importlib.resources.open_text("goli.config", f"{name}.yaml") as f: - config = omegaconf.OmegaConf.load(f) - - return config +import importlib.resources + +import omegaconf + + +def load_config(name: str): + """Load a default config file by its name. + + Args: + name: name of the config to load. + """ + + with importlib.resources.open_text("goli.config", f"{name}.yaml") as f: + config = omegaconf.OmegaConf.load(f) + + return config diff --git a/goli/config/_loader.py b/goli/config/_loader.py index 68c7afcf4..75fd91116 100644 --- a/goli/config/_loader.py +++ b/goli/config/_loader.py @@ -1,151 +1,151 @@ -from typing import List, Dict, Union, Any - -import omegaconf -from copy import deepcopy -import torch -from loguru import logger - -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning import Trainer - -from goli.trainer.metrics import MetricWrapper -from goli.nn import FullDGLNetwork, FullDGLSiameseNetwork, FeedForwardNN -from goli.data.datamodule import DGLFromSmilesDataModule, DGLOGBDataModule -from goli.trainer.predictor import PredictorModule - -# from goli.trainer.model_summary import BestEpochFromSummary -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger - - -DATAMODULE_DICT = { - "DGLFromSmilesDataModule": DGLFromSmilesDataModule, - "DGLOGBDataModule": DGLOGBDataModule, -} - - -def load_datamodule( - config: Union[omegaconf.DictConfig, Dict[str, Any]], -): - module_class = DATAMODULE_DICT[config["datamodule"]["module_type"]] - datamodule = module_class(**config["datamodule"]["args"]) - - return datamodule - - -def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]): - - metrics = {} - cfg_metrics = deepcopy(config["metrics"]) - if cfg_metrics is None: - return metrics - - for this_metric in cfg_metrics: - name = this_metric.pop("name") - metrics[name] = MetricWrapper(**this_metric) - - return metrics - - -def load_architecture( - config: Union[omegaconf.DictConfig, Dict[str, Any]], - in_dim_nodes: int, - in_dim_edges: int, -): - - if isinstance(config, dict): - config = omegaconf.OmegaConf.create(config) - cfg_arch = config["architecture"] - - kwargs = {} - - # Select the architecture - model_type = cfg_arch["model_type"].lower() - if model_type == "fulldglnetwork": - model_class = FullDGLNetwork - elif model_type == "fulldglsiamesenetwork": - model_class = FullDGLSiameseNetwork - kwargs["dist_method"] = cfg_arch["dist_method"] - else: - raise ValueError(f"Unsupported model_type=`{model_type}`") - - # Prepare the various kwargs - pre_nn_kwargs = dict(cfg_arch["pre_nn"]) if cfg_arch["pre_nn"] is not None else None - pre_nn_edges_kwargs = dict(cfg_arch["pre_nn_edges"]) if cfg_arch["pre_nn_edges"] is not None else None - gnn_kwargs = dict(cfg_arch["gnn"]) - post_nn_kwargs = dict(cfg_arch["post_nn"]) if cfg_arch["post_nn"] is not None else None - - # Set the input dimensions - if pre_nn_kwargs is not None: - pre_nn_kwargs = dict(pre_nn_kwargs) - pre_nn_kwargs.setdefault("in_dim", in_dim_nodes) - else: - gnn_kwargs.setdefault("in_dim", in_dim_nodes) - - if pre_nn_edges_kwargs is not None: - pre_nn_edges_kwargs = dict(pre_nn_edges_kwargs) - pre_nn_edges_kwargs.setdefault("in_dim", in_dim_edges) - else: - gnn_kwargs.setdefault("in_dim_edges", in_dim_edges) - - # Set the parameters for the full network - model_kwargs = dict( - gnn_kwargs=gnn_kwargs, - pre_nn_kwargs=pre_nn_kwargs, - pre_nn_edges_kwargs=pre_nn_edges_kwargs, - post_nn_kwargs=post_nn_kwargs, - ) - - return model_class, model_kwargs - - -def load_predictor(config, model_class, model_kwargs, metrics): - # Defining the predictor - - cfg_pred = dict(deepcopy(config["predictor"])) - predictor = PredictorModule( - model_class=model_class, - model_kwargs=model_kwargs, - metrics=metrics, - **cfg_pred, - ) - - return predictor - - -def load_trainer(config): - cfg_trainer = deepcopy(config["trainer"]) - - # Set the number of gpus to 0 if no GPU is available - gpus = cfg_trainer["trainer"].pop("gpus", 0) - num_gpus = 0 - if isinstance(gpus, int): - num_gpus = gpus - elif isinstance(gpus, (list, tuple)): - num_gpus = len(gpus) - if (num_gpus > 0) and (not torch.cuda.is_available()): - logger.warning( - f"Number of GPUs selected is `{num_gpus}`, but will be ignored since no GPU are available on this device" - ) - gpus = 0 - - trainer_kwargs = {} - callbacks = [] - if "early_stopping" in cfg_trainer.keys(): - callbacks.append(EarlyStopping(**cfg_trainer["early_stopping"])) - - if "model_checkpoint" in cfg_trainer.keys(): - callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"])) - - if "logger" in cfg_trainer.keys(): - trainer_kwargs["logger"] = TensorBoardLogger(**cfg_trainer["logger"], default_hp_metric=False) - - trainer_kwargs["callbacks"] = callbacks - - trainer = Trainer( - terminate_on_nan=True, - **cfg_trainer["trainer"], - **trainer_kwargs, - gpus=gpus, - ) - - return trainer +from typing import List, Dict, Union, Any + +import omegaconf +from copy import deepcopy +import torch +from loguru import logger + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning import Trainer + +from goli.trainer.metrics import MetricWrapper +from goli.nn import FullDGLNetwork, FullDGLSiameseNetwork, FeedForwardNN +from goli.data.datamodule import DGLFromSmilesDataModule, DGLOGBDataModule +from goli.trainer.predictor import PredictorModule + +# from goli.trainer.model_summary import BestEpochFromSummary +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger + + +DATAMODULE_DICT = { + "DGLFromSmilesDataModule": DGLFromSmilesDataModule, + "DGLOGBDataModule": DGLOGBDataModule, +} + + +def load_datamodule( + config: Union[omegaconf.DictConfig, Dict[str, Any]], +): + module_class = DATAMODULE_DICT[config["datamodule"]["module_type"]] + datamodule = module_class(**config["datamodule"]["args"]) + + return datamodule + + +def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]): + + metrics = {} + cfg_metrics = deepcopy(config["metrics"]) + if cfg_metrics is None: + return metrics + + for this_metric in cfg_metrics: + name = this_metric.pop("name") + metrics[name] = MetricWrapper(**this_metric) + + return metrics + + +def load_architecture( + config: Union[omegaconf.DictConfig, Dict[str, Any]], + in_dim_nodes: int, + in_dim_edges: int, +): + + if isinstance(config, dict): + config = omegaconf.OmegaConf.create(config) + cfg_arch = config["architecture"] + + kwargs = {} + + # Select the architecture + model_type = cfg_arch["model_type"].lower() + if model_type == "fulldglnetwork": + model_class = FullDGLNetwork + elif model_type == "fulldglsiamesenetwork": + model_class = FullDGLSiameseNetwork + kwargs["dist_method"] = cfg_arch["dist_method"] + else: + raise ValueError(f"Unsupported model_type=`{model_type}`") + + # Prepare the various kwargs + pre_nn_kwargs = dict(cfg_arch["pre_nn"]) if cfg_arch["pre_nn"] is not None else None + pre_nn_edges_kwargs = dict(cfg_arch["pre_nn_edges"]) if cfg_arch["pre_nn_edges"] is not None else None + gnn_kwargs = dict(cfg_arch["gnn"]) + post_nn_kwargs = dict(cfg_arch["post_nn"]) if cfg_arch["post_nn"] is not None else None + + # Set the input dimensions + if pre_nn_kwargs is not None: + pre_nn_kwargs = dict(pre_nn_kwargs) + pre_nn_kwargs.setdefault("in_dim", in_dim_nodes) + else: + gnn_kwargs.setdefault("in_dim", in_dim_nodes) + + if pre_nn_edges_kwargs is not None: + pre_nn_edges_kwargs = dict(pre_nn_edges_kwargs) + pre_nn_edges_kwargs.setdefault("in_dim", in_dim_edges) + else: + gnn_kwargs.setdefault("in_dim_edges", in_dim_edges) + + # Set the parameters for the full network + model_kwargs = dict( + gnn_kwargs=gnn_kwargs, + pre_nn_kwargs=pre_nn_kwargs, + pre_nn_edges_kwargs=pre_nn_edges_kwargs, + post_nn_kwargs=post_nn_kwargs, + ) + + return model_class, model_kwargs + + +def load_predictor(config, model_class, model_kwargs, metrics): + # Defining the predictor + + cfg_pred = dict(deepcopy(config["predictor"])) + predictor = PredictorModule( + model_class=model_class, + model_kwargs=model_kwargs, + metrics=metrics, + **cfg_pred, + ) + + return predictor + + +def load_trainer(config): + cfg_trainer = deepcopy(config["trainer"]) + + # Set the number of gpus to 0 if no GPU is available + gpus = cfg_trainer["trainer"].pop("gpus", 0) + num_gpus = 0 + if isinstance(gpus, int): + num_gpus = gpus + elif isinstance(gpus, (list, tuple)): + num_gpus = len(gpus) + if (num_gpus > 0) and (not torch.cuda.is_available()): + logger.warning( + f"Number of GPUs selected is `{num_gpus}`, but will be ignored since no GPU are available on this device" + ) + gpus = 0 + + trainer_kwargs = {} + callbacks = [] + if "early_stopping" in cfg_trainer.keys(): + callbacks.append(EarlyStopping(**cfg_trainer["early_stopping"])) + + if "model_checkpoint" in cfg_trainer.keys(): + callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"])) + + if "logger" in cfg_trainer.keys(): + trainer_kwargs["logger"] = TensorBoardLogger(**cfg_trainer["logger"], default_hp_metric=False) + + trainer_kwargs["callbacks"] = callbacks + + trainer = Trainer( + terminate_on_nan=True, + **cfg_trainer["trainer"], + **trainer_kwargs, + gpus=gpus, + ) + + return trainer diff --git a/goli/config/config_convert.py b/goli/config/config_convert.py index a6258d3ba..4ba482960 100644 --- a/goli/config/config_convert.py +++ b/goli/config/config_convert.py @@ -1,30 +1,30 @@ -import omegaconf - - -def recursive_config_reformating(configs): - r""" - For a given configuration file, convert all `DictConfig` to `dict`, - all `ListConfig` to `list`, and all `byte` to `str`. - - This helps avoid errors when dumping a yaml file. - """ - - if isinstance(configs, omegaconf.DictConfig): - configs = dict(configs) - elif isinstance(configs, omegaconf.ListConfig): - configs = list(configs) - - if isinstance(configs, dict): - for k, v in configs.items(): - if isinstance(v, bytes): - configs[k] = str(v) - else: - configs[k] = recursive_config_reformating(v) - elif isinstance(configs, list): - for k, v in enumerate(configs): - if isinstance(v, bytes): - configs[k] = str(v) - else: - configs[k] = recursive_config_reformating(v) - - return configs +import omegaconf + + +def recursive_config_reformating(configs): + r""" + For a given configuration file, convert all `DictConfig` to `dict`, + all `ListConfig` to `list`, and all `byte` to `str`. + + This helps avoid errors when dumping a yaml file. + """ + + if isinstance(configs, omegaconf.DictConfig): + configs = dict(configs) + elif isinstance(configs, omegaconf.ListConfig): + configs = list(configs) + + if isinstance(configs, dict): + for k, v in configs.items(): + if isinstance(v, bytes): + configs[k] = str(v) + else: + configs[k] = recursive_config_reformating(v) + elif isinstance(configs, list): + for k, v in enumerate(configs): + if isinstance(v, bytes): + configs[k] = str(v) + else: + configs[k] = recursive_config_reformating(v) + + return configs diff --git a/goli/config/zinc_default_fulldgl.yaml b/goli/config/zinc_default_fulldgl.yaml index e3fec4d64..073ae054d 100644 --- a/goli/config/zinc_default_fulldgl.yaml +++ b/goli/config/zinc_default_fulldgl.yaml @@ -1,72 +1,72 @@ -data: - module_type: "DGLFromSmilesDataModule" - args: - # df_path: null - # cache_data_path: null - - smiles_col: "SMILES" - label_cols: ["SA"] - split_val: 0.2 - split_test: 0.2 - split_seed: 19 - - batch_size_train_val: 16 - batch_size_test: 16 - - featurization: - atom_property_list_float: [] - atom_property_list_onehot: ["atomic-number", "degree"] - edge_property_list: ["in-ring", "bond-type-onehot"] - add_self_loop: false - use_bonds_weights: false - explicit_H: false - -architecture: - model_type: fulldglnetwork - pre_nn_kwargs: - out_dim: 32 - hidden_dims: 32 - depth: 1 - activation: &activation relu - last_activation: &last_activation none - dropout: &dropout 0.1 - batch_norm: &batch_norm True - residual_type: none - - post_nn_kwargs: - out_dim: 32 - hidden_dims: 32 - depth: 2 - activation: *activation - last_activation: *last_activation - dropout: *dropout - batch_norm: *batch_norm - residual_type: none - - gnn_kwargs: - out_dim: 32 - hidden_dims: 32 - depth: 4 - activation: *activation - last_activation: *last_activation - dropout: *dropout - batch_norm: *batch_norm - residual_type: simple - pooling: "sum" - virtual_node: "sum" - layer_type: "pna-msgpass" - aggregators: [mean, max, min, std] - scalers: [identity, amplification, attenuation] - -predictor: - loss_fun: mse - optim_kwargs: - lr: 1.e-3 - weight_decay: 1e-7 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 7 - scheduler_kwargs: - monitor: &monitor val_loss - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss +data: + module_type: "DGLFromSmilesDataModule" + args: + # df_path: null + # cache_data_path: null + + smiles_col: "SMILES" + label_cols: ["SA"] + split_val: 0.2 + split_test: 0.2 + split_seed: 19 + + batch_size_train_val: 16 + batch_size_test: 16 + + featurization: + atom_property_list_float: [] + atom_property_list_onehot: ["atomic-number", "degree"] + edge_property_list: ["in-ring", "bond-type-onehot"] + add_self_loop: false + use_bonds_weights: false + explicit_H: false + +architecture: + model_type: fulldglnetwork + pre_nn_kwargs: + out_dim: 32 + hidden_dims: 32 + depth: 1 + activation: &activation relu + last_activation: &last_activation none + dropout: &dropout 0.1 + batch_norm: &batch_norm True + residual_type: none + + post_nn_kwargs: + out_dim: 32 + hidden_dims: 32 + depth: 2 + activation: *activation + last_activation: *last_activation + dropout: *dropout + batch_norm: *batch_norm + residual_type: none + + gnn_kwargs: + out_dim: 32 + hidden_dims: 32 + depth: 4 + activation: *activation + last_activation: *last_activation + dropout: *dropout + batch_norm: *batch_norm + residual_type: simple + pooling: "sum" + virtual_node: "sum" + layer_type: "pna-msgpass" + aggregators: [mean, max, min, std] + scalers: [identity, amplification, attenuation] + +predictor: + loss_fun: mse + optim_kwargs: + lr: 1.e-3 + weight_decay: 1e-7 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 7 + scheduler_kwargs: + monitor: &monitor val_loss + frequency: 1 + target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss diff --git a/goli/data/__init__.py b/goli/data/__init__.py index b45cb586c..3641838e9 100644 --- a/goli/data/__init__.py +++ b/goli/data/__init__.py @@ -1,8 +1,8 @@ -from .utils import load_micro_zinc -from .utils import load_tiny_zinc - -from .collate import goli_collate_fn - -from .datamodule import DGLOGBDataModule -from .datamodule import DGLFromSmilesDataModule -from .datamodule import DGLOGBDataModule +from .utils import load_micro_zinc +from .utils import load_tiny_zinc + +from .collate import goli_collate_fn + +from .datamodule import DGLOGBDataModule +from .datamodule import DGLFromSmilesDataModule +from .datamodule import DGLOGBDataModule diff --git a/goli/data/collate.py b/goli/data/collate.py index f1d092725..17c283fa5 100644 --- a/goli/data/collate.py +++ b/goli/data/collate.py @@ -1,37 +1,37 @@ -import collections.abc - -from torch.utils.data.dataloader import default_collate - -import dgl - - -def goli_collate_fn(elements): - """This collate function is identical to the default - pytorch collate function but add support for `dgl.DGLGraph` - objects and use `dgl.batch` to batch graphs. - - Beside dgl graph collate, other objects are processed the same way - as the original torch collate function. See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn - for more details. - - Important: - Only dgl graph within a dict are currently supported. It's should not be hard - to support dgl graphs from other objects. - - Note: - If goli needs to manipulate other tricky-to-batch objects. Support - for them should be added to this single collate function. - """ - - elem = elements[0] - - if isinstance(elem, collections.abc.Mapping): - batch = {} - for key in elem: - if isinstance(elem[key], dgl.DGLGraph): - batch[key] = dgl.batch([d[key] for d in elements]) - else: - batch[key] = default_collate([d[key] for d in elements]) - return batch - else: - return default_collate(elements) +import collections.abc + +from torch.utils.data.dataloader import default_collate + +import dgl + + +def goli_collate_fn(elements): + """This collate function is identical to the default + pytorch collate function but add support for `dgl.DGLGraph` + objects and use `dgl.batch` to batch graphs. + + Beside dgl graph collate, other objects are processed the same way + as the original torch collate function. See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn + for more details. + + Important: + Only dgl graph within a dict are currently supported. It's should not be hard + to support dgl graphs from other objects. + + Note: + If goli needs to manipulate other tricky-to-batch objects. Support + for them should be added to this single collate function. + """ + + elem = elements[0] + + if isinstance(elem, collections.abc.Mapping): + batch = {} + for key in elem: + if isinstance(elem[key], dgl.DGLGraph): + batch[key] = dgl.batch([d[key] for d in elements]) + else: + batch[key] = default_collate([d[key] for d in elements]) + return batch + else: + return default_collate(elements) diff --git a/goli/data/datamodule.py b/goli/data/datamodule.py index abe9a92c6..caccfec87 100644 --- a/goli/data/datamodule.py +++ b/goli/data/datamodule.py @@ -1,957 +1,957 @@ -from typing import List, Dict, Union, Any, Callable, Optional, Tuple, Iterable - -import os -import functools -import importlib.resources -import zipfile - -from loguru import logger -import fsspec -import omegaconf - -import pandas as pd -import numpy as np - -from sklearn.model_selection import train_test_split - -import dgl -import pytorch_lightning as pl - -import datamol as dm - -from goli.utils import fs -from goli.features import mol_to_dglgraph -from goli.features import mol_to_dglgraph_signature -from goli.data.collate import goli_collate_fn -from goli.utils.arg_checker import check_arg_iterator - - -import torch -from torch.utils.data.dataloader import DataLoader, Dataset -from torch.utils.data import Subset - - -class DGLDataset(Dataset): - def __init__( - self, - features: List[dgl.DGLGraph], - labels: Union[torch.Tensor, np.ndarray], - smiles: Optional[List[str]] = None, - indices: Optional[List[str]] = None, - weights: Optional[Union[torch.Tensor, np.ndarray]] = None, - ): - self.smiles = smiles - self.features = features - self.labels = labels - self.indices = indices - self.weights = weights - - def __len__(self): - return len(self.features) - - def __getitem__(self, idx): - datum = {} - - if self.smiles is not None: - datum["smiles"] = self.smiles[idx] - - if self.indices is not None: - datum["indices"] = self.indices[idx] - - if self.weights is not None: - datum["weights"] = self.weights[idx] - - datum["features"] = self.features[idx] - datum["labels"] = self.labels[idx] - return datum - - -class DGLBaseDataModule(pl.LightningDataModule): - def __init__( - self, - batch_size_train_val: int = 16, - batch_size_test: int = 16, - num_workers: int = 0, - pin_memory: bool = True, - persistent_workers: bool = False, - collate_fn: Optional[Callable] = None, - ): - super().__init__() - - self.batch_size_train_val = batch_size_train_val - self.batch_size_test = batch_size_test - - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers - - if collate_fn is None: - self.collate_fn = goli_collate_fn - else: - self.collate_fn = collate_fn - - self.dataset = None - self.train_ds = None - self.val_ds = None - self.test_ds = None - - def prepare_data(self): - raise NotImplementedError() - - def setup(self): - raise NotImplementedError() - - def train_dataloader(self, **kwargs): - return self._dataloader( - dataset=self.train_ds, # type: ignore - batch_size=self.batch_size_train_val, - shuffle=True, - ) - - def val_dataloader(self, **kwargs): - return self._dataloader( - dataset=self.val_ds, # type: ignore - batch_size=self.batch_size_train_val, - shuffle=False, - ) - - def test_dataloader(self, **kwargs): - - return self._dataloader( - dataset=self.test_ds, # type: ignore - batch_size=self.batch_size_test, - shuffle=False, - ) - - def predict_dataloader(self, **kwargs): - - return self._dataloader( - dataset=self.dataset, # type: ignore - batch_size=self.batch_size_test, - shuffle=False, - ) - - @property - def is_prepared(self): - raise NotImplementedError() - - @property - def is_setup(self): - raise NotImplementedError() - - @property - def num_node_feats(self): - raise NotImplementedError() - - @property - def num_edge_feats(self): - raise NotImplementedError() - - def get_first_graph(self): - raise NotImplementedError() - - # Private methods - - @staticmethod - def _read_csv(path, **kwargs): - if str(path).endswith((".csv", ".csv.gz", ".csv.zip", ".csv.bz2")): - sep = "," - elif str(path).endswith((".tsv", ".tsv.gz", ".tsv.zip", ".tsv.bz2")): - sep = "\t" - else: - raise ValueError(f"unsupported file `{path}`") - kwargs.setdefault("sep", sep) - df = pd.read_csv(path, **kwargs) - return df - - def _dataloader(self, dataset: Dataset, batch_size: int, shuffle: bool): - """Get a dataloader for a given dataset""" - - if self.num_workers == -1: - num_workers = os.cpu_count() - num_workers = num_workers if num_workers is not None else 0 - else: - num_workers = self.num_workers - - loader = DataLoader( - dataset=dataset, - num_workers=num_workers, - collate_fn=self.collate_fn, - pin_memory=self.pin_memory, - batch_size=batch_size, - shuffle=shuffle, - persistent_workers=self.persistent_workers, - ) - return loader - - -class DGLFromSmilesDataModule(DGLBaseDataModule): - """ - NOTE(hadim): let's make only one class for the moment and refactor with a parent class - once we have more concrete datamodules to implement. The class should be general enough - to be easily refactored. - - NOTE(hadim): splitting is not very full-featured yet; only random splitting on-the-fly - is allowed using a seed. Next is to add the availability to provide split indices data as input. - - NOTE(hadim): implement using weights. It should probably be a column in the dataframe. - """ - - def __init__( - self, - df: pd.DataFrame = None, - df_path: Optional[Union[str, os.PathLike]] = None, - cache_data_path: Optional[Union[str, os.PathLike]] = None, - featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, - smiles_col: str = None, - label_cols: List[str] = None, - weights_col: str = None, - weights_type: str = None, - idx_col: str = None, - sample_size: Union[int, float, type(None)] = None, - split_val: float = 0.2, - split_test: float = 0.2, - split_seed: int = None, - splits_path: Optional[Union[str, os.PathLike]] = None, - batch_size_train_val: int = 16, - batch_size_test: int = 16, - num_workers: int = 0, - pin_memory: bool = True, - persistent_workers: bool = False, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - collate_fn: Optional[Callable] = None, - ): - """ - - Parameters: - df: a dataframe. - df_path: a path to a dataframe to load (CSV file). `df` takes precedence over - `df_path`. - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). - featurization: args to apply to the SMILES to DGL featurizer. - smiles_col: Name of the SMILES column. If set to `None`, it will look for - a column with the word "smile" (case insensitive) in it. - If no such column is found, an error will be raised. - label_cols: Name of the columns to use as labels, with different options. - - - `list`: A list of all column names to use - - `None`: All the columns are used except the SMILES one. - - `str`: The name of the single column to use - - `*str`: A string starting by a `*` means all columns whose name - ends with the specified `str` - - `str*`: A string ending by a `*` means all columns whose name - starts with the specified `str` - - weights_col: Name of the column to use as sample weights. If `None`, no - weights are used. This parameter cannot be used together with `weights_type`. - weights_type: The type of weights to use. This parameter cannot be used together with `weights_col`. - **It only supports multi-label binary classification.** - - Supported types: - - - `None`: No weights are used. - - `"sample_balanced"`: A weight is assigned to each sample inversely - proportional to the number of positive value. If there are multiple - labels, the product of the weights is used. - - `"sample_label_balanced"`: Similar to the `"sample_balanced"` weights, - but the weights are applied to each element individually, without - computing the product of the weights for a given sample. - - idx_col: Name of the columns to use as indices. Unused if set to None. - split_val: Ratio for the validation split. - split_test: Ratio for the test split. - split_seed: Seed to use for the random split. More complex splitting strategy - should be implemented. - splits_path: A path a CSV file containing indices for the splits. The file must contains - 3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`. - batch_size_train_val: batch size for training and val dataset. - batch_size_test: batch size for test dataset. - num_workers: Number of workers for the dataloader. Use -1 to use all available - cores. - pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - collate_fn: A custom torch collate function. Default is to `goli.data.goli_collate_fn` - sample_size: - - - `int`: The maximum number of elements to take from the dataset. - - `float`: Value between 0 and 1 representing the fraction of the dataset to consider - - `None`: all elements are considered. - """ - super().__init__( - batch_size_train_val=batch_size_train_val, - batch_size_test=batch_size_test, - num_workers=num_workers, - pin_memory=pin_memory, - persistent_workers=persistent_workers, - collate_fn=collate_fn, - ) - - self.df = df - self.df_path = df_path - - self.cache_data_path = str(cache_data_path) if cache_data_path is not None else None - self.featurization = featurization - - self.smiles_col = smiles_col - self.label_cols = self._parse_label_cols(label_cols, smiles_col) - self.idx_col = idx_col - self.sample_size = sample_size - - self.weights_col = weights_col - self.weights_type = weights_type - if self.weights_col is not None: - assert self.weights_type is None - - self.split_val = split_val - self.split_test = split_test - self.split_seed = split_seed - self.splits_path = splits_path - - self.featurization_n_jobs = featurization_n_jobs - self.featurization_progress = featurization_progress - - self.dataset = None - self.train_ds = None - self.val_ds = None - self.test_ds = None - self.train_indices = None - self.val_indices = None - self.test_indices = None - - def prepare_data(self): - """Called only from a single process in distributed settings. Steps: - - - If cache is set and exists, reload from cache. - - Load the dataframe if its a path. - - Extract smiles and labels from the dataframe. - - Compute the features. - - Compute or set split indices. - - Make the list of dict dataset. - """ - - if self._load_from_cache(): - return True - - # Load the dataframe - if self.df is None: - # Only load the useful columns, as some dataset can be very large - # when loading all columns - usecols = ( - check_arg_iterator(self.smiles_col, enforce_type=list) - + check_arg_iterator(self.label_cols, enforce_type=list) - + check_arg_iterator(self.idx_col, enforce_type=list) - + check_arg_iterator(self.weights_col, enforce_type=list) - ) - - df = self._read_csv(self.df_path, usecols=usecols) - else: - df = self.df - - df = self._sub_sample_df(df) - - logger.info(f"Prepare dataset with {len(df)} data points.") - - # Extract smiles and labels - smiles, labels, indices, weights, sample_idx = self._extract_smiles_labels( - df, - smiles_col=self.smiles_col, - label_cols=self.label_cols, - idx_col=self.idx_col, - weights_col=self.weights_col, - weights_type=self.weights_type, - ) - - # Precompute the features - # NOTE(hadim): in case of very large dataset we could: - # - or cache the data and read from it during `next(iter(dataloader))` - # - or compute the features on-the-fly during `next(iter(dataloader))` - # For now we compute in advance and hold everything in memory. - featurization_args = self.featurization or {} - transform_smiles = functools.partial(mol_to_dglgraph, **featurization_args) - features = dm.utils.parallelized( - transform_smiles, - smiles, - progress=self.featurization_progress, - n_jobs=self.featurization_n_jobs, - ) - - # Warn about None molecules - is_none = np.array([ii for ii, feat in enumerate(features) if feat is None]) - if len(is_none) > 0: - mols_to_msg = [f"{sample_idx[idx]}: {smiles[idx]}" for idx in is_none] - msg = "\n".join(mols_to_msg) - logger.warning( - (f"{len(is_none)} molecules will be removed since they failed featurization:\n" + msg) - ) - - # Remove None molecules - if len(is_none) > 0: - df.drop(df.index[is_none], axis=0) - features = [feat for feat in features if not (feat is None)] - sample_idx = np.delete(sample_idx, is_none, axis=0) - smiles = np.delete(smiles, is_none, axis=0) - if labels is not None: - labels = np.delete(labels, is_none, axis=0) - if weights is not None: - weights = np.delete(weights, is_none, axis=0) - if indices is not None: - indices = np.delete(indices, is_none, axis=0) - - # Get splits indices - self.train_indices, self.val_indices, self.test_indices = self._get_split_indices( - len(df), - split_val=self.split_val, - split_test=self.split_test, - split_seed=self.split_seed, - splits_path=self.splits_path, - sample_idx=sample_idx, - ) - - # Make the torch datasets (mostly a wrapper there is no memory overhead here) - self.dataset = DGLDataset( - smiles=smiles, - features=features, - labels=labels, - indices=indices, - weights=weights, - ) - - self._save_to_cache() - - def setup(self, stage: str = None): - """Prepare the torch dataset. Called on every GPUs. Setting state here is ok.""" - - if stage == "fit" or stage is None: - self.train_ds = Subset(self.dataset, self.train_indices) # type: ignore - self.val_ds = Subset(self.dataset, self.val_indices) # type: ignore - - if stage == "test" or stage is None: - self.test_ds = Subset(self.dataset, self.test_indices) # type: ignore - - def _parse_label_cols(self, label_cols: Union[type(None), str, List[str]], smiles_col: str) -> List[str]: - r""" - Parse the choice of label columns depending on the type of input. - The input parameters `label_cols` and `smiles_col` are described in - the `__init__` method. - """ - if self.df is None: - # Only load the useful columns, as some dataset can be very large - # when loading all columns - df = self._read_csv(self.df_path, nrows=0) - else: - df = self.df - cols = list(df.columns) - - # A star `*` at the beginning or end of the string specifies to look for all - # columns that starts/end with a specific string - if isinstance(label_cols, str): - if label_cols[0] == "*": - label_cols = [col for col in cols if str(col).endswith(label_cols[1:])] - elif label_cols[-1] == "*": - label_cols = [col for col in cols if str(col).startswith(label_cols[:-1])] - else: - label_cols = [label_cols] - - elif label_cols is None: - label_cols = [col for col in cols if col != smiles_col] - - return check_arg_iterator(label_cols, enforce_type=list) - - @property - def is_prepared(self): - if not hasattr(self, "dataset"): - return False - return getattr(self, "dataset") is not None - - @property - def is_setup(self): - if not hasattr(self, "train_ds"): - return False - return getattr(self, "train_ds") is not None - - @property - def num_node_feats(self): - """Return the number of node features in the first graph""" - - graph = self.get_first_graph() - num_feats = 0 - if "feat" in graph.ndata.keys(): - num_feats += graph.ndata["feat"].shape[1] - return num_feats - - @property - def num_node_feats_with_positional_encoding(self): - """Return the number of node features in the first graph - including positional encoding features.""" - - graph = self.get_first_graph() - num_feats = 0 - if "feat" in graph.ndata.keys(): - num_feats += graph.ndata["feat"].shape[1] - if "pos_enc_feats_sign_flip" in graph.ndata.keys(): - num_feats += graph.ndata["pos_enc_feats_sign_flip"].shape[1] - if "pos_enc_feats_no_flip" in graph.ndata.keys(): - num_feats += graph.ndata["pos_enc_feats_no_flip"].shape[1] - return num_feats - - @property - def num_edge_feats(self): - """Return the number of edge features in the first graph""" - - graph = self.get_first_graph() - if "feat" in graph.edata.keys(): - return graph.edata["feat"].shape[1] # type: ignore - else: - return 0 - - def get_first_graph(self): - """ - Low memory footprint method to get the first datapoint DGL graph. - The first 10 rows of the data are read in case the first one has a featurization - error. If all 10 first element, then `None` is returned, otherwise the first - graph to not fail is returned. - """ - if self.df is None: - df = self._read_csv(self.df_path, nrows=10) - else: - df = self.df.iloc[0:10, :] - - smiles, _, _, _, _ = self._extract_smiles_labels(df, self.smiles_col, self.label_cols) - - featurization_args = self.featurization or {} - transform_smiles = functools.partial(mol_to_dglgraph, **featurization_args) - graph = None - for s in smiles: - graph = transform_smiles(s) - if graph is not None: - break - return graph - - # Private methods - - def _save_to_cache(self): - """Save the built dataset, indices and featurization arguments into a cache file.""" - - # Cache on disk - if self.cache_data_path is not None: - logger.info(f"Write prepared datamodule to {self.cache_data_path}") - cache = {} - cache["dataset"] = self.dataset - cache["train_indices"] = self.train_indices - cache["val_indices"] = self.val_indices - cache["test_indices"] = self.test_indices - - # Save featurization args used - cache["featurization_args"] = mol_to_dglgraph_signature(dict(self.featurization or {})) - - with fsspec.open(self.cache_data_path, "wb") as f: - torch.save(cache, f) - - def _load_from_cache(self): - """Attempt to reload the data from cache. Return True if data has been - reloaded from the cache. - """ - - if self.cache_data_path is None: - # Cache path is not provided. - return False - - if not fs.exists(self.cache_data_path): - # Cache patch does not exist. - return False - - # Reload from cache if it exists and is valid - logger.info(f"Try reloading the data module from {self.cache_data_path}.") - - # Load cache - with fsspec.open(self.cache_data_path, "rb") as f: - cache = torch.load(f) - - # Are the required keys present? - excepted_cache_keys = { - "dataset", - "test_indices", - "train_indices", - "val_indices", - } - if not set(cache.keys()) != excepted_cache_keys: - logger.info( - f"Cache looks invalid with keys: {cache.keys()}. Excepted keys are {excepted_cache_keys}" - ) - logger.info("Fallback to regular data preparation steps.") - return False - - # Is the featurization signature the same? - current_signature = mol_to_dglgraph_signature(dict(self.featurization or {})) - cache_signature = mol_to_dglgraph_signature(cache["featurization_args"]) - - if current_signature != cache_signature: - logger.info(f"Cache featurizer arguments are different than the provided ones.") - logger.info(f"Cache featurizer arguments: {cache_signature}") - logger.info(f"Provided featurizer arguments: {current_signature}.") - logger.info("Fallback to regular data preparation steps.") - return False - - # At this point the cache can be loaded - - self.dataset = cache["dataset"] - self.train_indices = cache["train_indices"] - self.val_indices = cache["val_indices"] - self.test_indices = cache["test_indices"] - - logger.info(f"Datamodule correctly reloaded from cache.") - - return True - - def _extract_smiles_labels( - self, - df: pd.DataFrame, - smiles_col: str = None, - label_cols: List[str] = [], - idx_col: str = None, - weights_col: str = None, - weights_type: str = None, - ) -> Tuple[ - np.ndarray, np.ndarray, Union[type(None), np.ndarray], Union[type(None), np.ndarray], np.ndarray - ]: - """For a given dataframe extract the SMILES and labels columns. Smiles is returned as a list - of string while labels are returned as a 2D numpy array. - """ - - if smiles_col is None: - smiles_col_all = [col for col in df.columns if "smile" in str(col).lower()] - if len(smiles_col_all) == 0: - raise ValueError(f"No SMILES column found in dataframe. Columns are {df.columns}") - elif len(smiles_col_all) > 1: - raise ValueError( - f"Multiple SMILES column found in dataframe. SMILES Columns are {smiles_col_all}" - ) - - smiles_col = smiles_col_all[0] - - if label_cols is None: - label_cols = df.columns.drop(smiles_col) - - label_cols = check_arg_iterator(label_cols, enforce_type=list) - - smiles = df[smiles_col].values - labels = [pd.to_numeric(df[col], errors="coerce") for col in label_cols] - labels = np.stack(labels, axis=1) - - indices = None - if idx_col is not None: - indices = df[idx_col].values - - sample_idx = df.index.values - - # Extract the weights - weights = None - if weights_col is not None: - weights = df[weights_col].values - elif weights_type is not None: - if not np.all((labels == 0) | (labels == 1)): - raise ValueError("Labels must be binary for `weights_type`") - - if weights_type == "sample_label_balanced": - ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0] - weights = np.zeros(labels.shape) - weights[labels == 0] = ratio_pos_neg - weights[labels == 1] = ratio_pos_neg ** -1 - - elif weights_type == "sample_balanced": - ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0] - weights = np.zeros(labels.shape) - weights[labels == 0] = ratio_pos_neg - weights[labels == 1] = ratio_pos_neg ** -1 - weights = np.prod(weights, axis=1) - - else: - raise ValueError(f"Undefined `weights_type` {weights_type}") - - weights /= np.max(weights) # Put the max weight to 1 - - return smiles, labels, indices, weights, sample_idx - - def _get_split_indices( - self, - dataset_size: int, - split_val: float, - split_test: float, - sample_idx: Optional[Iterable[int]] = None, - split_seed: int = None, - splits_path: Union[str, os.PathLike] = None, - ): - """Compute indices of random splits.""" - - if sample_idx is None: - sample_idx = np.arange(dataset_size) - - if splits_path is None: - # Random splitting - train_indices, val_test_indices = train_test_split( - sample_idx, - test_size=split_val + split_test, - random_state=split_seed, - ) - - sub_split_test = split_test / (split_test + split_val) - val_indices, test_indices = train_test_split( - val_test_indices, - test_size=sub_split_test, - random_state=split_seed, - ) - - else: - # Split from an indices file - with fsspec.open(str(splits_path)) as f: - splits = self._read_csv(splits_path) - - train_indices = splits["train"].dropna().astype("int").tolist() - val_indices = splits["val"].dropna().astype("int").tolist() - test_indices = splits["test"].dropna().astype("int").tolist() - - # Filter train, val and test indices - _, train_idx, _ = np.intersect1d(sample_idx,train_indices, return_indices=True) - train_indices = train_idx.tolist() - train_indices.sort() - _, valid_idx, _ = np.intersect1d(sample_idx,val_indices, return_indices=True) - val_indices = valid_idx.tolist() - val_indices.sort() - _, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True) - test_indices = test_idx.tolist() - test_indices.sort() - - # train_indices = [ii for ii, idx in enumerate(sample_idx) if idx in train_indices] - # val_indices = [ii for ii, idx in enumerate(sample_idx) if idx in val_indices] - # test_indices = [ii for ii, idx in enumerate(sample_idx) if idx in test_indices] - - return train_indices, val_indices, test_indices - - def _sub_sample_df(self, df): - # Sub-sample the dataframe - if isinstance(self.sample_size, int): - n = min(self.sample_size, df.shape[0]) - df = df.sample(n=n) - elif isinstance(self.sample_size, float): - df = df.sample(f=self.sample_size) - elif self.sample_size is None: - pass - else: - raise ValueError(f"Wrong value for `self.sample_size`: {self.sample_size}") - - return df - - def __len__(self) -> int: - r""" - Returns the number of elements of the current DataModule - """ - if self.df is None: - df = self._read_csv(self.df_path, usecols=[self.smiles_col]) - else: - df = self.df - - return len(df) - - def to_dict(self): - obj_repr = {} - obj_repr["name"] = self.__class__.__name__ - obj_repr["len"] = len(self) - obj_repr["train_size"] = len(self.train_indices) if self.train_indices is not None else None - obj_repr["val_size"] = len(self.val_indices) if self.val_indices is not None else None - obj_repr["test_size"] = len(self.test_indices) if self.test_indices is not None else None - obj_repr["batch_size_train_val"] = self.batch_size_train_val - obj_repr["batch_size_test"] = self.batch_size_test - obj_repr["num_node_feats"] = self.num_node_feats - obj_repr["num_node_feats_with_positional_encoding"] = self.num_node_feats_with_positional_encoding - obj_repr["num_edge_feats"] = self.num_edge_feats - obj_repr["num_labels"] = len(self.label_cols) - obj_repr["collate_fn"] = self.collate_fn.__name__ - obj_repr["featurization"] = self.featurization - return obj_repr - - def __repr__(self): - """Controls how the class is printed""" - return omegaconf.OmegaConf.to_yaml(self.to_dict()) - - -class DGLOGBDataModule(DGLFromSmilesDataModule): - """Load an OGB GraphProp dataset.""" - - def __init__( - self, - dataset_name: str, - cache_data_path: Optional[Union[str, os.PathLike]] = None, - featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, - weights_col: str = None, - weights_type: str = None, - sample_size: Union[int, float, type(None)] = None, - batch_size_train_val: int = 16, - batch_size_test: int = 16, - num_workers: int = 0, - pin_memory: bool = True, - persistent_workers: bool = False, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - collate_fn: Optional[Callable] = None, - ): - """ - - Parameters: - dataset_name: Name of the OGB dataset to load. Examples of possible datasets are - "ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv". - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). - featurization: args to apply to the SMILES to DGL featurizer. - batch_size_train_val: batch size for training and val dataset. - batch_size_test: batch size for test dataset. - num_workers: Number of workers for the dataloader. Use -1 to use all available - cores. - pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - collate_fn: A custom torch collate function. Default is to `goli.data.goli_collate_fn` - sample_size: - - - `int`: The maximum number of elements to take from the dataset. - - `float`: Value between 0 and 1 representing the fraction of the dataset to consider - - `None`: all elements are considered. - """ - - self.dataset_name = dataset_name - - # Get OGB metadata - self.metadata = self._get_dataset_metadata(self.dataset_name) - - # Get dataset - df, idx_col, smiles_col, label_cols, splits_path = self._load_dataset(self.metadata) - - # Config for datamodule - dm_args = {} - dm_args["df"] = df - dm_args["cache_data_path"] = cache_data_path - dm_args["featurization"] = featurization - dm_args["smiles_col"] = smiles_col - dm_args["label_cols"] = label_cols - dm_args["idx_col"] = idx_col - dm_args["splits_path"] = splits_path - dm_args["batch_size_train_val"] = batch_size_train_val - dm_args["batch_size_test"] = batch_size_test - dm_args["num_workers"] = num_workers - dm_args["pin_memory"] = pin_memory - dm_args["featurization_n_jobs"] = featurization_n_jobs - dm_args["featurization_progress"] = featurization_progress - dm_args["persistent_workers"] = persistent_workers - dm_args["collate_fn"] = collate_fn - dm_args["weights_col"] = weights_col - dm_args["weights_type"] = weights_type - dm_args["sample_size"] = sample_size - - # Init DGLFromSmilesDataModule - super().__init__(**dm_args) - - def to_dict(self): - obj_repr = {} - obj_repr["dataset_name"] = self.dataset_name - obj_repr.update(super().to_dict()) - return obj_repr - - # Private methods - - def _load_dataset(self, metadata: dict): - """Download, extract and load an OGB dataset.""" - - base_dir = fs.get_cache_dir("ogb") - if metadata['download_name'] == "pcqm4m": - dataset_dir = base_dir / (metadata["download_name"] + "_kddcup2021") - else: - dataset_dir = base_dir / metadata["download_name"] - - if not dataset_dir.exists(): - - # Create cache filepath for zip file and associated folder - dataset_path = base_dir / f"{metadata['download_name']}.zip" - - # Download it - if not dataset_path.exists(): - logger.info(f"Downloading {metadata['url']} to {dataset_path}") - fs.copy(metadata["url"], dataset_path, progress=True) - - # Extract - zf = zipfile.ZipFile(dataset_path) - zf.extractall(base_dir) - - # Load CSV file - if metadata['download_name']== "pcqm4m": - df_path = dataset_dir / "raw" / "data.csv.gz" - else: - df_path = dataset_dir / "mapping" / "mol.csv.gz" - logger.info(f"Loading {df_path} in memory.") - df = pd.read_csv(df_path) - - # Load split from the OGB dataset and save them in a single CSV file - if metadata['download_name'] == "pcqm4m": - split_name = metadata["split"] - split_dict = torch.load(dataset_dir / "split_dict.pt") - train_split = pd.DataFrame(split_dict['train']) - val_split = pd.DataFrame(split_dict['valid']) - test_split = pd.DataFrame(split_dict['test']) - splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore - splits.columns = ["train", "val", "test"] - - splits_path = dataset_dir / "split" - if not splits_path.exists(): - os.makedirs(splits_path) - splits_path = dataset_dir / f"{split_name}.csv.gz" - else: - splits_path = splits_path / f"{split_name}.csv.gz" - logger.info(f"Saving splits to {splits_path}") - splits.to_csv(splits_path, index=None) - else: - split_name = metadata["split"] - train_split = pd.read_csv(dataset_dir / "split" / split_name / "train.csv.gz", header=None) # type: ignore - val_split = pd.read_csv(dataset_dir / "split" / split_name / "valid.csv.gz", header=None) # type: ignore - test_split = pd.read_csv(dataset_dir / "split" / split_name / "test.csv.gz", header=None) # type: ignore - - splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore - splits.columns = ["train", "val", "test"] - - splits_path = dataset_dir / "split" / f"{split_name}.csv.gz" - logger.info(f"Saving splits to {splits_path}") - splits.to_csv(splits_path, index=None) - - # Get column names: OGB columns are predictable - if metadata['download_name'] == "pcqm4m": - idx_col = df.columns[0] - smiles_col = df.columns[-2] - label_cols = df.columns[-1:].to_list() - else: - idx_col = df.columns[-1] - smiles_col = df.columns[-2] - label_cols = df.columns[:-2].to_list() - - return df, idx_col, smiles_col, label_cols, splits_path - - def _get_dataset_metadata(self, dataset_name: str): - ogb_metadata = self._get_ogb_metadata() - - if dataset_name not in ogb_metadata.index: - raise ValueError(f"'{dataset_name}' is not a valid dataset name.") - - return ogb_metadata.loc[dataset_name].to_dict() - - def _get_ogb_metadata(self): - """Get the metadata of OGB GraphProp datasets.""" - - with importlib.resources.open_text("ogb.graphproppred", "master.csv") as f: - ogb_metadata = pd.read_csv(f) - - ogb_metadata = ogb_metadata.set_index(ogb_metadata.columns[0]) - ogb_metadata = ogb_metadata.T - - # Only keep datasets of type 'mol' - ogb_metadata = ogb_metadata[ogb_metadata["data type"] == "mol"] - - return ogb_metadata +from typing import List, Dict, Union, Any, Callable, Optional, Tuple, Iterable + +import os +import functools +import importlib.resources +import zipfile + +from loguru import logger +import fsspec +import omegaconf + +import pandas as pd +import numpy as np + +from sklearn.model_selection import train_test_split + +import dgl +import pytorch_lightning as pl + +import datamol as dm + +from goli.utils import fs +from goli.features import mol_to_dglgraph +from goli.features import mol_to_dglgraph_signature +from goli.data.collate import goli_collate_fn +from goli.utils.arg_checker import check_arg_iterator + + +import torch +from torch.utils.data.dataloader import DataLoader, Dataset +from torch.utils.data import Subset + + +class DGLDataset(Dataset): + def __init__( + self, + features: List[dgl.DGLGraph], + labels: Union[torch.Tensor, np.ndarray], + smiles: Optional[List[str]] = None, + indices: Optional[List[str]] = None, + weights: Optional[Union[torch.Tensor, np.ndarray]] = None, + ): + self.smiles = smiles + self.features = features + self.labels = labels + self.indices = indices + self.weights = weights + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + datum = {} + + if self.smiles is not None: + datum["smiles"] = self.smiles[idx] + + if self.indices is not None: + datum["indices"] = self.indices[idx] + + if self.weights is not None: + datum["weights"] = self.weights[idx] + + datum["features"] = self.features[idx] + datum["labels"] = self.labels[idx] + return datum + + +class DGLBaseDataModule(pl.LightningDataModule): + def __init__( + self, + batch_size_train_val: int = 16, + batch_size_test: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable] = None, + ): + super().__init__() + + self.batch_size_train_val = batch_size_train_val + self.batch_size_test = batch_size_test + + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + if collate_fn is None: + self.collate_fn = goli_collate_fn + else: + self.collate_fn = collate_fn + + self.dataset = None + self.train_ds = None + self.val_ds = None + self.test_ds = None + + def prepare_data(self): + raise NotImplementedError() + + def setup(self): + raise NotImplementedError() + + def train_dataloader(self, **kwargs): + return self._dataloader( + dataset=self.train_ds, # type: ignore + batch_size=self.batch_size_train_val, + shuffle=True, + ) + + def val_dataloader(self, **kwargs): + return self._dataloader( + dataset=self.val_ds, # type: ignore + batch_size=self.batch_size_train_val, + shuffle=False, + ) + + def test_dataloader(self, **kwargs): + + return self._dataloader( + dataset=self.test_ds, # type: ignore + batch_size=self.batch_size_test, + shuffle=False, + ) + + def predict_dataloader(self, **kwargs): + + return self._dataloader( + dataset=self.dataset, # type: ignore + batch_size=self.batch_size_test, + shuffle=False, + ) + + @property + def is_prepared(self): + raise NotImplementedError() + + @property + def is_setup(self): + raise NotImplementedError() + + @property + def num_node_feats(self): + raise NotImplementedError() + + @property + def num_edge_feats(self): + raise NotImplementedError() + + def get_first_graph(self): + raise NotImplementedError() + + # Private methods + + @staticmethod + def _read_csv(path, **kwargs): + if str(path).endswith((".csv", ".csv.gz", ".csv.zip", ".csv.bz2")): + sep = "," + elif str(path).endswith((".tsv", ".tsv.gz", ".tsv.zip", ".tsv.bz2")): + sep = "\t" + else: + raise ValueError(f"unsupported file `{path}`") + kwargs.setdefault("sep", sep) + df = pd.read_csv(path, **kwargs) + return df + + def _dataloader(self, dataset: Dataset, batch_size: int, shuffle: bool): + """Get a dataloader for a given dataset""" + + if self.num_workers == -1: + num_workers = os.cpu_count() + num_workers = num_workers if num_workers is not None else 0 + else: + num_workers = self.num_workers + + loader = DataLoader( + dataset=dataset, + num_workers=num_workers, + collate_fn=self.collate_fn, + pin_memory=self.pin_memory, + batch_size=batch_size, + shuffle=shuffle, + persistent_workers=self.persistent_workers, + ) + return loader + + +class DGLFromSmilesDataModule(DGLBaseDataModule): + """ + NOTE(hadim): let's make only one class for the moment and refactor with a parent class + once we have more concrete datamodules to implement. The class should be general enough + to be easily refactored. + + NOTE(hadim): splitting is not very full-featured yet; only random splitting on-the-fly + is allowed using a seed. Next is to add the availability to provide split indices data as input. + + NOTE(hadim): implement using weights. It should probably be a column in the dataframe. + """ + + def __init__( + self, + df: pd.DataFrame = None, + df_path: Optional[Union[str, os.PathLike]] = None, + cache_data_path: Optional[Union[str, os.PathLike]] = None, + featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, + smiles_col: str = None, + label_cols: List[str] = None, + weights_col: str = None, + weights_type: str = None, + idx_col: str = None, + sample_size: Union[int, float, type(None)] = None, + split_val: float = 0.2, + split_test: float = 0.2, + split_seed: int = None, + splits_path: Optional[Union[str, os.PathLike]] = None, + batch_size_train_val: int = 16, + batch_size_test: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + featurization_n_jobs: int = -1, + featurization_progress: bool = False, + collate_fn: Optional[Callable] = None, + ): + """ + + Parameters: + df: a dataframe. + df_path: a path to a dataframe to load (CSV file). `df` takes precedence over + `df_path`. + cache_data_path: path where to save or reload the cached data. The path can be + remote (S3, GS, etc). + featurization: args to apply to the SMILES to DGL featurizer. + smiles_col: Name of the SMILES column. If set to `None`, it will look for + a column with the word "smile" (case insensitive) in it. + If no such column is found, an error will be raised. + label_cols: Name of the columns to use as labels, with different options. + + - `list`: A list of all column names to use + - `None`: All the columns are used except the SMILES one. + - `str`: The name of the single column to use + - `*str`: A string starting by a `*` means all columns whose name + ends with the specified `str` + - `str*`: A string ending by a `*` means all columns whose name + starts with the specified `str` + + weights_col: Name of the column to use as sample weights. If `None`, no + weights are used. This parameter cannot be used together with `weights_type`. + weights_type: The type of weights to use. This parameter cannot be used together with `weights_col`. + **It only supports multi-label binary classification.** + + Supported types: + + - `None`: No weights are used. + - `"sample_balanced"`: A weight is assigned to each sample inversely + proportional to the number of positive value. If there are multiple + labels, the product of the weights is used. + - `"sample_label_balanced"`: Similar to the `"sample_balanced"` weights, + but the weights are applied to each element individually, without + computing the product of the weights for a given sample. + + idx_col: Name of the columns to use as indices. Unused if set to None. + split_val: Ratio for the validation split. + split_test: Ratio for the test split. + split_seed: Seed to use for the random split. More complex splitting strategy + should be implemented. + splits_path: A path a CSV file containing indices for the splits. The file must contains + 3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`. + batch_size_train_val: batch size for training and val dataset. + batch_size_test: batch size for test dataset. + num_workers: Number of workers for the dataloader. Use -1 to use all available + cores. + pin_memory: Whether to pin on paginated CPU memory for the dataloader. + featurization_n_jobs: Number of cores to use for the featurization. + featurization_progress: whether to show a progress bar during featurization. + collate_fn: A custom torch collate function. Default is to `goli.data.goli_collate_fn` + sample_size: + + - `int`: The maximum number of elements to take from the dataset. + - `float`: Value between 0 and 1 representing the fraction of the dataset to consider + - `None`: all elements are considered. + """ + super().__init__( + batch_size_train_val=batch_size_train_val, + batch_size_test=batch_size_test, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + ) + + self.df = df + self.df_path = df_path + + self.cache_data_path = str(cache_data_path) if cache_data_path is not None else None + self.featurization = featurization + + self.smiles_col = smiles_col + self.label_cols = self._parse_label_cols(label_cols, smiles_col) + self.idx_col = idx_col + self.sample_size = sample_size + + self.weights_col = weights_col + self.weights_type = weights_type + if self.weights_col is not None: + assert self.weights_type is None + + self.split_val = split_val + self.split_test = split_test + self.split_seed = split_seed + self.splits_path = splits_path + + self.featurization_n_jobs = featurization_n_jobs + self.featurization_progress = featurization_progress + + self.dataset = None + self.train_ds = None + self.val_ds = None + self.test_ds = None + self.train_indices = None + self.val_indices = None + self.test_indices = None + + def prepare_data(self): + """Called only from a single process in distributed settings. Steps: + + - If cache is set and exists, reload from cache. + - Load the dataframe if its a path. + - Extract smiles and labels from the dataframe. + - Compute the features. + - Compute or set split indices. + - Make the list of dict dataset. + """ + + if self._load_from_cache(): + return True + + # Load the dataframe + if self.df is None: + # Only load the useful columns, as some dataset can be very large + # when loading all columns + usecols = ( + check_arg_iterator(self.smiles_col, enforce_type=list) + + check_arg_iterator(self.label_cols, enforce_type=list) + + check_arg_iterator(self.idx_col, enforce_type=list) + + check_arg_iterator(self.weights_col, enforce_type=list) + ) + + df = self._read_csv(self.df_path, usecols=usecols) + else: + df = self.df + + df = self._sub_sample_df(df) + + logger.info(f"Prepare dataset with {len(df)} data points.") + + # Extract smiles and labels + smiles, labels, indices, weights, sample_idx = self._extract_smiles_labels( + df, + smiles_col=self.smiles_col, + label_cols=self.label_cols, + idx_col=self.idx_col, + weights_col=self.weights_col, + weights_type=self.weights_type, + ) + + # Precompute the features + # NOTE(hadim): in case of very large dataset we could: + # - or cache the data and read from it during `next(iter(dataloader))` + # - or compute the features on-the-fly during `next(iter(dataloader))` + # For now we compute in advance and hold everything in memory. + featurization_args = self.featurization or {} + transform_smiles = functools.partial(mol_to_dglgraph, **featurization_args) + features = dm.utils.parallelized( + transform_smiles, + smiles, + progress=self.featurization_progress, + n_jobs=self.featurization_n_jobs, + ) + + # Warn about None molecules + is_none = np.array([ii for ii, feat in enumerate(features) if feat is None]) + if len(is_none) > 0: + mols_to_msg = [f"{sample_idx[idx]}: {smiles[idx]}" for idx in is_none] + msg = "\n".join(mols_to_msg) + logger.warning( + (f"{len(is_none)} molecules will be removed since they failed featurization:\n" + msg) + ) + + # Remove None molecules + if len(is_none) > 0: + df.drop(df.index[is_none], axis=0) + features = [feat for feat in features if not (feat is None)] + sample_idx = np.delete(sample_idx, is_none, axis=0) + smiles = np.delete(smiles, is_none, axis=0) + if labels is not None: + labels = np.delete(labels, is_none, axis=0) + if weights is not None: + weights = np.delete(weights, is_none, axis=0) + if indices is not None: + indices = np.delete(indices, is_none, axis=0) + + # Get splits indices + self.train_indices, self.val_indices, self.test_indices = self._get_split_indices( + len(df), + split_val=self.split_val, + split_test=self.split_test, + split_seed=self.split_seed, + splits_path=self.splits_path, + sample_idx=sample_idx, + ) + + # Make the torch datasets (mostly a wrapper there is no memory overhead here) + self.dataset = DGLDataset( + smiles=smiles, + features=features, + labels=labels, + indices=indices, + weights=weights, + ) + + self._save_to_cache() + + def setup(self, stage: str = None): + """Prepare the torch dataset. Called on every GPUs. Setting state here is ok.""" + + if stage == "fit" or stage is None: + self.train_ds = Subset(self.dataset, self.train_indices) # type: ignore + self.val_ds = Subset(self.dataset, self.val_indices) # type: ignore + + if stage == "test" or stage is None: + self.test_ds = Subset(self.dataset, self.test_indices) # type: ignore + + def _parse_label_cols(self, label_cols: Union[type(None), str, List[str]], smiles_col: str) -> List[str]: + r""" + Parse the choice of label columns depending on the type of input. + The input parameters `label_cols` and `smiles_col` are described in + the `__init__` method. + """ + if self.df is None: + # Only load the useful columns, as some dataset can be very large + # when loading all columns + df = self._read_csv(self.df_path, nrows=0) + else: + df = self.df + cols = list(df.columns) + + # A star `*` at the beginning or end of the string specifies to look for all + # columns that starts/end with a specific string + if isinstance(label_cols, str): + if label_cols[0] == "*": + label_cols = [col for col in cols if str(col).endswith(label_cols[1:])] + elif label_cols[-1] == "*": + label_cols = [col for col in cols if str(col).startswith(label_cols[:-1])] + else: + label_cols = [label_cols] + + elif label_cols is None: + label_cols = [col for col in cols if col != smiles_col] + + return check_arg_iterator(label_cols, enforce_type=list) + + @property + def is_prepared(self): + if not hasattr(self, "dataset"): + return False + return getattr(self, "dataset") is not None + + @property + def is_setup(self): + if not hasattr(self, "train_ds"): + return False + return getattr(self, "train_ds") is not None + + @property + def num_node_feats(self): + """Return the number of node features in the first graph""" + + graph = self.get_first_graph() + num_feats = 0 + if "feat" in graph.ndata.keys(): + num_feats += graph.ndata["feat"].shape[1] + return num_feats + + @property + def num_node_feats_with_positional_encoding(self): + """Return the number of node features in the first graph + including positional encoding features.""" + + graph = self.get_first_graph() + num_feats = 0 + if "feat" in graph.ndata.keys(): + num_feats += graph.ndata["feat"].shape[1] + if "pos_enc_feats_sign_flip" in graph.ndata.keys(): + num_feats += graph.ndata["pos_enc_feats_sign_flip"].shape[1] + if "pos_enc_feats_no_flip" in graph.ndata.keys(): + num_feats += graph.ndata["pos_enc_feats_no_flip"].shape[1] + return num_feats + + @property + def num_edge_feats(self): + """Return the number of edge features in the first graph""" + + graph = self.get_first_graph() + if "feat" in graph.edata.keys(): + return graph.edata["feat"].shape[1] # type: ignore + else: + return 0 + + def get_first_graph(self): + """ + Low memory footprint method to get the first datapoint DGL graph. + The first 10 rows of the data are read in case the first one has a featurization + error. If all 10 first element, then `None` is returned, otherwise the first + graph to not fail is returned. + """ + if self.df is None: + df = self._read_csv(self.df_path, nrows=10) + else: + df = self.df.iloc[0:10, :] + + smiles, _, _, _, _ = self._extract_smiles_labels(df, self.smiles_col, self.label_cols) + + featurization_args = self.featurization or {} + transform_smiles = functools.partial(mol_to_dglgraph, **featurization_args) + graph = None + for s in smiles: + graph = transform_smiles(s) + if graph is not None: + break + return graph + + # Private methods + + def _save_to_cache(self): + """Save the built dataset, indices and featurization arguments into a cache file.""" + + # Cache on disk + if self.cache_data_path is not None: + logger.info(f"Write prepared datamodule to {self.cache_data_path}") + cache = {} + cache["dataset"] = self.dataset + cache["train_indices"] = self.train_indices + cache["val_indices"] = self.val_indices + cache["test_indices"] = self.test_indices + + # Save featurization args used + cache["featurization_args"] = mol_to_dglgraph_signature(dict(self.featurization or {})) + + with fsspec.open(self.cache_data_path, "wb") as f: + torch.save(cache, f) + + def _load_from_cache(self): + """Attempt to reload the data from cache. Return True if data has been + reloaded from the cache. + """ + + if self.cache_data_path is None: + # Cache path is not provided. + return False + + if not fs.exists(self.cache_data_path): + # Cache patch does not exist. + return False + + # Reload from cache if it exists and is valid + logger.info(f"Try reloading the data module from {self.cache_data_path}.") + + # Load cache + with fsspec.open(self.cache_data_path, "rb") as f: + cache = torch.load(f) + + # Are the required keys present? + excepted_cache_keys = { + "dataset", + "test_indices", + "train_indices", + "val_indices", + } + if not set(cache.keys()) != excepted_cache_keys: + logger.info( + f"Cache looks invalid with keys: {cache.keys()}. Excepted keys are {excepted_cache_keys}" + ) + logger.info("Fallback to regular data preparation steps.") + return False + + # Is the featurization signature the same? + current_signature = mol_to_dglgraph_signature(dict(self.featurization or {})) + cache_signature = mol_to_dglgraph_signature(cache["featurization_args"]) + + if current_signature != cache_signature: + logger.info(f"Cache featurizer arguments are different than the provided ones.") + logger.info(f"Cache featurizer arguments: {cache_signature}") + logger.info(f"Provided featurizer arguments: {current_signature}.") + logger.info("Fallback to regular data preparation steps.") + return False + + # At this point the cache can be loaded + + self.dataset = cache["dataset"] + self.train_indices = cache["train_indices"] + self.val_indices = cache["val_indices"] + self.test_indices = cache["test_indices"] + + logger.info(f"Datamodule correctly reloaded from cache.") + + return True + + def _extract_smiles_labels( + self, + df: pd.DataFrame, + smiles_col: str = None, + label_cols: List[str] = [], + idx_col: str = None, + weights_col: str = None, + weights_type: str = None, + ) -> Tuple[ + np.ndarray, np.ndarray, Union[type(None), np.ndarray], Union[type(None), np.ndarray], np.ndarray + ]: + """For a given dataframe extract the SMILES and labels columns. Smiles is returned as a list + of string while labels are returned as a 2D numpy array. + """ + + if smiles_col is None: + smiles_col_all = [col for col in df.columns if "smile" in str(col).lower()] + if len(smiles_col_all) == 0: + raise ValueError(f"No SMILES column found in dataframe. Columns are {df.columns}") + elif len(smiles_col_all) > 1: + raise ValueError( + f"Multiple SMILES column found in dataframe. SMILES Columns are {smiles_col_all}" + ) + + smiles_col = smiles_col_all[0] + + if label_cols is None: + label_cols = df.columns.drop(smiles_col) + + label_cols = check_arg_iterator(label_cols, enforce_type=list) + + smiles = df[smiles_col].values + labels = [pd.to_numeric(df[col], errors="coerce") for col in label_cols] + labels = np.stack(labels, axis=1) + + indices = None + if idx_col is not None: + indices = df[idx_col].values + + sample_idx = df.index.values + + # Extract the weights + weights = None + if weights_col is not None: + weights = df[weights_col].values + elif weights_type is not None: + if not np.all((labels == 0) | (labels == 1)): + raise ValueError("Labels must be binary for `weights_type`") + + if weights_type == "sample_label_balanced": + ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0] + weights = np.zeros(labels.shape) + weights[labels == 0] = ratio_pos_neg + weights[labels == 1] = ratio_pos_neg ** -1 + + elif weights_type == "sample_balanced": + ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0] + weights = np.zeros(labels.shape) + weights[labels == 0] = ratio_pos_neg + weights[labels == 1] = ratio_pos_neg ** -1 + weights = np.prod(weights, axis=1) + + else: + raise ValueError(f"Undefined `weights_type` {weights_type}") + + weights /= np.max(weights) # Put the max weight to 1 + + return smiles, labels, indices, weights, sample_idx + + def _get_split_indices( + self, + dataset_size: int, + split_val: float, + split_test: float, + sample_idx: Optional[Iterable[int]] = None, + split_seed: int = None, + splits_path: Union[str, os.PathLike] = None, + ): + """Compute indices of random splits.""" + + if sample_idx is None: + sample_idx = np.arange(dataset_size) + + if splits_path is None: + # Random splitting + train_indices, val_test_indices = train_test_split( + sample_idx, + test_size=split_val + split_test, + random_state=split_seed, + ) + + sub_split_test = split_test / (split_test + split_val) + val_indices, test_indices = train_test_split( + val_test_indices, + test_size=sub_split_test, + random_state=split_seed, + ) + + else: + # Split from an indices file + with fsspec.open(str(splits_path)) as f: + splits = self._read_csv(splits_path) + + train_indices = splits["train"].dropna().astype("int").tolist() + val_indices = splits["val"].dropna().astype("int").tolist() + test_indices = splits["test"].dropna().astype("int").tolist() + + # Filter train, val and test indices + _, train_idx, _ = np.intersect1d(sample_idx,train_indices, return_indices=True) + train_indices = train_idx.tolist() + train_indices.sort() + _, valid_idx, _ = np.intersect1d(sample_idx,val_indices, return_indices=True) + val_indices = valid_idx.tolist() + val_indices.sort() + _, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True) + test_indices = test_idx.tolist() + test_indices.sort() + + # train_indices = [ii for ii, idx in enumerate(sample_idx) if idx in train_indices] + # val_indices = [ii for ii, idx in enumerate(sample_idx) if idx in val_indices] + # test_indices = [ii for ii, idx in enumerate(sample_idx) if idx in test_indices] + + return train_indices, val_indices, test_indices + + def _sub_sample_df(self, df): + # Sub-sample the dataframe + if isinstance(self.sample_size, int): + n = min(self.sample_size, df.shape[0]) + df = df.sample(n=n) + elif isinstance(self.sample_size, float): + df = df.sample(f=self.sample_size) + elif self.sample_size is None: + pass + else: + raise ValueError(f"Wrong value for `self.sample_size`: {self.sample_size}") + + return df + + def __len__(self) -> int: + r""" + Returns the number of elements of the current DataModule + """ + if self.df is None: + df = self._read_csv(self.df_path, usecols=[self.smiles_col]) + else: + df = self.df + + return len(df) + + def to_dict(self): + obj_repr = {} + obj_repr["name"] = self.__class__.__name__ + obj_repr["len"] = len(self) + obj_repr["train_size"] = len(self.train_indices) if self.train_indices is not None else None + obj_repr["val_size"] = len(self.val_indices) if self.val_indices is not None else None + obj_repr["test_size"] = len(self.test_indices) if self.test_indices is not None else None + obj_repr["batch_size_train_val"] = self.batch_size_train_val + obj_repr["batch_size_test"] = self.batch_size_test + obj_repr["num_node_feats"] = self.num_node_feats + obj_repr["num_node_feats_with_positional_encoding"] = self.num_node_feats_with_positional_encoding + obj_repr["num_edge_feats"] = self.num_edge_feats + obj_repr["num_labels"] = len(self.label_cols) + obj_repr["collate_fn"] = self.collate_fn.__name__ + obj_repr["featurization"] = self.featurization + return obj_repr + + def __repr__(self): + """Controls how the class is printed""" + return omegaconf.OmegaConf.to_yaml(self.to_dict()) + + +class DGLOGBDataModule(DGLFromSmilesDataModule): + """Load an OGB GraphProp dataset.""" + + def __init__( + self, + dataset_name: str, + cache_data_path: Optional[Union[str, os.PathLike]] = None, + featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, + weights_col: str = None, + weights_type: str = None, + sample_size: Union[int, float, type(None)] = None, + batch_size_train_val: int = 16, + batch_size_test: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + featurization_n_jobs: int = -1, + featurization_progress: bool = False, + collate_fn: Optional[Callable] = None, + ): + """ + + Parameters: + dataset_name: Name of the OGB dataset to load. Examples of possible datasets are + "ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv". + cache_data_path: path where to save or reload the cached data. The path can be + remote (S3, GS, etc). + featurization: args to apply to the SMILES to DGL featurizer. + batch_size_train_val: batch size for training and val dataset. + batch_size_test: batch size for test dataset. + num_workers: Number of workers for the dataloader. Use -1 to use all available + cores. + pin_memory: Whether to pin on paginated CPU memory for the dataloader. + featurization_n_jobs: Number of cores to use for the featurization. + featurization_progress: whether to show a progress bar during featurization. + collate_fn: A custom torch collate function. Default is to `goli.data.goli_collate_fn` + sample_size: + + - `int`: The maximum number of elements to take from the dataset. + - `float`: Value between 0 and 1 representing the fraction of the dataset to consider + - `None`: all elements are considered. + """ + + self.dataset_name = dataset_name + + # Get OGB metadata + self.metadata = self._get_dataset_metadata(self.dataset_name) + + # Get dataset + df, idx_col, smiles_col, label_cols, splits_path = self._load_dataset(self.metadata) + + # Config for datamodule + dm_args = {} + dm_args["df"] = df + dm_args["cache_data_path"] = cache_data_path + dm_args["featurization"] = featurization + dm_args["smiles_col"] = smiles_col + dm_args["label_cols"] = label_cols + dm_args["idx_col"] = idx_col + dm_args["splits_path"] = splits_path + dm_args["batch_size_train_val"] = batch_size_train_val + dm_args["batch_size_test"] = batch_size_test + dm_args["num_workers"] = num_workers + dm_args["pin_memory"] = pin_memory + dm_args["featurization_n_jobs"] = featurization_n_jobs + dm_args["featurization_progress"] = featurization_progress + dm_args["persistent_workers"] = persistent_workers + dm_args["collate_fn"] = collate_fn + dm_args["weights_col"] = weights_col + dm_args["weights_type"] = weights_type + dm_args["sample_size"] = sample_size + + # Init DGLFromSmilesDataModule + super().__init__(**dm_args) + + def to_dict(self): + obj_repr = {} + obj_repr["dataset_name"] = self.dataset_name + obj_repr.update(super().to_dict()) + return obj_repr + + # Private methods + + def _load_dataset(self, metadata: dict): + """Download, extract and load an OGB dataset.""" + + base_dir = fs.get_cache_dir("ogb") + if metadata['download_name'] == "pcqm4m": + dataset_dir = base_dir / (metadata["download_name"] + "_kddcup2021") + else: + dataset_dir = base_dir / metadata["download_name"] + + if not dataset_dir.exists(): + + # Create cache filepath for zip file and associated folder + dataset_path = base_dir / f"{metadata['download_name']}.zip" + + # Download it + if not dataset_path.exists(): + logger.info(f"Downloading {metadata['url']} to {dataset_path}") + fs.copy(metadata["url"], dataset_path, progress=True) + + # Extract + zf = zipfile.ZipFile(dataset_path) + zf.extractall(base_dir) + + # Load CSV file + if metadata['download_name']== "pcqm4m": + df_path = dataset_dir / "raw" / "data.csv.gz" + else: + df_path = dataset_dir / "mapping" / "mol.csv.gz" + logger.info(f"Loading {df_path} in memory.") + df = pd.read_csv(df_path) + + # Load split from the OGB dataset and save them in a single CSV file + if metadata['download_name'] == "pcqm4m": + split_name = metadata["split"] + split_dict = torch.load(dataset_dir / "split_dict.pt") + train_split = pd.DataFrame(split_dict['train']) + val_split = pd.DataFrame(split_dict['valid']) + test_split = pd.DataFrame(split_dict['test']) + splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore + splits.columns = ["train", "val", "test"] + + splits_path = dataset_dir / "split" + if not splits_path.exists(): + os.makedirs(splits_path) + splits_path = dataset_dir / f"{split_name}.csv.gz" + else: + splits_path = splits_path / f"{split_name}.csv.gz" + logger.info(f"Saving splits to {splits_path}") + splits.to_csv(splits_path, index=None) + else: + split_name = metadata["split"] + train_split = pd.read_csv(dataset_dir / "split" / split_name / "train.csv.gz", header=None) # type: ignore + val_split = pd.read_csv(dataset_dir / "split" / split_name / "valid.csv.gz", header=None) # type: ignore + test_split = pd.read_csv(dataset_dir / "split" / split_name / "test.csv.gz", header=None) # type: ignore + + splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore + splits.columns = ["train", "val", "test"] + + splits_path = dataset_dir / "split" / f"{split_name}.csv.gz" + logger.info(f"Saving splits to {splits_path}") + splits.to_csv(splits_path, index=None) + + # Get column names: OGB columns are predictable + if metadata['download_name'] == "pcqm4m": + idx_col = df.columns[0] + smiles_col = df.columns[-2] + label_cols = df.columns[-1:].to_list() + else: + idx_col = df.columns[-1] + smiles_col = df.columns[-2] + label_cols = df.columns[:-2].to_list() + + return df, idx_col, smiles_col, label_cols, splits_path + + def _get_dataset_metadata(self, dataset_name: str): + ogb_metadata = self._get_ogb_metadata() + + if dataset_name not in ogb_metadata.index: + raise ValueError(f"'{dataset_name}' is not a valid dataset name.") + + return ogb_metadata.loc[dataset_name].to_dict() + + def _get_ogb_metadata(self): + """Get the metadata of OGB GraphProp datasets.""" + + with importlib.resources.open_text("ogb.graphproppred", "master.csv") as f: + ogb_metadata = pd.read_csv(f) + + ogb_metadata = ogb_metadata.set_index(ogb_metadata.columns[0]) + ogb_metadata = ogb_metadata.T + + # Only keep datasets of type 'mol' + ogb_metadata = ogb_metadata[ogb_metadata["data type"] == "mol"] + + return ogb_metadata diff --git a/goli/data/single_atom_dataset/single_atom_dataset.csv b/goli/data/single_atom_dataset/single_atom_dataset.csv index 272516c4d..4654d56d6 100644 --- a/goli/data/single_atom_dataset/single_atom_dataset.csv +++ b/goli/data/single_atom_dataset/single_atom_dataset.csv @@ -1,8 +1,8 @@ -SMILES,score -[CL-],17 -[Br-],35 -C,6 -O,8 -S,16 -[O-],8 +SMILES,score +[CL-],17 +[Br-],35 +C,6 +O,8 +S,16 +[O-],8 N,7 \ No newline at end of file diff --git a/goli/data/utils.py b/goli/data/utils.py index 820c39d09..c8c33bec6 100644 --- a/goli/data/utils.py +++ b/goli/data/utils.py @@ -1,72 +1,72 @@ -import importlib.resources -import zipfile - -import pandas as pd - -import goli - -GOLI_DATASETS_BASE_URL = "gcs://goli-public/datasets" -GOLI_DATASETS = { - "goli-zinc-micro": "goli-zinc-micro.zip", - "goli-zinc-bench-gnn": "goli-zinc-bench-gnn.zip", - "goli-htsfp": "goli-htsfp.csv.gz", - "goli-htsfp-pcba": "goli-htsfp-pcba.csv.gz", -} - - -def load_micro_zinc() -> pd.DataFrame: - """Return a dataframe of micro ZINC (1000 data points).""" - - with importlib.resources.open_text("goli.data.micro_ZINC", "micro_ZINC.csv") as f: - df = pd.read_csv(f) - - return df # type: ignore - - -def load_tiny_zinc() -> pd.DataFrame: - """Return a dataframe of tiny ZINC (100 data points).""" - - with importlib.resources.open_text("goli.data.micro_ZINC", "micro_ZINC.csv") as f: - df = pd.read_csv(f, nrows=100) - - return df # type: ignore - - -def list_goli_datasets(): - """List Goli datasets available to download.""" - return set(GOLI_DATASETS.keys()) - - -def download_goli_dataset(name: str, output_path: str, extract_zip: bool = True, progress: bool = False): - """Download a Goli dataset to a specified location. - - Args: - name: Name of the Goli dataset from `goli.data.utils.get_goli_datasets()`. - output_path: Directory path where to download the dataset to. - extract_zip: Whether to extract the dataset if it's a zip file. - progress: Whether to show a progress bar during download. - """ - - if name not in GOLI_DATASETS: - raise ValueError(f"'{name}' is not a valid Goli dataset name. Choose from {GOLI_DATASETS}") - - fname = GOLI_DATASETS[name] - - dataset_path_source = goli.utils.fs.join(GOLI_DATASETS_BASE_URL, fname) - dataset_path_destination = goli.utils.fs.join(output_path, fname) - - if not goli.utils.fs.exists(dataset_path_destination): - goli.utils.fs.copy(dataset_path_source, dataset_path_destination, progress=progress) - - if extract_zip and str(dataset_path_destination).endswith(".zip"): - - # Unzip the dataset - with zipfile.ZipFile(dataset_path_destination, "r") as zf: - zf.extractall(output_path) - - if extract_zip: - # Set the destination path to the folder - # NOTE(hadim): this is a bit fragile. - dataset_path_destination = dataset_path_destination.split(".")[0] - - return dataset_path_destination +import importlib.resources +import zipfile + +import pandas as pd + +import goli + +GOLI_DATASETS_BASE_URL = "gcs://goli-public/datasets" +GOLI_DATASETS = { + "goli-zinc-micro": "goli-zinc-micro.zip", + "goli-zinc-bench-gnn": "goli-zinc-bench-gnn.zip", + "goli-htsfp": "goli-htsfp.csv.gz", + "goli-htsfp-pcba": "goli-htsfp-pcba.csv.gz", +} + + +def load_micro_zinc() -> pd.DataFrame: + """Return a dataframe of micro ZINC (1000 data points).""" + + with importlib.resources.open_text("goli.data.micro_ZINC", "micro_ZINC.csv") as f: + df = pd.read_csv(f) + + return df # type: ignore + + +def load_tiny_zinc() -> pd.DataFrame: + """Return a dataframe of tiny ZINC (100 data points).""" + + with importlib.resources.open_text("goli.data.micro_ZINC", "micro_ZINC.csv") as f: + df = pd.read_csv(f, nrows=100) + + return df # type: ignore + + +def list_goli_datasets(): + """List Goli datasets available to download.""" + return set(GOLI_DATASETS.keys()) + + +def download_goli_dataset(name: str, output_path: str, extract_zip: bool = True, progress: bool = False): + """Download a Goli dataset to a specified location. + + Args: + name: Name of the Goli dataset from `goli.data.utils.get_goli_datasets()`. + output_path: Directory path where to download the dataset to. + extract_zip: Whether to extract the dataset if it's a zip file. + progress: Whether to show a progress bar during download. + """ + + if name not in GOLI_DATASETS: + raise ValueError(f"'{name}' is not a valid Goli dataset name. Choose from {GOLI_DATASETS}") + + fname = GOLI_DATASETS[name] + + dataset_path_source = goli.utils.fs.join(GOLI_DATASETS_BASE_URL, fname) + dataset_path_destination = goli.utils.fs.join(output_path, fname) + + if not goli.utils.fs.exists(dataset_path_destination): + goli.utils.fs.copy(dataset_path_source, dataset_path_destination, progress=progress) + + if extract_zip and str(dataset_path_destination).endswith(".zip"): + + # Unzip the dataset + with zipfile.ZipFile(dataset_path_destination, "r") as zf: + zf.extractall(output_path) + + if extract_zip: + # Set the destination path to the folder + # NOTE(hadim): this is a bit fragile. + dataset_path_destination = dataset_path_destination.split(".")[0] + + return dataset_path_destination diff --git a/goli/features/__init__.py b/goli/features/__init__.py index a6bf5e647..c5f1c51bf 100644 --- a/goli/features/__init__.py +++ b/goli/features/__init__.py @@ -1,6 +1,6 @@ -from .featurizer import get_mol_atomic_features_onehot -from .featurizer import get_mol_atomic_features_float -from .featurizer import get_mol_edge_features -from .featurizer import mol_to_adj_and_features -from .featurizer import mol_to_dglgraph -from .featurizer import mol_to_dglgraph_signature +from .featurizer import get_mol_atomic_features_onehot +from .featurizer import get_mol_atomic_features_float +from .featurizer import get_mol_edge_features +from .featurizer import mol_to_adj_and_features +from .featurizer import mol_to_dglgraph +from .featurizer import mol_to_dglgraph_signature diff --git a/goli/features/featurizer.py b/goli/features/featurizer.py index 21e66a97d..2434a0589 100644 --- a/goli/features/featurizer.py +++ b/goli/features/featurizer.py @@ -1,727 +1,740 @@ -from typing import Union, List, Callable, Dict, Tuple, Any - -import inspect -import warnings -from loguru import logger - -import numpy as np -from scipy.sparse import csr_matrix -import dgl -import torch - -from rdkit import Chem -from rdkit.Chem.rdmolops import GetAdjacencyMatrix -import datamol as dm - -from goli.features import nmp -from goli.utils.tensor import one_of_k_encoding -from goli.features.positional_encoding import get_all_positional_encoding - - -def get_mol_atomic_features_onehot(mol: Chem.rdchem.Mol, property_list: List[str]) -> Dict[str, np.ndarray]: - r""" - Get the following set of features for any given atom - - * One-hot representation of the atom - * One-hot representation of the atom degree - * One-hot representation of the atom implicit valence - * One-hot representation of the the atom hybridization - * Whether the atom is aromatic - * The atom's formal charge - * The atom's number of radical electrons - - Additionally, the following features can be set, depending on the value of input Parameters - - * One-hot representation of the number of hydrogen atom in the the current atom neighborhood if `explicit_H` is false - * One-hot encoding of the atom chirality, and whether such configuration is even possible - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of integer atomic properties to get from the molecule. - The integer values are converted to a one-hot vector. - Callables are not supported by this function. - - Accepted properties are: - - - "atomic-number" - - "degree" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N, OH). N is the number of atoms - in ``mol`` and OH the lenght of the one-hot encoding. - - """ - - prop_dict = {} - - for prop in property_list: - - prop = prop.lower() - prop_name = prop - - property_array = [] - for ii, atom in enumerate(mol.GetAtoms()): - - if prop in ["atomic-number"]: - one_hot = one_of_k_encoding(atom.GetSymbol(), nmp.ATOM_LIST) - elif prop in ["degree"]: - one_hot = one_of_k_encoding(atom.GetDegree(), nmp.ATOM_DEGREE_LIST) - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - one_hot = one_of_k_encoding(atom.GetTotalValence(), nmp.VALENCE) - elif prop in ["implicit-valence"]: - one_hot = one_of_k_encoding(atom.GetImplicitValence(), nmp.VALENCE) - elif prop in ["hybridization"]: - one_hot = one_of_k_encoding(atom.GetHybridization(), nmp.HYBRIDIZATION_LIST) - elif prop in ["chirality"]: - try: - one_hot = one_of_k_encoding(atom.GetProp("_CIPCode"), nmp.CHIRALITY_LIST) - one_hot.append(int(atom.HasProp("_ChiralityPossible"))) - except: - one_hot = [0, 0, int(atom.HasProp("_ChiralityPossible"))] - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(one_hot, dtype=np.float32)) - - prop_dict[prop_name] = np.stack(property_array, axis=0) - - return prop_dict - - -def get_mol_atomic_features_float( - mol: Chem.rdchem.Mol, - property_list: Union[List[str], List[Callable]], - offset_carbon: bool = True, -) -> Dict[str, np.ndarray]: - """ - Get a dictionary of floating-point arrays of atomic properties. - To ensure all properties are at a similar scale, some of the properties - are divided by a constant. - - There is also the possibility of offseting by the carbon value using - the `offset_carbon` parameter. - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of atomic properties to get from the molecule, such as 'atomic-number', - 'mass', 'valence', 'degree', 'electronegativity'. - Some elements are divided by a factor to avoid feature explosion. - - Accepted properties are: - - - "atomic-number" - - "mass", "weight" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - "hybridization" - - "aromatic" - - "ring", "in-ring" - - "degree" - - "radical-electron" - - "formal-charge" - - "vdw-radius" - - "covalent-radius" - - "electronegativity" - - "ionization", "first-ionization" - - "melting-point" - - "metal" - - "single-bond" - - "aromatic-bond" - - "double-bond" - - "triple-bond" - - "is-carbon" - - offset_carbon: - Whether to subract the Carbon property from the desired atomic property. - For example, if we want the mass of the Lithium (6.941), the mass of the - Carbon (12.0107) will be subracted, resulting in a value of -5.0697 - - Returns: - - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - periodic_table = Chem.GetPeriodicTable() - prop_dict = {} - C = Chem.Atom("C") - offC = bool(offset_carbon) - - for prop in property_list: - - prop_name = None - - property_array = np.zeros(mol.GetNumAtoms(), dtype=np.float32) - for ii, atom in enumerate(mol.GetAtoms()): - - val = None - - if isinstance(prop, str): - - prop = prop.lower() - prop_name = prop - - if prop in ["atomic-number"]: - val = (atom.GetAtomicNum() - (offC * C.GetAtomicNum())) / 5 - elif prop in ["mass", "weight"]: - prop_name = "mass" - val = (atom.GetMass() - (offC * C.GetMass())) / 10 - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - val = atom.GetTotalValence() - (offC * 4) - elif prop in ["implicit-valence"]: - val = atom.GetImplicitValence() - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["chirality"]: - val = (atom.GetProp("_CIPCode") == "R") if atom.HasProp("_CIPCode") else 2 - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["aromatic"]: - val = atom.GetIsAromatic() - elif prop in ["ring", "in-ring"]: - prop_name = "in-ring" - val = atom.IsInRing() - elif prop in ["degree"]: - val = atom.GetTotalDegree() - (offC * 2) - elif prop in ["radical-electron"]: - val = atom.GetNumRadicalElectrons() - elif prop in ["formal-charge"]: - val = atom.GetFormalCharge() - elif prop in ["vdw-radius"]: - val = periodic_table.GetRvdw(atom.GetAtomicNum()) - offC * periodic_table.GetRvdw( - C.GetAtomicNum() - ) - elif prop in ["covalent-radius"]: - val = periodic_table.GetRcovalent( - atom.GetAtomicNum() - ) - offC * periodic_table.GetRcovalent(C.GetAtomicNum()) - elif prop in ["electronegativity"]: - val = ( - nmp.PERIODIC_TABLE["Electronegativity"][atom.GetAtomicNum()] - - offC * nmp.PERIODIC_TABLE["Electronegativity"][C.GetAtomicNum()] - ) - elif prop in ["ionization", "first-ionization"]: - prop_name = "ionization" - val = ( - nmp.PERIODIC_TABLE["FirstIonization"][atom.GetAtomicNum()] - - offC * nmp.PERIODIC_TABLE["FirstIonization"][C.GetAtomicNum()] - ) / 5 - elif prop in ["melting-point"]: - val = ( - nmp.PERIODIC_TABLE["MeltingPoint"][atom.GetAtomicNum()] - - offC * nmp.PERIODIC_TABLE["MeltingPoint"][C.GetAtomicNum()] - ) / 200 - elif prop in ["metal"]: - if nmp.PERIODIC_TABLE["Metal"][atom.GetAtomicNum()] == "yes": - val = 2 - elif nmp.PERIODIC_TABLE["Metalloid"][atom.GetAtomicNum()] == "yes": - val = 1 - else: - val = 0 - elif "-bond" in prop: - bonds = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] - if prop in ["single-bond"]: - val = len([bond == 1 for bond in bonds]) - elif prop in ["aromatic-bond"]: - val = len([bond == 1.5 for bond in bonds]) - elif prop in ["double-bond"]: - val = len([bond == 2 for bond in bonds]) - elif prop in ["triple-bond"]: - val = len([bond == 3 for bond in bonds]) - else: - raise ValueError(f"{prop} is not a correct bond.") - val -= offC * 1 - elif prop in ["is-carbon"]: - val = atom.GetAtomicNum() == 6 - val -= offC * 1 - else: - raise ValueError(f"Unsupported property `{prop}`") - - elif callable(prop): - prop_name = str(prop) - val = prop(atom) - else: - ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") - - if val is None: - raise ValueError("val is undefined.") - - property_array[ii] = val - - if prop_name is None: - raise ValueError("prop_name is undefined.") - - prop_dict[prop_name] = property_array - - return prop_dict - - -def get_simple_mol_conformer(mol: Chem.rdchem.Mol) -> Union[Chem.rdchem.Conformer, None]: - r""" - If the molecule has a conformer, then it will return the conformer at idx `0`. - Otherwise, it generates a simple molecule conformer using `rdkit.Chem.rdDistGeom.EmbedMolecule` - and returns it. This is meant to be used in simple functions like `GetBondLength`, - not in functions requiring complex 3D structure. - - Parameters: - - mol: Rdkit Molecule - - Returns: - conf: A conformer of the molecule, or `None` if it fails - """ - - val = 0 - if mol.GetNumConformers() == 0: - val = Chem.rdDistGeom.EmbedMolecule(mol) - if val == -1: - val = Chem.rdDistGeom.EmbedMolecule( - mol, - enforceChirality=False, - ignoreSmoothingFailures=True, - useBasicKnowledge=True, - useExpTorsionAnglePrefs=True, - forceTol=0.1, - ) - - if val == -1: - conf = None - logger.warn("Couldn't compute conformer for molecule `{}`".format(Chem.MolToSmiles(mol))) - else: - conf = mol.GetConformer(0) - - return conf - - -def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: Chem.rdchem.Mol) -> float: - r""" - Estimate the bond length between atoms by looking at the estimated atomic radius - that depends both on the atom type and the bond type. The resulting bond-length is - then the sum of the radius. - - Keep in mind that this function only provides an estimate of the bond length and not - the true one based on a conformer. The vast majority od estimated bond lengths will - have an error below 5% while some bonds can have an error up to 20%. This function - is mostly useful when conformer generation fails for some molecules, or for - increased computation speed. - - Parameters: - bond: The bond to measure its lenght - mol: The molecule containing the bond (used to get neighbouring atoms) - - Returns: - bond_length: The bond length in Angstrom, typically a value around 1-2. - - """ - - # Small function to convert strings to floats - def float_or_nan(string): - try: - val = float(string) - except: - val = float("nan") - return val - - # Get the atoms connected by the bond - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - atom1 = mol.GetAtomWithIdx(idx1).GetAtomicNum() - atom2 = mol.GetAtomWithIdx(idx2).GetAtomicNum() - bond_type = bond.GetBondType() - - # Get single bond atomic radius - if bond_type == Chem.rdchem.BondType.SINGLE: - rad1 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom1]] - rad2 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom2]] - # Get double bond atomic radius - elif bond_type == Chem.rdchem.BondType.DOUBLE: - rad1 = [nmp.PERIODIC_TABLE["DoubleBondRadius"][atom1]] - rad2 = [nmp.PERIODIC_TABLE["DoubleBondRadius"][atom2]] - # Get triple bond atomic radius - elif bond_type == Chem.rdchem.BondType.TRIPLE: - rad1 = [nmp.PERIODIC_TABLE["TripleBondRadius"][atom1]] - rad2 = [nmp.PERIODIC_TABLE["TripleBondRadius"][atom2]] - # Get average of single bond and double bond atomic radius - elif bond_type == Chem.rdchem.BondType.AROMATIC: - rad1 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom1], nmp.PERIODIC_TABLE["DoubleBondRadius"][atom1]] - rad2 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom2], nmp.PERIODIC_TABLE["DoubleBondRadius"][atom2]] - - # Average the bond lengths, while ignoring nans in case some missing value - rad1_float = np.nanmean(np.array([float_or_nan(elem) for elem in rad1])) - rad2_float = np.nanmean(np.array([float_or_nan(elem) for elem in rad2])) - - # If the bond radius is still nan (this shouldn't happen), take the single bond radius - if np.isnan(rad1_float): - rad1_float = float_or_nan(nmp.PERIODIC_TABLE["SingleBondRadius"][atom1]) - if np.isnan(rad2_float): - rad2_float = float_or_nan(nmp.PERIODIC_TABLE["SingleBondRadius"][atom2]) - - bond_length = rad1_float + rad2_float - - return bond_length - - -def get_mol_edge_features(mol: Chem.rdchem.Mol, property_list: List[str]): - r""" - Get the following set of features for any given bond - See `goli.features.nmp` for allowed values in one hot encoding - - * One-hot representation of the bond type. Note that you should not kekulize your - molecules, if you expect this to take aromatic bond into account. - * Bond stereo type, following CIP classification - * Whether the bond is conjugated - * Whether the bond is in a ring - - Parameters: - mol: rdkit.Chem.Molecule - the molecule of interest - - property_list: - A list of edge properties to return for the given molecule. - Accepted properties are: - - - "bond-type-onehot" - - "bond-type-float" - - "stereo" - - "in-ring" - - "conjugated" - - "conformer-bond-length" (might cause problems with complex molecules) - - "estimated-bond-length" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - prop_dict = {} - - # Compute features for each bond - num_bonds = mol.GetNumBonds() - for prop in property_list: - property_array = [] - for ii in range(num_bonds): - prop = prop.lower() - bond = mol.GetBondWithIdx(ii) - - if prop in ["bond-type-onehot"]: - encoding = one_of_k_encoding(bond.GetBondType(), nmp.BOND_TYPES) - elif prop in ["bond-type-float"]: - encoding = [bond.GetBondTypeAsDouble()] - elif prop in ["stereo"]: - encoding = one_of_k_encoding(bond.GetStereo(), nmp.BOND_STEREO) - elif prop in ["in-ring"]: - encoding = [bond.IsInRing()] - elif prop in ["conjugated"]: - encoding = [bond.GetIsConjugated()] - elif prop in ["conformer-bond-length"]: - conf = get_simple_mol_conformer(mol) - if conf is not None: - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - encoding = [Chem.rdMolTransforms.GetBondLength(conf, idx1, idx2)] - else: - encoding = [0] - elif prop in ["estimated-bond-length"]: - encoding = [get_estimated_bond_length(bond, mol)] - - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(encoding, dtype=np.float32)) - - if num_bonds > 0: - prop_dict[prop] = np.stack(property_array, axis=0) - else: - prop_dict[prop] = np.array([]) - - return prop_dict - - -def mol_to_adj_and_features( - mol: Union[str, Chem.rdchem.Mol], - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - pos_encoding_as_directions: Dict[str, Any] = None, -) -> Union[csr_matrix, Union[np.ndarray, None], Union[np.ndarray, None]]: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - Returns: - - adj: - Scipy sparse adjacency matrix of the molecule - - ndata: - Concatenated node data of the atoms, based on the properties from - `atom_property_list_onehot` and `atom_property_list_float`. - If no properties are given, it returns `None` - - edata - Concatenated node edge of the molecule, based on the properties from - `edge_property_list`. - If no properties are given, it returns `None` - - """ - - if isinstance(mol, str): - mol = dm.to_mol(mol) - - # Add or remove explicit hydrogens - if explicit_H: - mol = Chem.AddHs(mol) - else: - mol = Chem.RemoveHs(mol) - - # Get the adjacency matrix - adj = GetAdjacencyMatrix(mol, useBO=use_bonds_weights, force=True) - if add_self_loop: - adj = adj + np.eye(adj.shape[0]) - adj = csr_matrix(adj) - - # Get the node features - atom_features_onehot = get_mol_atomic_features_onehot(mol, atom_property_list_onehot) - atom_features_float = get_mol_atomic_features_float(mol, atom_property_list_float) - ndata = list(atom_features_float.values()) + list(atom_features_onehot.values()) - ndata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in ndata] - ndata = np.concatenate(ndata, axis=1) if len(ndata) > 0 else None - - # Get the edge features - edge_features = get_mol_edge_features(mol, edge_property_list) - edata = list(edge_features.values()) - edata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in edata] - edata = np.concatenate(edata, axis=1) if len(edata) > 0 else None - - pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir = get_all_positional_encoding( - adj, pos_encoding_as_features, pos_encoding_as_directions - ) - - return adj, ndata, edata, pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir - - -def mol_to_dglgraph( - mol: Chem.rdchem.Mol, - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - pos_encoding_as_directions: Dict[str, Any] = None, - dtype: torch.dtype = torch.float32, - on_error: str = "ignore", -) -> dgl.DGLGraph: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - dtype: - The torch data type used to build the graph - - on_error: - What to do when the featurization fails. - - - "raise": Raise an error - - "warn": Raise a warning and return None - - "ignore": Ignore the error and return None - - Returns: - - graph: - DGL graph, with `graph.ndata['n']` corresponding to the concatenated - node data from `atom_property_list_onehot` and `atom_property_list_float`, - `graph.edata['e']` corresponding to the concatenated edge data from `edge_property_list` - - """ - - input_mol = mol - - try: - - if isinstance(mol, str): - mol = dm.to_mol(mol) - if explicit_H: - mol = Chem.AddHs(mol) - else: - mol = Chem.RemoveHs(mol) - - # Get the adjacency, node features and edge features - ( - adj, - ndata, - edata, - pos_enc_feats_sign_flip, - pos_enc_feats_no_flip, - pos_enc_dir, - ) = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=atom_property_list_onehot, - atom_property_list_float=atom_property_list_float, - edge_property_list=edge_property_list, - add_self_loop=add_self_loop, - explicit_H=explicit_H, - use_bonds_weights=use_bonds_weights, - pos_encoding_as_features=pos_encoding_as_features, - pos_encoding_as_directions=pos_encoding_as_directions, - ) - except Exception as e: - if on_error.lower() == "raise": - raise e - elif on_error.lower() == "warn": - smiles = input_mol - if isinstance(smiles, Chem.rdchem.Mol): - smiles = Chem.MolToSmiles(input_mol) - - msg = str(e) + "\nIgnoring following molecule:" + smiles - logger.warning(msg) - return None - elif on_error.lower() == "ignore": - return None - - # Transform the matrix and data into a DGLGraph object - graph = dgl.from_scipy(adj) - - # Assign the node data - if ndata is not None: - graph.ndata["feat"] = torch.from_numpy(ndata).to(dtype=dtype) - - # Assign the edge data. Due to DGL only supporting Hetero-graphs, we - # need to duplicate each edge information for its 2 entries - if edata is not None: - src_ids, dst_ids = graph.all_edges() - hetero_edata = np.zeros_like(edata, shape=(edata.shape[0] * 2, edata.shape[1])) - for ii in range(mol.GetNumBonds()): - bond = mol.GetBondWithIdx(ii) - src, dst = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - id1 = np.where((src == src_ids) & (dst == dst_ids))[0] - id2 = np.where((dst == src_ids) & (src == dst_ids))[0] - hetero_edata[id1, :] = edata[ii, :] - hetero_edata[id2, :] = edata[ii, :] - - graph.edata["feat"] = torch.from_numpy(hetero_edata).to(dtype=dtype) - - # Add sign-flip positional encoding - if pos_enc_feats_sign_flip is not None: - graph.ndata["pos_enc_feats_sign_flip"] = pos_enc_feats_sign_flip - - # Add non-sign-flip positional encoding - if pos_enc_feats_no_flip is not None: - graph.ndata["pos_enc_feats_no_flip"] = pos_enc_feats_no_flip - - # Add positional encoding for directional use - if pos_enc_dir is not None: - graph.ndata["pos_dir"] = pos_enc_dir - - return graph - - -def mol_to_dglgraph_signature(featurizer_args: Dict[str, Any] = None): - """Get the default arguments of `mol_to_dglgraph` and update it - with a provided dict of arguments in order to get a fulle signature - of the featurizer args actually used for the features computation. - """ - - # Get the signature of `mol_to_dglgraph` - signature = inspect.signature(mol_to_dglgraph) - - # Filter out empty arguments (without default value) - parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values())) - - # Convert to dict - parameters = {param.name: param.default for param in parameters} - - # Update the parameters with the supplied ones - if featurizer_args is not None: - parameters.update(featurizer_args) - - return parameters +from typing import Union, List, Callable, Dict, Tuple, Any, Optional + +import inspect +import warnings +from loguru import logger + +import numpy as np +from scipy.sparse import csr_matrix +import dgl +import torch + +from rdkit import Chem +from rdkit.Chem.rdmolops import GetAdjacencyMatrix +import datamol as dm + +from goli.features import nmp +from goli.utils.tensor import one_of_k_encoding +from goli.features.positional_encoding import get_all_positional_encoding + + +def get_mol_atomic_features_onehot(mol: Chem.rdchem.Mol, property_list: List[str]) -> Dict[str, np.ndarray]: + r""" + Get the following set of features for any given atom + + * One-hot representation of the atom + * One-hot representation of the atom degree + * One-hot representation of the atom implicit valence + * One-hot representation of the the atom hybridization + * Whether the atom is aromatic + * The atom's formal charge + * The atom's number of radical electrons + + Additionally, the following features can be set, depending on the value of input Parameters + + * One-hot representation of the number of hydrogen atom in the the current atom neighborhood if `explicit_H` is false + * One-hot encoding of the atom chirality, and whether such configuration is even possible + + Parameters: + + mol: + molecule from which to extract the properties + + property_list: + A list of integer atomic properties to get from the molecule. + The integer values are converted to a one-hot vector. + Callables are not supported by this function. + + Accepted properties are: + + - "atomic-number" + - "degree" + - "valence", "total-valence" + - "implicit-valence" + - "hybridization" + - "chirality" + + Returns: + prop_dict: + A dictionnary where the element of ``property_list`` are the keys + and the values are np.ndarray of shape (N, OH). N is the number of atoms + in ``mol`` and OH the lenght of the one-hot encoding. + + """ + + prop_dict = {} + + for prop in property_list: + + prop = prop.lower() + prop_name = prop + + property_array = [] + for ii, atom in enumerate(mol.GetAtoms()): + + if prop in ["atomic-number"]: + one_hot = one_of_k_encoding(atom.GetSymbol(), nmp.ATOM_LIST) + elif prop in ["degree"]: + one_hot = one_of_k_encoding(atom.GetDegree(), nmp.ATOM_DEGREE_LIST) + elif prop in ["valence", "total-valence"]: + prop_name = "valence" + one_hot = one_of_k_encoding(atom.GetTotalValence(), nmp.VALENCE) + elif prop in ["implicit-valence"]: + one_hot = one_of_k_encoding(atom.GetImplicitValence(), nmp.VALENCE) + elif prop in ["hybridization"]: + one_hot = one_of_k_encoding(atom.GetHybridization(), nmp.HYBRIDIZATION_LIST) + elif prop in ["chirality"]: + try: + one_hot = one_of_k_encoding(atom.GetProp("_CIPCode"), nmp.CHIRALITY_LIST) + one_hot.append(int(atom.HasProp("_ChiralityPossible"))) + except: + one_hot = [0, 0, int(atom.HasProp("_ChiralityPossible"))] + else: + raise ValueError(f"Unsupported property `{prop}`") + + property_array.append(np.asarray(one_hot, dtype=np.float32)) + + prop_dict[prop_name] = np.stack(property_array, axis=0) + + return prop_dict + + +def get_mol_atomic_features_float( + mol: Chem.rdchem.Mol, + property_list: Union[List[str], List[Callable]], + offset_carbon: bool = True, + mask_nan: Optional[float] = 0., +) -> Dict[str, np.ndarray]: + """ + Get a dictionary of floating-point arrays of atomic properties. + To ensure all properties are at a similar scale, some of the properties + are divided by a constant. + + There is also the possibility of offseting by the carbon value using + the `offset_carbon` parameter. + + Parameters: + + mol: + molecule from which to extract the properties + + property_list: + A list of atomic properties to get from the molecule, such as 'atomic-number', + 'mass', 'valence', 'degree', 'electronegativity'. + Some elements are divided by a factor to avoid feature explosion. + + Accepted properties are: + + - "atomic-number" + - "mass", "weight" + - "valence", "total-valence" + - "implicit-valence" + - "hybridization" + - "chirality" + - "hybridization" + - "aromatic" + - "ring", "in-ring" + - "degree" + - "radical-electron" + - "formal-charge" + - "vdw-radius" + - "covalent-radius" + - "electronegativity" + - "ionization", "first-ionization" + - "melting-point" + - "metal" + - "single-bond" + - "aromatic-bond" + - "double-bond" + - "triple-bond" + - "is-carbon" + + offset_carbon: + Whether to subract the Carbon property from the desired atomic property. + For example, if we want the mass of the Lithium (6.941), the mass of the + Carbon (12.0107) will be subracted, resulting in a value of -5.0697 + + mask_nan: + Floating point value used to replace the NaNs in the atomic property. + This can happen when taking the electronegativity of a noble gas, + or other properties that are not measured for specific atoms. + If `None`, the NaNs are not masked. + + Returns: + + prop_dict: + A dictionnary where the element of ``property_list`` are the keys + and the values are np.ndarray of shape (N,). N is the number of atoms + in ``mol``. + + """ + + periodic_table = Chem.GetPeriodicTable() + prop_dict = {} + C = Chem.Atom("C") + offC = bool(offset_carbon) + + for prop in property_list: + + prop_name = None + + property_array = np.zeros(mol.GetNumAtoms(), dtype=np.float32) + for ii, atom in enumerate(mol.GetAtoms()): + + val = None + + if isinstance(prop, str): + + prop = prop.lower() + prop_name = prop + + if prop in ["atomic-number"]: + val = (atom.GetAtomicNum() - (offC * C.GetAtomicNum())) / 5 + elif prop in ["mass", "weight"]: + prop_name = "mass" + val = (atom.GetMass() - (offC * C.GetMass())) / 10 + elif prop in ["valence", "total-valence"]: + prop_name = "valence" + val = atom.GetTotalValence() - (offC * 4) + elif prop in ["implicit-valence"]: + val = atom.GetImplicitValence() + elif prop in ["hybridization"]: + val = atom.GetHybridization() + elif prop in ["chirality"]: + val = (atom.GetProp("_CIPCode") == "R") if atom.HasProp("_CIPCode") else 2 + elif prop in ["hybridization"]: + val = atom.GetHybridization() + elif prop in ["aromatic"]: + val = atom.GetIsAromatic() + elif prop in ["ring", "in-ring"]: + prop_name = "in-ring" + val = atom.IsInRing() + elif prop in ["degree"]: + val = atom.GetTotalDegree() - (offC * 2) + elif prop in ["radical-electron"]: + val = atom.GetNumRadicalElectrons() + elif prop in ["formal-charge"]: + val = atom.GetFormalCharge() + elif prop in ["vdw-radius"]: + val = periodic_table.GetRvdw(atom.GetAtomicNum()) - offC * periodic_table.GetRvdw( + C.GetAtomicNum() + ) + elif prop in ["covalent-radius"]: + val = periodic_table.GetRcovalent( + atom.GetAtomicNum() + ) - offC * periodic_table.GetRcovalent(C.GetAtomicNum()) + elif prop in ["electronegativity"]: + val = ( + nmp.PERIODIC_TABLE["Electronegativity"][atom.GetAtomicNum()] + - offC * nmp.PERIODIC_TABLE["Electronegativity"][C.GetAtomicNum()] + ) + elif prop in ["ionization", "first-ionization"]: + prop_name = "ionization" + val = ( + nmp.PERIODIC_TABLE["FirstIonization"][atom.GetAtomicNum()] + - offC * nmp.PERIODIC_TABLE["FirstIonization"][C.GetAtomicNum()] + ) / 5 + elif prop in ["melting-point"]: + val = ( + nmp.PERIODIC_TABLE["MeltingPoint"][atom.GetAtomicNum()] + - offC * nmp.PERIODIC_TABLE["MeltingPoint"][C.GetAtomicNum()] + ) / 200 + elif prop in ["metal"]: + if nmp.PERIODIC_TABLE["Metal"][atom.GetAtomicNum()] == "yes": + val = 2 + elif nmp.PERIODIC_TABLE["Metalloid"][atom.GetAtomicNum()] == "yes": + val = 1 + else: + val = 0 + elif "-bond" in prop: + bonds = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] + if prop in ["single-bond"]: + val = len([bond == 1 for bond in bonds]) + elif prop in ["aromatic-bond"]: + val = len([bond == 1.5 for bond in bonds]) + elif prop in ["double-bond"]: + val = len([bond == 2 for bond in bonds]) + elif prop in ["triple-bond"]: + val = len([bond == 3 for bond in bonds]) + else: + raise ValueError(f"{prop} is not a correct bond.") + val -= offC * 1 + elif prop in ["is-carbon"]: + val = atom.GetAtomicNum() == 6 + val -= offC * 1 + else: + raise ValueError(f"Unsupported property `{prop}`") + + elif callable(prop): + prop_name = str(prop) + val = prop(atom) + else: + ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") + + if val is None: + raise ValueError("val is undefined.") + + property_array[ii] = val + + if prop_name is None: + raise ValueError("prop_name is undefined.") + + # Mask the NaNs + if mask_nan is not None: + property_array[np.isnan(property_array)] = mask_nan + prop_dict[prop_name] = property_array + + return prop_dict + + +def get_simple_mol_conformer(mol: Chem.rdchem.Mol) -> Union[Chem.rdchem.Conformer, None]: + r""" + If the molecule has a conformer, then it will return the conformer at idx `0`. + Otherwise, it generates a simple molecule conformer using `rdkit.Chem.rdDistGeom.EmbedMolecule` + and returns it. This is meant to be used in simple functions like `GetBondLength`, + not in functions requiring complex 3D structure. + + Parameters: + + mol: Rdkit Molecule + + Returns: + conf: A conformer of the molecule, or `None` if it fails + """ + + val = 0 + if mol.GetNumConformers() == 0: + val = Chem.rdDistGeom.EmbedMolecule(mol) + if val == -1: + val = Chem.rdDistGeom.EmbedMolecule( + mol, + enforceChirality=False, + ignoreSmoothingFailures=True, + useBasicKnowledge=True, + useExpTorsionAnglePrefs=True, + forceTol=0.1, + ) + + if val == -1: + conf = None + logger.warn("Couldn't compute conformer for molecule `{}`".format(Chem.MolToSmiles(mol))) + else: + conf = mol.GetConformer(0) + + return conf + + +def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: Chem.rdchem.Mol) -> float: + r""" + Estimate the bond length between atoms by looking at the estimated atomic radius + that depends both on the atom type and the bond type. The resulting bond-length is + then the sum of the radius. + + Keep in mind that this function only provides an estimate of the bond length and not + the true one based on a conformer. The vast majority od estimated bond lengths will + have an error below 5% while some bonds can have an error up to 20%. This function + is mostly useful when conformer generation fails for some molecules, or for + increased computation speed. + + Parameters: + bond: The bond to measure its lenght + mol: The molecule containing the bond (used to get neighbouring atoms) + + Returns: + bond_length: The bond length in Angstrom, typically a value around 1-2. + + """ + + # Small function to convert strings to floats + def float_or_nan(string): + try: + val = float(string) + except: + val = float("nan") + return val + + # Get the atoms connected by the bond + idx1 = bond.GetBeginAtomIdx() + idx2 = bond.GetEndAtomIdx() + atom1 = mol.GetAtomWithIdx(idx1).GetAtomicNum() + atom2 = mol.GetAtomWithIdx(idx2).GetAtomicNum() + bond_type = bond.GetBondType() + + # Get single bond atomic radius + if bond_type == Chem.rdchem.BondType.SINGLE: + rad1 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom1]] + rad2 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom2]] + # Get double bond atomic radius + elif bond_type == Chem.rdchem.BondType.DOUBLE: + rad1 = [nmp.PERIODIC_TABLE["DoubleBondRadius"][atom1]] + rad2 = [nmp.PERIODIC_TABLE["DoubleBondRadius"][atom2]] + # Get triple bond atomic radius + elif bond_type == Chem.rdchem.BondType.TRIPLE: + rad1 = [nmp.PERIODIC_TABLE["TripleBondRadius"][atom1]] + rad2 = [nmp.PERIODIC_TABLE["TripleBondRadius"][atom2]] + # Get average of single bond and double bond atomic radius + elif bond_type == Chem.rdchem.BondType.AROMATIC: + rad1 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom1], nmp.PERIODIC_TABLE["DoubleBondRadius"][atom1]] + rad2 = [nmp.PERIODIC_TABLE["SingleBondRadius"][atom2], nmp.PERIODIC_TABLE["DoubleBondRadius"][atom2]] + + # Average the bond lengths, while ignoring nans in case some missing value + rad1_float = np.nanmean(np.array([float_or_nan(elem) for elem in rad1])) + rad2_float = np.nanmean(np.array([float_or_nan(elem) for elem in rad2])) + + # If the bond radius is still nan (this shouldn't happen), take the single bond radius + if np.isnan(rad1_float): + rad1_float = float_or_nan(nmp.PERIODIC_TABLE["SingleBondRadius"][atom1]) + if np.isnan(rad2_float): + rad2_float = float_or_nan(nmp.PERIODIC_TABLE["SingleBondRadius"][atom2]) + + bond_length = rad1_float + rad2_float + + return bond_length + + +def get_mol_edge_features(mol: Chem.rdchem.Mol, property_list: List[str], mask_nan: Optional[float] = 0.): + r""" + Get the following set of features for any given bond + See `goli.features.nmp` for allowed values in one hot encoding + + * One-hot representation of the bond type. Note that you should not kekulize your + molecules, if you expect this to take aromatic bond into account. + * Bond stereo type, following CIP classification + * Whether the bond is conjugated + * Whether the bond is in a ring + + Parameters: + mol: rdkit.Chem.Molecule + the molecule of interest + + property_list: + A list of edge properties to return for the given molecule. + Accepted properties are: + + - "bond-type-onehot" + - "bond-type-float" + - "stereo" + - "in-ring" + - "conjugated" + - "conformer-bond-length" (might cause problems with complex molecules) + - "estimated-bond-length" + + Returns: + prop_dict: + A dictionnary where the element of ``property_list`` are the keys + and the values are np.ndarray of shape (N,). N is the number of atoms + in ``mol``. + + """ + + prop_dict = {} + + # Compute features for each bond + num_bonds = mol.GetNumBonds() + for prop in property_list: + property_array = [] + for ii in range(num_bonds): + prop = prop.lower() + bond = mol.GetBondWithIdx(ii) + + if prop in ["bond-type-onehot"]: + encoding = one_of_k_encoding(bond.GetBondType(), nmp.BOND_TYPES) + elif prop in ["bond-type-float"]: + encoding = [bond.GetBondTypeAsDouble()] + elif prop in ["stereo"]: + encoding = one_of_k_encoding(bond.GetStereo(), nmp.BOND_STEREO) + elif prop in ["in-ring"]: + encoding = [bond.IsInRing()] + elif prop in ["conjugated"]: + encoding = [bond.GetIsConjugated()] + elif prop in ["conformer-bond-length"]: + conf = get_simple_mol_conformer(mol) + if conf is not None: + idx1 = bond.GetBeginAtomIdx() + idx2 = bond.GetEndAtomIdx() + encoding = [Chem.rdMolTransforms.GetBondLength(conf, idx1, idx2)] + else: + encoding = [0] + elif prop in ["estimated-bond-length"]: + encoding = [get_estimated_bond_length(bond, mol)] + + else: + raise ValueError(f"Unsupported property `{prop}`") + + property_array.append(np.asarray(encoding, dtype=np.float32)) + + if num_bonds > 0: + property_array = np.stack(property_array, axis=0) + if mask_nan is not None: # Mask the NaNs + property_array[np.isnan(property_array)] = mask_nan + prop_dict[prop] = property_array + else: + prop_dict[prop] = np.array([]) + + return prop_dict + + +def mol_to_adj_and_features( + mol: Union[str, Chem.rdchem.Mol], + atom_property_list_onehot: List[str] = [], + atom_property_list_float: List[Union[str, Callable]] = [], + edge_property_list: List[str] = [], + add_self_loop: bool = False, + explicit_H: bool = False, + use_bonds_weights: bool = False, + pos_encoding_as_features: Dict[str, Any] = None, + pos_encoding_as_directions: Dict[str, Any] = None, +) -> Union[csr_matrix, Union[np.ndarray, None], Union[np.ndarray, None]]: + r""" + Transforms a molecule into an adjacency matrix representing the molecular graph + and a set of atom and bond features. + + Parameters: + + mol: + The molecule to be converted + + atom_property_list_onehot: + List of the properties used to get one-hot encoding of the atom type, + such as the atom index represented as a one-hot vector. + See function `get_mol_atomic_features_onehot` + + atom_property_list_float: + List of the properties used to get floating-point encoding of the atom type, + such as the atomic mass or electronegativity. + See function `get_mol_atomic_features_float` + + edge_property_list: + List of the properties used to encode the edges, such as the edge type + and the stereo type. + + add_self_loop: + Whether to add a value of `1` on the diagonal of the adjacency matrix. + + explicit_H: + Whether to consider the Hydrogens explicitely. If `False`, the hydrogens + are implicit. + + use_bonds_weights: + Whether to use the floating-point value of the bonds in the adjacency matrix, + such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 + + Returns: + + adj: + Scipy sparse adjacency matrix of the molecule + + ndata: + Concatenated node data of the atoms, based on the properties from + `atom_property_list_onehot` and `atom_property_list_float`. + If no properties are given, it returns `None` + + edata + Concatenated node edge of the molecule, based on the properties from + `edge_property_list`. + If no properties are given, it returns `None` + + """ + + if isinstance(mol, str): + mol = dm.to_mol(mol) + + # Add or remove explicit hydrogens + if explicit_H: + mol = Chem.AddHs(mol) + else: + mol = Chem.RemoveHs(mol) + + # Get the adjacency matrix + adj = GetAdjacencyMatrix(mol, useBO=use_bonds_weights, force=True) + if add_self_loop: + adj = adj + np.eye(adj.shape[0]) + adj = csr_matrix(adj) + + # Get the node features + atom_features_onehot = get_mol_atomic_features_onehot(mol, atom_property_list_onehot) + atom_features_float = get_mol_atomic_features_float(mol, atom_property_list_float) + ndata = list(atom_features_float.values()) + list(atom_features_onehot.values()) + ndata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in ndata] + ndata = np.concatenate(ndata, axis=1) if len(ndata) > 0 else None + + # Get the edge features + edge_features = get_mol_edge_features(mol, edge_property_list) + edata = list(edge_features.values()) + edata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in edata] + edata = np.concatenate(edata, axis=1) if len(edata) > 0 else None + + pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir = get_all_positional_encoding( + adj, pos_encoding_as_features, pos_encoding_as_directions + ) + + return adj, ndata, edata, pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir + + +def mol_to_dglgraph( + mol: Chem.rdchem.Mol, + atom_property_list_onehot: List[str] = [], + atom_property_list_float: List[Union[str, Callable]] = [], + edge_property_list: List[str] = [], + add_self_loop: bool = False, + explicit_H: bool = False, + use_bonds_weights: bool = False, + pos_encoding_as_features: Dict[str, Any] = None, + pos_encoding_as_directions: Dict[str, Any] = None, + dtype: torch.dtype = torch.float32, + on_error: str = "ignore", +) -> dgl.DGLGraph: + r""" + Transforms a molecule into an adjacency matrix representing the molecular graph + and a set of atom and bond features. + + Parameters: + + mol: + The molecule to be converted + + atom_property_list_onehot: + List of the properties used to get one-hot encoding of the atom type, + such as the atom index represented as a one-hot vector. + See function `get_mol_atomic_features_onehot` + + atom_property_list_float: + List of the properties used to get floating-point encoding of the atom type, + such as the atomic mass or electronegativity. + See function `get_mol_atomic_features_float` + + edge_property_list: + List of the properties used to encode the edges, such as the edge type + and the stereo type. + + add_self_loop: + Whether to add a value of `1` on the diagonal of the adjacency matrix. + + explicit_H: + Whether to consider the Hydrogens explicitely. If `False`, the hydrogens + are implicit. + + use_bonds_weights: + Whether to use the floating-point value of the bonds in the adjacency matrix, + such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 + + dtype: + The torch data type used to build the graph + + on_error: + What to do when the featurization fails. + + - "raise": Raise an error + - "warn": Raise a warning and return None + - "ignore": Ignore the error and return None + + Returns: + + graph: + DGL graph, with `graph.ndata['n']` corresponding to the concatenated + node data from `atom_property_list_onehot` and `atom_property_list_float`, + `graph.edata['e']` corresponding to the concatenated edge data from `edge_property_list` + + """ + + input_mol = mol + + try: + + if isinstance(mol, str): + mol = dm.to_mol(mol) + if explicit_H: + mol = Chem.AddHs(mol) + else: + mol = Chem.RemoveHs(mol) + + # Get the adjacency, node features and edge features + ( + adj, + ndata, + edata, + pos_enc_feats_sign_flip, + pos_enc_feats_no_flip, + pos_enc_dir, + ) = mol_to_adj_and_features( + mol=mol, + atom_property_list_onehot=atom_property_list_onehot, + atom_property_list_float=atom_property_list_float, + edge_property_list=edge_property_list, + add_self_loop=add_self_loop, + explicit_H=explicit_H, + use_bonds_weights=use_bonds_weights, + pos_encoding_as_features=pos_encoding_as_features, + pos_encoding_as_directions=pos_encoding_as_directions, + ) + except Exception as e: + if on_error.lower() == "raise": + raise e + elif on_error.lower() == "warn": + smiles = input_mol + if isinstance(smiles, Chem.rdchem.Mol): + smiles = Chem.MolToSmiles(input_mol) + + msg = str(e) + "\nIgnoring following molecule:" + smiles + logger.warning(msg) + return None + elif on_error.lower() == "ignore": + return None + + # Transform the matrix and data into a DGLGraph object + graph = dgl.from_scipy(adj) + + # Assign the node data + if ndata is not None: + graph.ndata["feat"] = torch.from_numpy(ndata).to(dtype=dtype) + + # Assign the edge data. Due to DGL only supporting Hetero-graphs, we + # need to duplicate each edge information for its 2 entries + if edata is not None: + src_ids, dst_ids = graph.all_edges() + hetero_edata = np.zeros_like(edata, shape=(edata.shape[0] * 2, edata.shape[1])) + for ii in range(mol.GetNumBonds()): + bond = mol.GetBondWithIdx(ii) + src, dst = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + id1 = np.where((src == src_ids) & (dst == dst_ids))[0] + id2 = np.where((dst == src_ids) & (src == dst_ids))[0] + hetero_edata[id1, :] = edata[ii, :] + hetero_edata[id2, :] = edata[ii, :] + + graph.edata["feat"] = torch.from_numpy(hetero_edata).to(dtype=dtype) + + # Add sign-flip positional encoding + if pos_enc_feats_sign_flip is not None: + graph.ndata["pos_enc_feats_sign_flip"] = pos_enc_feats_sign_flip + + # Add non-sign-flip positional encoding + if pos_enc_feats_no_flip is not None: + graph.ndata["pos_enc_feats_no_flip"] = pos_enc_feats_no_flip + + # Add positional encoding for directional use + if pos_enc_dir is not None: + graph.ndata["pos_dir"] = pos_enc_dir + + return graph + + +def mol_to_dglgraph_signature(featurizer_args: Dict[str, Any] = None): + """Get the default arguments of `mol_to_dglgraph` and update it + with a provided dict of arguments in order to get a fulle signature + of the featurizer args actually used for the features computation. + """ + + # Get the signature of `mol_to_dglgraph` + signature = inspect.signature(mol_to_dglgraph) + + # Filter out empty arguments (without default value) + parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values())) + + # Convert to dict + parameters = {param.name: param.default for param in parameters} + + # Update the parameters with the supplied ones + if featurizer_args is not None: + parameters.update(featurizer_args) + + return parameters diff --git a/goli/features/nmp.py b/goli/features/nmp.py index 51ec9f3a5..1cbc5d5cc 100644 --- a/goli/features/nmp.py +++ b/goli/features/nmp.py @@ -1,87 +1,87 @@ -import importlib.resources - -import pandas as pd -from rdkit import Chem - -# NOTE(hadim): usually it's best to embed this in a function. -with importlib.resources.open_text("goli.features", "periodic_table.csv") as f: - PERIODIC_TABLE = pd.read_csv(f) -PERIODIC_TABLE = PERIODIC_TABLE.set_index("AtomicNumber") - -ATOM_LIST = [ - "C", - "N", - "O", - "S", - "F", - "Si", - "P", - "Cl", - "Br", - "Mg", - "Na", - "Ca", - "Fe", - "As", - "Al", - "I", - "B", - "V", - "K", - "Tl", - "Yb", - "Sb", - "Sn", - "Ag", - "Pd", - "Co", - "Se", - "Ti", - "Zn", - "H", - "Li", - "Ge", - "Cu", - "Au", - "Ni", - "Cd", - "In", - "Mn", - "Zr", - "Cr", - "Pt", - "Hg", - "Pb", -] - -ATOM_NUM_H = [0, 1, 2, 3, 4] -VALENCE = [0, 1, 2, 3, 4, 5, 6] -CHARGE_LIST = [-3, -2, -1, 0, 1, 2, 3] -RADICAL_E_LIST = [0, 1, 2] -ATOM_DEGREE_LIST = [0, 1, 2, 3, 4] - -HYBRIDIZATION_LIST = [ - Chem.rdchem.HybridizationType.names[k] - for k in sorted(Chem.rdchem.HybridizationType.names.keys(), reverse=True) - if k != "OTHER" -] - - -CHIRALITY_LIST = ["R"] # alternative is just S - - -BOND_TYPES = [ - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, -] - -BOND_STEREO = [ - Chem.rdchem.BondStereo.STEREONONE, - Chem.rdchem.BondStereo.STEREOANY, - Chem.rdchem.BondStereo.STEREOZ, - Chem.rdchem.BondStereo.STEREOE, - Chem.rdchem.BondStereo.STEREOCIS, - Chem.rdchem.BondStereo.STEREOTRANS, -] +import importlib.resources + +import pandas as pd +from rdkit import Chem + +# NOTE(hadim): usually it's best to embed this in a function. +with importlib.resources.open_text("goli.features", "periodic_table.csv") as f: + PERIODIC_TABLE = pd.read_csv(f) +PERIODIC_TABLE = PERIODIC_TABLE.set_index("AtomicNumber") + +ATOM_LIST = [ + "C", + "N", + "O", + "S", + "F", + "Si", + "P", + "Cl", + "Br", + "Mg", + "Na", + "Ca", + "Fe", + "As", + "Al", + "I", + "B", + "V", + "K", + "Tl", + "Yb", + "Sb", + "Sn", + "Ag", + "Pd", + "Co", + "Se", + "Ti", + "Zn", + "H", + "Li", + "Ge", + "Cu", + "Au", + "Ni", + "Cd", + "In", + "Mn", + "Zr", + "Cr", + "Pt", + "Hg", + "Pb", +] + +ATOM_NUM_H = [0, 1, 2, 3, 4] +VALENCE = [0, 1, 2, 3, 4, 5, 6] +CHARGE_LIST = [-3, -2, -1, 0, 1, 2, 3] +RADICAL_E_LIST = [0, 1, 2] +ATOM_DEGREE_LIST = [0, 1, 2, 3, 4] + +HYBRIDIZATION_LIST = [ + Chem.rdchem.HybridizationType.names[k] + for k in sorted(Chem.rdchem.HybridizationType.names.keys(), reverse=True) + if k != "OTHER" +] + + +CHIRALITY_LIST = ["R"] # alternative is just S + + +BOND_TYPES = [ + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, +] + +BOND_STEREO = [ + Chem.rdchem.BondStereo.STEREONONE, + Chem.rdchem.BondStereo.STEREOANY, + Chem.rdchem.BondStereo.STEREOZ, + Chem.rdchem.BondStereo.STEREOE, + Chem.rdchem.BondStereo.STEREOCIS, + Chem.rdchem.BondStereo.STEREOTRANS, +] diff --git a/goli/features/positional_encoding.py b/goli/features/positional_encoding.py index dda2457e9..8d88cefd7 100644 --- a/goli/features/positional_encoding.py +++ b/goli/features/positional_encoding.py @@ -1,101 +1,101 @@ -from typing import Tuple, Optional, Dict, Union -import numpy as np -from scipy.sparse import spmatrix -import torch - -from goli.features.spectral import compute_laplacian_positional_eigvecs - - -def get_all_positional_encoding( - adj: Union[np.ndarray, spmatrix], - pos_encoding_as_features: Optional[Dict] = None, - pos_encoding_as_directions: Optional[Dict] = None, -) -> Tuple[np.ndarray, np.ndarray]: - r""" - Get features positional encoding and direction positional encoding. - - Parameters: - adj: Adjacency matrix of the graph - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - pos_encoding_as_directions: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for directional features. - """ - - pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir = None, None, None - pos_encoding_as_features = {} if pos_encoding_as_features is None else pos_encoding_as_features - pos_encoding_as_directions = {} if pos_encoding_as_directions is None else pos_encoding_as_directions - - # Get the positional encoding for the features - if len(pos_encoding_as_features) > 0: - pos_enc_feats_sign_flip, pos_enc_feats_no_flip = graph_positional_encoder( - adj, **pos_encoding_as_features - ) - - # Get the positional encoding for the directions - if len(pos_encoding_as_directions) > 0: - if pos_encoding_as_directions == pos_encoding_as_features: - - # Concatenate the sign-flip and non-sign-flip positional encodings - if pos_enc_feats_sign_flip is None: - pos_enc_dir = pos_enc_feats_no_flip - elif pos_enc_feats_no_flip is None: - pos_enc_dir = pos_enc_feats_sign_flip - else: - pos_enc_dir = np.concatenate((pos_enc_feats_sign_flip, pos_enc_feats_no_flip), axis=1) - - else: - pos_enc_dir1, pos_enc_dir2 = graph_positional_encoder(adj, **pos_encoding_as_directions) - # Concatenate both positional encodings - if pos_enc_dir1 is None: - pos_enc_dir = pos_enc_dir2 - elif pos_enc_dir2 is None: - pos_enc_dir = pos_enc_dir1 - else: - pos_enc_dir = np.concatenate((pos_enc_dir1, pos_enc_dir2), axis=1) - - return pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir - - -def graph_positional_encoder( - adj: Union[np.ndarray, spmatrix], pos_type: str, num_pos: int, disconnected_comp: bool = True, **kwargs -) -> np.ndarray: - r""" - Get a positional encoding that depends on the parameters. - - Parameters: - - adj: Adjacency matrix of the graph - - pos_type: The type of positional encoding to use. Supported types are: - - - laplacian_eigvec: the - - laplacian_eigvec_eigval - - """ - - pos_type = pos_type.lower() - pos_enc_sign_flip, pos_enc_no_flip = None, None - - if pos_type == "laplacian_eigvec": - _, eigvecs = compute_laplacian_positional_eigvecs( - adj=adj, num_pos=num_pos, disconnected_comp=disconnected_comp, **kwargs - ) - pos_enc_sign_flip = eigvecs - - elif pos_type == "laplacian_eigvec_eigval": - eigvals_tile, eigvecs = compute_laplacian_positional_eigvecs( - adj=adj, num_pos=num_pos, disconnected_comp=disconnected_comp, **kwargs - ) - pos_enc_sign_flip = eigvecs - pos_enc_no_flip = eigvals_tile - else: - raise ValueError(f"Unknown `pos_type`: {pos_type}") - - if pos_enc_sign_flip is not None: - pos_enc_sign_flip = torch.as_tensor(np.real(pos_enc_sign_flip)).to(torch.float32) - - if pos_enc_no_flip is not None: - pos_enc_no_flip = torch.as_tensor(np.real(pos_enc_no_flip)).to(torch.float32) - - return pos_enc_sign_flip, pos_enc_no_flip +from typing import Tuple, Optional, Dict, Union +import numpy as np +from scipy.sparse import spmatrix +import torch + +from goli.features.spectral import compute_laplacian_positional_eigvecs + + +def get_all_positional_encoding( + adj: Union[np.ndarray, spmatrix], + pos_encoding_as_features: Optional[Dict] = None, + pos_encoding_as_directions: Optional[Dict] = None, +) -> Tuple[np.ndarray, np.ndarray]: + r""" + Get features positional encoding and direction positional encoding. + + Parameters: + adj: Adjacency matrix of the graph + pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` + to generate positional encoding for node features. + pos_encoding_as_directions: keyword arguments for function `graph_positional_encoder` + to generate positional encoding for directional features. + """ + + pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir = None, None, None + pos_encoding_as_features = {} if pos_encoding_as_features is None else pos_encoding_as_features + pos_encoding_as_directions = {} if pos_encoding_as_directions is None else pos_encoding_as_directions + + # Get the positional encoding for the features + if len(pos_encoding_as_features) > 0: + pos_enc_feats_sign_flip, pos_enc_feats_no_flip = graph_positional_encoder( + adj, **pos_encoding_as_features + ) + + # Get the positional encoding for the directions + if len(pos_encoding_as_directions) > 0: + if pos_encoding_as_directions == pos_encoding_as_features: + + # Concatenate the sign-flip and non-sign-flip positional encodings + if pos_enc_feats_sign_flip is None: + pos_enc_dir = pos_enc_feats_no_flip + elif pos_enc_feats_no_flip is None: + pos_enc_dir = pos_enc_feats_sign_flip + else: + pos_enc_dir = np.concatenate((pos_enc_feats_sign_flip, pos_enc_feats_no_flip), axis=1) + + else: + pos_enc_dir1, pos_enc_dir2 = graph_positional_encoder(adj, **pos_encoding_as_directions) + # Concatenate both positional encodings + if pos_enc_dir1 is None: + pos_enc_dir = pos_enc_dir2 + elif pos_enc_dir2 is None: + pos_enc_dir = pos_enc_dir1 + else: + pos_enc_dir = np.concatenate((pos_enc_dir1, pos_enc_dir2), axis=1) + + return pos_enc_feats_sign_flip, pos_enc_feats_no_flip, pos_enc_dir + + +def graph_positional_encoder( + adj: Union[np.ndarray, spmatrix], pos_type: str, num_pos: int, disconnected_comp: bool = True, **kwargs +) -> np.ndarray: + r""" + Get a positional encoding that depends on the parameters. + + Parameters: + + adj: Adjacency matrix of the graph + + pos_type: The type of positional encoding to use. Supported types are: + + - laplacian_eigvec: the + - laplacian_eigvec_eigval + + """ + + pos_type = pos_type.lower() + pos_enc_sign_flip, pos_enc_no_flip = None, None + + if pos_type == "laplacian_eigvec": + _, eigvecs = compute_laplacian_positional_eigvecs( + adj=adj, num_pos=num_pos, disconnected_comp=disconnected_comp, **kwargs + ) + pos_enc_sign_flip = eigvecs + + elif pos_type == "laplacian_eigvec_eigval": + eigvals_tile, eigvecs = compute_laplacian_positional_eigvecs( + adj=adj, num_pos=num_pos, disconnected_comp=disconnected_comp, **kwargs + ) + pos_enc_sign_flip = eigvecs + pos_enc_no_flip = eigvals_tile + else: + raise ValueError(f"Unknown `pos_type`: {pos_type}") + + if pos_enc_sign_flip is not None: + pos_enc_sign_flip = torch.as_tensor(np.real(pos_enc_sign_flip)).to(torch.float32) + + if pos_enc_no_flip is not None: + pos_enc_no_flip = torch.as_tensor(np.real(pos_enc_no_flip)).to(torch.float32) + + return pos_enc_sign_flip, pos_enc_no_flip diff --git a/goli/features/properties.py b/goli/features/properties.py index e12f4c228..4c5b44a62 100644 --- a/goli/features/properties.py +++ b/goli/features/properties.py @@ -1,106 +1,106 @@ -from typing import Union, List, Callable - -import numpy as np -import datamol as dm - -from rdkit import Chem -from rdkit.Chem import rdMolDescriptors as rdMD - -from mordred import Calculator, descriptors - - -def get_prop_or_none(prop, n, *args, **kwargs): - r""" - return properties. If error, return list of `None` with lenght `n`. - """ - try: - return prop(*args, **kwargs) - except RuntimeError: - return [None] * n - - -def get_props_from_mol(mol: Union[Chem.rdchem.Mol, str], properties: Union[List[str], str] = "autocorr3d"): - r""" - Function to get a given set of desired properties from a molecule, - and output a property list. - - Parameters: - mol: The molecule from which to compute the properties. - properties: - The list of properties to compute for each molecule. It can be the following: - - - 'descriptors' - - 'autocorr3d' - - 'rdf' - - 'morse' - - 'whim' - - 'all' - - Returns: - props: np.array(float) - The array of properties for the desired molecule - classes_start_idx: list(int) - The list of index specifying the start of each new class of - descriptor or property. For example, if props has 20 elements, - the first 5 are rotatable bonds, the next 8 are morse, and - the rest are whim, then ``classes_start_idx = [0, 5, 13]``. - This will mainly be useful to normalize the features of - each class. - classes_names: list(str) - The name of the classes associated to each starting index. - Will be usefull to understand what property is the network learning. - - """ - - if isinstance(mol, str): - mol = dm.to_mol(mol) - - if isinstance(properties, str): - properties = [properties] - - properties = [p.lower() for p in properties] - - # Initialize arrays - props = [] # Property vector for the features - classes_start_idx = [] # The starting index for each property class - classes_names = [] - - # Generate a 3D structure for the molecule - mol = Chem.AddHs(mol) # type: ignore - - if ("descriptors" in properties) or ("all" in properties): - # Calculate the descriptors of the molecule - for desc in descriptors.all: - classes_names.append(desc.__name__.replace("mordred.", "")) - classes_start_idx.append(len(props)) - calc = Calculator(desc, ignore_3D=True) - props.extend(calc(mol)) - - if ("autocorr3d" in properties) or ("all" in properties): - # Some kind of 3D description of the molecule - classes_names.append("autocorr3d") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcAUTOCORR3D, 80, mol)) - - if ("rdf" in properties) or ("all" in properties): - # The radial distribution function (better than the inertia) - # https://en.wikipedia.org/wiki/Radial_distribution_function - classes_names.append("rdf") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcRDF, 210, mol)) - - if ("morse" in properties) or ("all" in properties): - # Molecule Representation of Structures based on Electron diffraction descriptors - classes_names.append("morse") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcMORSE, 224, mol)) - - if ("whim" in properties) or ("all" in properties): - # WHIM descriptors are 3D structural descriptors obtained from the - # (x,y,z)‐atomic coordinates of a molecular conformation of a chemical, - # and are used successfully in QSAR modelling. - classes_names.append("whim") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcWHIM, 114, mol)) - - return np.array(props), classes_start_idx, classes_names +from typing import Union, List, Callable + +import numpy as np +import datamol as dm + +from rdkit import Chem +from rdkit.Chem import rdMolDescriptors as rdMD + +from mordred import Calculator, descriptors + + +def get_prop_or_none(prop, n, *args, **kwargs): + r""" + return properties. If error, return list of `None` with lenght `n`. + """ + try: + return prop(*args, **kwargs) + except RuntimeError: + return [None] * n + + +def get_props_from_mol(mol: Union[Chem.rdchem.Mol, str], properties: Union[List[str], str] = "autocorr3d"): + r""" + Function to get a given set of desired properties from a molecule, + and output a property list. + + Parameters: + mol: The molecule from which to compute the properties. + properties: + The list of properties to compute for each molecule. It can be the following: + + - 'descriptors' + - 'autocorr3d' + - 'rdf' + - 'morse' + - 'whim' + - 'all' + + Returns: + props: np.array(float) + The array of properties for the desired molecule + classes_start_idx: list(int) + The list of index specifying the start of each new class of + descriptor or property. For example, if props has 20 elements, + the first 5 are rotatable bonds, the next 8 are morse, and + the rest are whim, then ``classes_start_idx = [0, 5, 13]``. + This will mainly be useful to normalize the features of + each class. + classes_names: list(str) + The name of the classes associated to each starting index. + Will be usefull to understand what property is the network learning. + + """ + + if isinstance(mol, str): + mol = dm.to_mol(mol) + + if isinstance(properties, str): + properties = [properties] + + properties = [p.lower() for p in properties] + + # Initialize arrays + props = [] # Property vector for the features + classes_start_idx = [] # The starting index for each property class + classes_names = [] + + # Generate a 3D structure for the molecule + mol = Chem.AddHs(mol) # type: ignore + + if ("descriptors" in properties) or ("all" in properties): + # Calculate the descriptors of the molecule + for desc in descriptors.all: + classes_names.append(desc.__name__.replace("mordred.", "")) + classes_start_idx.append(len(props)) + calc = Calculator(desc, ignore_3D=True) + props.extend(calc(mol)) + + if ("autocorr3d" in properties) or ("all" in properties): + # Some kind of 3D description of the molecule + classes_names.append("autocorr3d") + classes_start_idx.append(len(props)) + props.extend(get_prop_or_none(rdMD.CalcAUTOCORR3D, 80, mol)) + + if ("rdf" in properties) or ("all" in properties): + # The radial distribution function (better than the inertia) + # https://en.wikipedia.org/wiki/Radial_distribution_function + classes_names.append("rdf") + classes_start_idx.append(len(props)) + props.extend(get_prop_or_none(rdMD.CalcRDF, 210, mol)) + + if ("morse" in properties) or ("all" in properties): + # Molecule Representation of Structures based on Electron diffraction descriptors + classes_names.append("morse") + classes_start_idx.append(len(props)) + props.extend(get_prop_or_none(rdMD.CalcMORSE, 224, mol)) + + if ("whim" in properties) or ("all" in properties): + # WHIM descriptors are 3D structural descriptors obtained from the + # (x,y,z)‐atomic coordinates of a molecular conformation of a chemical, + # and are used successfully in QSAR modelling. + classes_names.append("whim") + classes_start_idx.append(len(props)) + props.extend(get_prop_or_none(rdMD.CalcWHIM, 114, mol)) + + return np.array(props), classes_start_idx, classes_names diff --git a/goli/features/spectral.py b/goli/features/spectral.py index c35a5c0ea..92eb12863 100644 --- a/goli/features/spectral.py +++ b/goli/features/spectral.py @@ -1,139 +1,139 @@ -from typing import Tuple, Union - -from scipy.sparse.linalg import eigs -from scipy.linalg import eig -from scipy.sparse import csr_matrix, diags, issparse, spmatrix -import numpy as np -import torch -import networkx as nx - -from goli.utils.tensor import is_dtype_torch_tensor, is_dtype_numpy_array - - -def compute_laplacian_positional_eigvecs( - adj: Union[np.ndarray, spmatrix], - num_pos: int, - disconnected_comp: bool = True, - normalization: str = "none", -) -> Tuple[np.ndarray, np.ndarray]: - - # Sparsify the adjacency patrix - if issparse(adj): - adj = adj.astype(np.float64) - else: - adj = csr_matrix(adj, dtype=np.float64) - - # Compute tha Laplacian, and normalize it - D = np.array(np.sum(adj, axis=1)).flatten() - D_mat = diags(D) - L = -adj + D_mat - L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization) - - if disconnected_comp: - # Get the list of connected components - components = list(nx.connected_components(nx.from_scipy_sparse_matrix(adj))) - eigvals_tile = np.zeros((L_norm.shape[0], num_pos), dtype=np.float64) - eigvecs = np.zeros_like(eigvals_tile) - - # Compute the eigenvectors for each connected component, and stack them together - for component in components: - comp = list(component) - this_L = L_norm[comp][:, comp] - this_eigvals, this_eigvecs = _get_positional_eigvecs(this_L, num_pos=num_pos) - eigvecs[comp, :] = this_eigvecs - eigvals_tile[comp, :] = this_eigvals - else: - eigvals, eigvecs = _get_positional_eigvecs(L, num_pos=num_pos) - eigvals_tile = np.tile(eigvals, (L_norm.shape[0], 1)) - - # Eigenvalues previously set to infinite are now set to 0 - eigvals_tile[np.isinf(eigvals_tile)] = 0 - - return eigvals_tile, eigvecs - - -def _get_positional_eigvecs(matrix, num_pos: int): - - mat_len = matrix.shape[0] - if num_pos < mat_len - 1: # Compute the k-lowest eigenvectors - eigvals, eigvecs = eigs(matrix, k=num_pos, which="SR", tol=0) - - else: # Compute all eigenvectors - - eigvals, eigvecs = eig(matrix.todense()) - - # Pad with non-sense eigenvectors if required - if num_pos > mat_len: - temp_EigVal = np.ones(num_pos - mat_len, dtype=np.float64) + float("inf") - temp_EigVec = np.zeros((mat_len, num_pos - mat_len), dtype=np.float64) - eigvals = np.concatenate([eigvals, temp_EigVal], axis=0) - eigvecs = np.concatenate([eigvecs, temp_EigVec], axis=1) - - # Sort and keep only the first `num_pos` elements - sort_idx = eigvals.argsort() - eigvals = eigvals[sort_idx] - eigvals = eigvals[:num_pos] - eigvecs = eigvecs[:, sort_idx] - eigvecs = eigvecs[:, :num_pos] - - # Normalize the eigvecs - eigvecs = eigvecs / (np.sqrt(np.sum(eigvecs ** 2, axis=0, keepdims=True)) + 1e-8) - - return eigvals, eigvecs - - -def normalize_matrix(matrix, degree_vector=None, normalization: str = None): - r""" - Normalize a given matrix using its degree vector - - Parameters - --------------- - - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - A square matrix representing either an Adjacency matrix or a Laplacian. - - degree_vector: torch.tensor(N) or np.ndarray(N) or None - A vector representing the degree of ``matrix``. - ``None`` is only accepted if ``normalization==None`` - - normalization: str or None, Default='none' - Normalization to use on the eig_matrix - - - 'none' or ``None``: no normalization - - - 'sym': Symmetric normalization ``D^-0.5 L D^-0.5`` - - - 'inv': Inverse normalization ``D^-1 L`` - - Returns - ----------- - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - The normalized matrix - - """ - - # Transform the degree vector into a matrix - if degree_vector is None: - if not ((normalization is None) or (normalization.lower() == "none")): - raise ValueError("`degree_vector` cannot be `None` if `normalization` is not `None`") - else: - if is_dtype_numpy_array(matrix.dtype): - degree_inv = np.expand_dims(degree_vector ** -0.5, axis=1) - degree_inv[np.isinf(degree_inv)] = 0 - elif is_dtype_torch_tensor(matrix.dtype): - degree_inv = torch.unsqueeze(degree_vector ** -0.5, dim=1) - degree_inv[torch.isinf(degree_inv)] = 0 - - # Compute the normalized matrix - if (normalization is None) or (normalization.lower() == "none"): - pass - elif normalization.lower() == "sym": - matrix = degree_inv * matrix * degree_inv.T - elif normalization.lower() == "inv": - matrix = (degree_inv ** 2) * matrix - else: - raise ValueError( - f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided' - ) - - return matrix +from typing import Tuple, Union + +from scipy.sparse.linalg import eigs +from scipy.linalg import eig +from scipy.sparse import csr_matrix, diags, issparse, spmatrix +import numpy as np +import torch +import networkx as nx + +from goli.utils.tensor import is_dtype_torch_tensor, is_dtype_numpy_array + + +def compute_laplacian_positional_eigvecs( + adj: Union[np.ndarray, spmatrix], + num_pos: int, + disconnected_comp: bool = True, + normalization: str = "none", +) -> Tuple[np.ndarray, np.ndarray]: + + # Sparsify the adjacency patrix + if issparse(adj): + adj = adj.astype(np.float64) + else: + adj = csr_matrix(adj, dtype=np.float64) + + # Compute tha Laplacian, and normalize it + D = np.array(np.sum(adj, axis=1)).flatten() + D_mat = diags(D) + L = -adj + D_mat + L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization) + + if disconnected_comp: + # Get the list of connected components + components = list(nx.connected_components(nx.from_scipy_sparse_matrix(adj))) + eigvals_tile = np.zeros((L_norm.shape[0], num_pos), dtype=np.float64) + eigvecs = np.zeros_like(eigvals_tile) + + # Compute the eigenvectors for each connected component, and stack them together + for component in components: + comp = list(component) + this_L = L_norm[comp][:, comp] + this_eigvals, this_eigvecs = _get_positional_eigvecs(this_L, num_pos=num_pos) + eigvecs[comp, :] = this_eigvecs + eigvals_tile[comp, :] = this_eigvals + else: + eigvals, eigvecs = _get_positional_eigvecs(L, num_pos=num_pos) + eigvals_tile = np.tile(eigvals, (L_norm.shape[0], 1)) + + # Eigenvalues previously set to infinite are now set to 0 + eigvals_tile[np.isinf(eigvals_tile)] = 0 + + return eigvals_tile, eigvecs + + +def _get_positional_eigvecs(matrix, num_pos: int): + + mat_len = matrix.shape[0] + if num_pos < mat_len - 1: # Compute the k-lowest eigenvectors + eigvals, eigvecs = eigs(matrix, k=num_pos, which="SR", tol=0) + + else: # Compute all eigenvectors + + eigvals, eigvecs = eig(matrix.todense()) + + # Pad with non-sense eigenvectors if required + if num_pos > mat_len: + temp_EigVal = np.ones(num_pos - mat_len, dtype=np.float64) + float("inf") + temp_EigVec = np.zeros((mat_len, num_pos - mat_len), dtype=np.float64) + eigvals = np.concatenate([eigvals, temp_EigVal], axis=0) + eigvecs = np.concatenate([eigvecs, temp_EigVec], axis=1) + + # Sort and keep only the first `num_pos` elements + sort_idx = eigvals.argsort() + eigvals = eigvals[sort_idx] + eigvals = eigvals[:num_pos] + eigvecs = eigvecs[:, sort_idx] + eigvecs = eigvecs[:, :num_pos] + + # Normalize the eigvecs + eigvecs = eigvecs / (np.sqrt(np.sum(eigvecs ** 2, axis=0, keepdims=True)) + 1e-8) + + return eigvals, eigvecs + + +def normalize_matrix(matrix, degree_vector=None, normalization: str = None): + r""" + Normalize a given matrix using its degree vector + + Parameters + --------------- + + matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) + A square matrix representing either an Adjacency matrix or a Laplacian. + + degree_vector: torch.tensor(N) or np.ndarray(N) or None + A vector representing the degree of ``matrix``. + ``None`` is only accepted if ``normalization==None`` + + normalization: str or None, Default='none' + Normalization to use on the eig_matrix + + - 'none' or ``None``: no normalization + + - 'sym': Symmetric normalization ``D^-0.5 L D^-0.5`` + + - 'inv': Inverse normalization ``D^-1 L`` + + Returns + ----------- + matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) + The normalized matrix + + """ + + # Transform the degree vector into a matrix + if degree_vector is None: + if not ((normalization is None) or (normalization.lower() == "none")): + raise ValueError("`degree_vector` cannot be `None` if `normalization` is not `None`") + else: + if is_dtype_numpy_array(matrix.dtype): + degree_inv = np.expand_dims(degree_vector ** -0.5, axis=1) + degree_inv[np.isinf(degree_inv)] = 0 + elif is_dtype_torch_tensor(matrix.dtype): + degree_inv = torch.unsqueeze(degree_vector ** -0.5, dim=1) + degree_inv[torch.isinf(degree_inv)] = 0 + + # Compute the normalized matrix + if (normalization is None) or (normalization.lower() == "none"): + pass + elif normalization.lower() == "sym": + matrix = degree_inv * matrix * degree_inv.T + elif normalization.lower() == "inv": + matrix = (degree_inv ** 2) * matrix + else: + raise ValueError( + f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided' + ) + + return matrix diff --git a/goli/nn/__init__.py b/goli/nn/__init__.py index 1b0b8bc76..16e545e21 100644 --- a/goli/nn/__init__.py +++ b/goli/nn/__init__.py @@ -1,3 +1,3 @@ -from .architectures import FullDGLNetwork -from .architectures import FullDGLSiameseNetwork -from .architectures import FeedForwardNN +from .architectures import FullDGLNetwork +from .architectures import FullDGLSiameseNetwork +from .architectures import FeedForwardNN diff --git a/goli/nn/architectures.py b/goli/nn/architectures.py index 00df63521..b7fdd4853 100644 --- a/goli/nn/architectures.py +++ b/goli/nn/architectures.py @@ -1,1038 +1,1038 @@ -from torch import nn -import torch -import dgl -from typing import List, Dict, Tuple, Union, Callable, Any, Optional -import inspect - -from goli.nn.base_layers import FCLayer, get_activation -from goli.nn.dgl_layers import BaseDGLLayer -from goli.nn.residual_connections import ResidualConnectionBase -from goli.nn.dgl_layers.pooling import parse_pooling_layer, VirtualNode -from goli.utils.spaces import LAYERS_DICT, RESIDUALS_DICT - - -class FeedForwardNN(nn.Module): - def __init__( - self, - in_dim: int, - out_dim: int, - hidden_dims: Union[List[int], int], - depth: Optional[int] = None, - activation: Union[str, Callable] = "relu", - last_activation: Union[str, Callable] = "none", - dropout: float = 0.0, - last_dropout: float = 0.0, - batch_norm: bool = False, - last_batch_norm: bool = False, - residual_type: str = "none", - residual_skip_steps: int = 1, - name: str = "LNN", - layer_type: Union[str, nn.Module] = "fc", - layer_kwargs: Optional[Dict] = None, - ): - r""" - A flexible neural network architecture, with variable hidden dimensions, - support for multiple layer types, and support for different residual - connections. - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - hidden_dims: - Either an integer specifying all the hidden dimensions, - or a list of dimensions in the hidden layers. - Be careful, the "simple" residual type only supports - hidden dimensions of the same value. - - depth: - If `hidden_dims` is an integer, `depth` is 1 + the number of - hidden layers to use. If `hidden_dims` is a `list`, `depth` must - be `None`. - - activation: - activation function to use in the hidden layers. - - last_activation: - activation function to use in the last layer. - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - last_dropout: - The ratio of units to dropout for the last_layer. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - last_batch_norm: - Whether to use batch normalization in the last layer - - residual_type: - - "none": No residual connection - - "simple": Residual connection similar to the ResNet architecture. - See class `ResidualConnectionSimple` - - "weighted": Residual connection similar to the Resnet architecture, - but with weights applied before the summation. See class `ResidualConnectionWeighted` - - "concat": Residual connection where the residual is concatenated instead - of being added. - - "densenet": Residual connection where the residual of all previous layers - are concatenated. This leads to a strong increase in the number of parameters - if there are multiple hidden layers. - - residual_skip_steps: - The number of steps to skip between each residual connection. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - - name: - Name attributed to the current network, for display and printing - purposes. - - layer_type: - The type of layers to use in the network. - Either "fc" as the `FCLayer`, or a class representing the `nn.Module` - to use. - - layer_kwargs: - The arguments to be used in the initialization of the layer provided by `layer_type` - - """ - - super().__init__() - - # Set the class attributes - self.in_dim = in_dim - self.out_dim = out_dim - if isinstance(hidden_dims, int): - self.hidden_dims = [hidden_dims] * (depth - 1) - else: - self.hidden_dims = list(hidden_dims) - assert depth is None - self.depth = len(self.hidden_dims) + 1 - self.activation = get_activation(activation) - self.last_activation = get_activation(last_activation) - self.dropout = dropout - self.last_dropout = last_dropout - self.batch_norm = batch_norm - self.last_batch_norm = last_batch_norm - self.residual_type = None if residual_type is None else residual_type.lower() - self.residual_skip_steps = residual_skip_steps - self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} - self.name = name - - # Parse the layer and residuals - self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT) - self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) - - self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] - self._create_layers() - self._check_bad_arguments() - - def _check_bad_arguments(self): - r""" - Raise comprehensive errors if the arguments seem wrong - """ - if (self.residual_type == "simple") and not (self.hidden_dims[:-1] == self.hidden_dims[1:]): - raise ValueError( - f"When using the residual_type={self.residual_type}" - + f", all elements in the hidden_dims must be equal. Provided:{self.hidden_dims}" - ) - - def _parse_class_from_dict( - self, name_or_class: Union[type, str], class_dict: Dict[str, type] - ) -> Tuple[type, str]: - r""" - Register the hyperparameters for tracking by Pytorch-lightning - """ - if isinstance(name_or_class, str): - obj_name = name_or_class.lower() - obj_class = class_dict[obj_name] - elif callable(name_or_class): - obj_name = str(name_or_class) - obj_class = name_or_class - else: - raise TypeError(f"`name_or_class` must be str or callable, provided: {type(name_or_class)}") - - return obj_class, obj_name - - def _create_residual_connection(self, out_dims: List[int]) -> Tuple[ResidualConnectionBase, List[int]]: - r""" - Create the residual connection classes. - The out_dims is only used if the residual classes requires weights - """ - if self.residual_class.has_weights: - residual_layer = self.residual_class( - skip_steps=self.residual_skip_steps, - out_dims=out_dims, - dropout=self.dropout, - activation=self.activation, - batch_norm=self.batch_norm, - bias=False, - ) - else: - residual_layer = self.residual_class(skip_steps=self.residual_skip_steps) - - residual_out_dims = residual_layer.get_true_out_dims(self.full_dims[1:]) - - return residual_layer, residual_out_dims - - def _create_layers(self): - r""" - Create all the necessary layers for the network. - It's a bit complicated to explain what's going on in this function, - but it must manage the varying features sizes caused by: - - - The presence of different types of residual connections - """ - - self.residual_layer, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:]) - - # Create a ModuleList of the GNN layers - self.layers = nn.ModuleList() - this_in_dim = self.full_dims[0] - this_activation = self.activation - this_batch_norm = self.batch_norm - this_dropout = self.dropout - - for ii in range(self.depth): - this_out_dim = self.full_dims[ii + 1] - if ii == self.depth - 1: - this_activation = self.last_activation - this_batch_norm = self.last_batch_norm - this_dropout = self.last_dropout - - # Create the layer - self.layers.append( - self.layer_class( - in_dim=this_in_dim, - out_dim=this_out_dim, - activation=this_activation, - dropout=this_dropout, - batch_norm=this_batch_norm, - **self.layer_kwargs, - ) - ) - - if ii < len(residual_out_dims): - this_in_dim = residual_out_dims[ii] - - def forward(self, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the neural network on the input features. - - Parameters: - - h: `torch.Tensor[..., Din]`: - Input feature tensor, before the network. - `Din` is the number of input features - - Returns: - - `torch.Tensor[..., Dout]`: - Output feature tensor, after the network. - `Dout` is the number of output features - - """ - h_prev = None - for ii, layer in enumerate(self.layers): - h = layer.forward(h) - if ii < len(self.layers) - 1: - h, h_prev = self.residual_layer.forward(h, h_prev, step_idx=ii) - - return h - - def __repr__(self): - r""" - Controls how the class is printed - """ - class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n " - layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]" - - return class_str + layer_str - - -class FeedForwardDGL(FeedForwardNN): - def __init__( - self, - in_dim: int, - out_dim: int, - hidden_dims: List[int], - depth: Optional[int] = None, - activation: Union[str, Callable] = "relu", - last_activation: Union[str, Callable] = "none", - dropout: float = 0.0, - last_dropout: float = 0.0, - batch_norm: bool = False, - last_batch_norm: bool = False, - residual_type: str = "none", - residual_skip_steps: int = 1, - in_dim_edges: int = 0, - hidden_dims_edges: List[int] = [], - pooling: Union[List[str], List[Callable]] = ["sum"], - name: str = "GNN", - layer_type: Union[str, nn.Module] = "gcn", - layer_kwargs: Optional[Dict] = None, - virtual_node: str = "none", - ): - r""" - A flexible neural network architecture, with variable hidden dimensions, - support for multiple layer types, and support for different residual - connections. - - This class is meant to work with different DGL-based graph neural networks - layers. Any layer must inherit from `goli.nn.dgl_layers.BaseDGLLayer`. - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - hidden_dims: - List of dimensions in the hidden layers. - Be careful, the "simple" residual type only supports - hidden dimensions of the same value. - - depth: - If `hidden_dims` is an integer, `depth` is 1 + the number of - hidden layers to use. If `hidden_dims` is a `list`, `depth` must - be `None`. - - activation: - activation function to use in the hidden layers. - - last_activation: - activation function to use in the last layer. - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - last_dropout: - The ratio of units to dropout for the last layer. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - last_batch_norm: - Whether to use batch normalization in the last layer - - residual_type: - - "none": No residual connection - - "simple": Residual connection similar to the ResNet architecture. - See class `ResidualConnectionSimple` - - "weighted": Residual connection similar to the Resnet architecture, - but with weights applied before the summation. See class `ResidualConnectionWeighted` - - "concat": Residual connection where the residual is concatenated instead - of being added. - - "densenet": Residual connection where the residual of all previous layers - are concatenated. This leads to a strong increase in the number of parameters - if there are multiple hidden layers. - - residual_skip_steps: - The number of steps to skip between each residual connection. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - - in_dim_edges: - Input edge-feature dimensions of the network. Keep at 0 if not using - edge features, or if the layer doesn't support edges. - - hidden_dims_edges: - Hidden dimensions for the edges. Most models don't support it, so it - should only be used for those that do, i.e. `GatedGCNLayer` - - pooling: - The pooling types to use. Multiple pooling can be used, and their - results will be concatenated. - For node feature predictions, use `["none"]`. - For graph feature predictions see `goli.nn.dgl_layers.pooling.parse_pooling_layer`. - The list must either contain Callables, or the string below - - - "none": No pooling is applied - - "sum": `SumPooling` - - "mean": `MeanPooling` - - "max": `MaxPooling` - - "min": `MinPooling` - - "std": `StdPooling` - - "s2s": `Set2Set` - - name: - Name attributed to the current network, for display and printing - purposes. - - layer_type: - The type of layers to use in the network. - Either a class that inherits from `goli.nn.dgl_layers.BaseDGLLayer`, - or one of the following strings - - - "gcn": `GCNLayer` - - "gin": `GINLayer` - - "gat": `GATLayer` - - "gated-gcn": `GatedGCNLayer` - - "pna-conv": `PNAConvolutionalLayer` - - "pna-msgpass": `PNAMessagePassingLayer` - - "dgn-conv": `DGNConvolutionalLayer` - - "dgn-msgpass": `DGNMessagePassingLayer` - - layer_kwargs: - The arguments to be used in the initialization of the layer provided by `layer_type` - - virtual_node: - A string associated to the type of virtual node to use, - either `None`, "none", "mean", "sum", "max", "logsum". - See `goli.nn.dgl_layers.VirtualNode`. - - The virtual node will not use any residual connection if `residual_type` - is "none". Otherwise, it will use a simple ResNet like residual - connection. - - """ - - # Initialize the additional attributes - self.in_dim_edges = in_dim_edges - if isinstance(hidden_dims_edges, int): - self.hidden_dims_edges = [hidden_dims_edges] * (depth - 1) - elif len(hidden_dims_edges) == 0: - self.hidden_dims_edges = [] - else: - self.hidden_dims_edges = list(hidden_dims_edges) - assert depth is None - self.full_dims_edges = None - if len(self.hidden_dims_edges) > 0: - self.full_dims_edges = [self.in_dim_edges] + self.hidden_dims_edges + [self.hidden_dims_edges[-1]] - - self.virtual_node = virtual_node.lower() if virtual_node is not None else "none" - self.pooling = pooling - - # Initialize the parent `FeedForwardNN` - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - depth=depth, - activation=activation, - last_activation=last_activation, - batch_norm=batch_norm, - last_batch_norm=last_batch_norm, - residual_type=residual_type, - residual_skip_steps=residual_skip_steps, - name=name, - layer_type=layer_type, - dropout=dropout, - last_dropout=last_dropout, - layer_kwargs=layer_kwargs, - ) - - def _check_bad_arguments(self): - r""" - Raise comprehensive errors if the arguments seem wrong - """ - super()._check_bad_arguments() - if ( - (self.in_dim_edges > 0) or (self.full_dims_edges is not None) - ) and not self.layer_class.layer_supports_edges: - raise ValueError(f"Cannot use edge features with class `{self.layer_class}`") - - def _create_layers(self): - r""" - Create all the necessary layers for the network. - It's a bit complicated to explain what's going on in this function, - but it must manage the varying features sizes caused by: - - - The presence of different types of residual connections - - The presence or absence of edges - - The output dimensions varying for different networks i.e. `GatLayer` outputs different feature sizes according to the number of heads - - The presence or absence of virtual nodes - - The different possible pooling, and the concatenation of multiple pooling together. - """ - - residual_layer_temp, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:]) - - # Create a ModuleList of the GNN layers - self.layers = nn.ModuleList() - self.virtual_node_layers = nn.ModuleList() - this_in_dim = self.full_dims[0] - this_activation = self.activation - this_batch_norm = self.batch_norm - this_dropout = self.dropout - - # Find the appropriate edge dimensions, depending if edges are used, - # And if the residual is required for the edges - this_in_dim_edges, this_out_dim_edges = None, None - if self.full_dims_edges is not None: - this_in_dim_edges, this_out_dim_edges = self.full_dims_edges[0:2] - residual_out_dims_edges = residual_layer_temp.get_true_out_dims(self.full_dims_edges[1:]) - elif self.in_dim_edges > 0: - this_in_dim_edges = self.in_dim_edges - layer_out_dims_edges = [] - - # Create all the layers in a loop - for ii in range(self.depth): - this_out_dim = self.full_dims[ii + 1] - - if ii == self.depth - 1: - this_activation = self.last_activation - - # Find the edge key-word arguments depending on the layer type and residual connection - this_edge_kwargs = {} - if self.layer_class.layer_supports_edges and self.in_dim_edges > 0: - this_edge_kwargs["in_dim_edges"] = this_in_dim_edges - if "out_dim_edges" in inspect.signature(self.layer_class.__init__).parameters.keys(): - layer_out_dims_edges.append(self.full_dims_edges[ii + 1]) - this_edge_kwargs["out_dim_edges"] = layer_out_dims_edges[-1] - - # Create the GNN layer - self.layers.append( - self.layer_class( - in_dim=this_in_dim, - out_dim=this_out_dim, - activation=this_activation, - dropout=this_dropout, - batch_norm=this_batch_norm, - **self.layer_kwargs, - **this_edge_kwargs, - ) - ) - - # Create the Virtual Node layer, except at the last layer - if ii < len(residual_out_dims): - self.virtual_node_layers.append( - VirtualNode( - dim=this_out_dim * self.layers[-1].out_dim_factor, - activation=this_activation, - dropout=this_dropout, - batch_norm=this_batch_norm, - bias=True, - vn_type=self.virtual_node, - residual=self.residual_type is not None, - ) - ) - - # Get the true input dimension of the next layer, - # by factoring both the residual connection and GNN layer type - if ii < len(residual_out_dims): - this_in_dim = residual_out_dims[ii] * self.layers[ii - 1].out_dim_factor - if self.full_dims_edges is not None: - this_in_dim_edges = residual_out_dims_edges[ii] * self.layers[ii - 1].out_dim_factor - - layer_out_dims = [layer.out_dim_factor * layer.out_dim for layer in self.layers] - - # Initialize residual and pooling layers - self.residual_layer, _ = self._create_residual_connection(out_dims=layer_out_dims) - if len(layer_out_dims_edges) > 0: - self.residual_edges_layer, _ = self._create_residual_connection(out_dims=layer_out_dims_edges) - else: - self.residual_edges_layer = None - self.global_pool_layer, out_pool_dim = parse_pooling_layer(layer_out_dims[-1], self.pooling) - - # Output linear layer - self.out_linear = FCLayer( - in_dim=out_pool_dim, - out_dim=self.out_dim, - activation="none", - dropout=self.dropout, - batch_norm=self.batch_norm, - ) - - def _pool_layer_forward(self, g, h): - r""" - Apply the graph pooling layer, followed by the linear output layer. - - Parameters: - - g: dgl.DGLGraph - graph on which the convolution is done - - h (torch.Tensor[..., N, Din]): - Node feature tensor, before convolution. - `N` is the number of nodes, `Din` is the output size of the last DGL layer - - Returns: - - torch.Tensor[..., M, Din] or torch.Tensor[..., N, Din]: - Node feature tensor, after convolution. - `N` is the number of nodes, `M` is the number of graphs, `Dout` is the output dimension ``self.out_dim`` - If the pooling is `None`, the dimension is `N`, otherwise it is `M` - - """ - - if len(self.global_pool_layer) > 0: - pooled_h = [] - for this_pool in self.global_pool_layer: - pooled_h.append(this_pool(g, h)) - pooled_h = torch.cat(pooled_h, dim=-1) - else: - pooled_h = h - - pooled_h = self.out_linear(pooled_h) - - return pooled_h - - def _dgl_layer_forward( - self, - layer: BaseDGLLayer, - g: dgl.DGLGraph, - h: torch.Tensor, - e: Union[torch.Tensor, None], - h_prev: Union[torch.Tensor, None], - e_prev: Union[torch.Tensor, None], - step_idx: int, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]: - r""" - Apply the *i-th* DGL graph layer, where *i* is the index given by `step_idx`. - The layer is applied differently depending if there are edge features or not. - - Then, the residual is also applied on both the features and the edges (if applicable) - - Parameters: - - layer: - The DGL layer used for the convolution - - g: - graph on which the convolution is done - - h (torch.Tensor[..., N, Din]): - Node feature tensor, before convolution. - `N` is the number of nodes, `Din` is the input features - - e (torch.Tensor[..., N, Ein]): - Edge feature tensor, before convolution. - `N` is the number of nodes, `Ein` is the input edge features - - h_prev: - Node feature of the previous residual connection, or `None` - - e_prev: - Edge feature of the previous residual connection, or `None` - - step_idx: - The current step idx in the forward loop - - Returns: - - h (torch.Tensor[..., N, Dout]): - Node feature tensor, after convolution and residual. - `N` is the number of nodes, `Dout` is the output features of the layer and residual - - e: - Edge feature tensor, after convolution and residual. - `N` is the number of nodes, `Ein` is the input edge features - - h_prev: - Node feature tensor to be used at the next residual connection, or `None` - - e_prev: - Edge feature tensor to be used at the next residual connection, or `None` - - """ - - # Apply the GNN layer with the right inputs/outputs - if layer.layer_inputs_edges and layer.layer_outputs_edges: - h, e = layer(g=g, h=h, e=e) - elif layer.layer_inputs_edges: - h = layer(g=g, h=h, e=e) - elif layer.layer_outputs_edges: - h, e = layer(g=g, h=h) - else: - h = layer(g=g, h=h) - - # Apply the residual layers on the features and edges (if applicable) - if step_idx < len(self.layers) - 1: - h, h_prev = self.residual_layer.forward(h, h_prev, step_idx=step_idx) - if (self.residual_edges_layer is not None) and (layer.layer_outputs_edges): - e, e_prev = self.residual_edges_layer.forward(e, e_prev, step_idx=step_idx) - - return h, e, h_prev, e_prev - - def _virtual_node_forward( - self, g: dgl.DGLGraph, h: torch.Tensor, vn_h: torch.Tensor, step_idx: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Apply the *i-th* virtual node layer, where *i* is the index given by `step_idx`. - - Parameters: - - g: - graph on which the convolution is done - - h (torch.Tensor[..., N, Din]): - Node feature tensor, before convolution. - `N` is the number of nodes, `Din` is the input features - - vn_h (torch.Tensor[..., M, Din]): - Graph feature of the previous virtual node, or `None` - `M` is the number of graphs, `Din` is the input features - It is added to the result after the MLP, as a residual connection - - step_idx: - The current step idx in the forward loop - - Returns: - - `h = torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution and residual. - `N` is the number of nodes, `Dout` is the output features of the layer and residual - - `vn_h = torch.Tensor[..., M, Dout]`: - Graph feature tensor to be used at the next virtual node, or `None` - `M` is the number of graphs, `Dout` is the output features - - """ - - if step_idx == 0: - vn_h = 0.0 - if step_idx < len(self.virtual_node_layers): - h, vn_h = self.virtual_node_layers[step_idx].forward(g=g, h=h, vn_h=vn_h) - - return h, vn_h - - def forward(self, g: dgl.DGLGraph) -> torch.Tensor: - r""" - Apply the full graph neural network on the input graph and node features. - - Parameters: - - g: - graph on which the convolution is done. - Must contain the following elements: - - - `g.ndata["h"]`: `torch.Tensor[..., N, Din]`. - Input node feature tensor, before the network. - `N` is the number of nodes, `Din` is the input features - - - `g.edata["e"]`: `torch.Tensor[..., N, Ein]` **Optional**. - The edge features to use. It will be ignored if the - model doesn't supporte edge features or if - `self.in_dim_edges==0`. - - Returns: - - `torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`: - Node or graph feature tensor, after the network. - `N` is the number of nodes, `M` is the number of graphs, - `Dout` is the output dimension ``self.out_dim`` - If the `self.pooling` is [`None`], then it returns node features and the output dimension is `N`, - otherwise it returns graph features and the output dimension is `M` - - """ - - # Get node and edge features - h = g.ndata["h"] - e = g.edata["e"] if (self.in_dim_edges > 0) else None - - # Initialize values of the residuals and virtual node - h_prev = None - e_prev = None - vn_h = 0 - - # Apply the forward loop of the layers, residuals and virtual nodes - for ii, layer in enumerate(self.layers): - h, e, h_prev, e_prev = self._dgl_layer_forward( - layer=layer, g=g, h=h, e=e, h_prev=h_prev, e_prev=e_prev, step_idx=ii - ) - h, vn_h = self._virtual_node_forward(g=g, h=h, vn_h=vn_h, step_idx=ii) - - pooled_h = self._pool_layer_forward(g=g, h=h) - - return pooled_h - - def __repr__(self): - r""" - Controls how the class is printed - """ - class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n " - layer_str = f"{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]\n " - pool_str = f"-> Pooling({self.pooling})" - out_str = f" -> {self.out_linear}" - - return class_str + layer_str + pool_str + out_str - - -class FullDGLNetwork(nn.Module): - def __init__( - self, - gnn_kwargs: Dict[str, Any], - pre_nn_kwargs: Optional[Dict[str, Any]] = None, - pre_nn_edges_kwargs: Optional[Dict[str, Any]] = None, - post_nn_kwargs: Optional[Dict[str, Any]] = None, - name: str = "DGL_GNN", - ): - r""" - Class that allows to implement a full graph neural network architecture, - including the pre-processing MLP and the post processing MLP. - - Parameters: - - gnn_kwargs: - key-word arguments to use for the initialization of the pre-processing - GNN network using the class `FeedForwardDGL`. - It must respect the following criteria: - - - gnn_kwargs["in_dim"] must be equal to pre_nn_kwargs["out_dim"] - - gnn_kwargs["out_dim"] must be equal to post_nn_kwargs["in_dim"] - - pre_nn_kwargs: - key-word arguments to use for the initialization of the pre-processing - MLP network of the node features before the GNN, using the class `FeedForwardNN`. - If `None`, there won't be a pre-processing MLP. - - pre_nn_kwargs: - key-word arguments to use for the initialization of the pre-processing - MLP network of the edge features before the GNN, using the class `FeedForwardNN`. - If `None`, there won't be a pre-processing MLP. - - post_nn_kwargs: - key-word arguments to use for the initialization of the post-processing - MLP network after the GNN, using the class `FeedForwardNN`. - If `None`, there won't be a post-processing MLP. - - name: - Name attributed to the current network, for display and printing - purposes. - """ - - super().__init__() - self.name = name - - # Initialize the networks - self.pre_nn, self.post_nn, self.pre_nn_edges = None, None, None - if pre_nn_kwargs is not None: - name = pre_nn_kwargs.pop("name", "pre-NN") - self.pre_nn = FeedForwardNN(**pre_nn_kwargs, name=name) - next_in_dim = self.pre_nn.out_dim - gnn_kwargs.setdefault("in_dim", next_in_dim) - assert next_in_dim == gnn_kwargs["in_dim"] - - if pre_nn_edges_kwargs is not None: - name = pre_nn_edges_kwargs.pop("name", "pre-NN-edges") - self.pre_nn_edges = FeedForwardNN(**pre_nn_edges_kwargs, name=name) - next_in_dim = self.pre_nn_edges.out_dim - gnn_kwargs.setdefault("in_dim_edges", next_in_dim) - assert next_in_dim == gnn_kwargs["in_dim_edges"] - - name = gnn_kwargs.pop("name", "GNN") - self.gnn = FeedForwardDGL(**gnn_kwargs, name=name) - next_in_dim = self.gnn.out_dim - - if post_nn_kwargs is not None: - name = post_nn_kwargs.pop("name", "post-NN") - post_nn_kwargs.setdefault("in_dim", next_in_dim) - self.post_nn = FeedForwardNN(**post_nn_kwargs, name=name) - assert next_in_dim == self.post_nn.in_dim - - def _check_bad_arguments(self): - r""" - Raise comprehensive errors if the arguments seem wrong - """ - if self.pre_nn is not None: - if self.pre_nn["out_dim"] != self.gnn["in_dim"]: - raise ValueError( - f"`self.pre_nn.out_dim` must be equal to `self.gnn.in_dim`." - + 'Provided" {self.pre_nn.out_dim} and {self.gnn.in_dim}' - ) - - if self.post_nn is not None: - if self.gnn["out_dim"] != self.post_nn["in_dim"]: - raise ValueError( - f"`self.gnn.out_dim` must be equal to `self.post_nn.in_dim`." - + 'Provided" {self.gnn.out_dim} and {self.post_nn.in_dim}' - ) - - def drop_post_nn_layers(self, num_layers_to_drop: int) -> None: - r""" - Remove the last layers of the model. Useful for Transfer Learning. - - Parameters: - num_layers_to_drop: The number of layers to drop from the `self.post_nn` network. - - """ - - assert num_layers_to_drop >= 0 - assert num_layers_to_drop <= len(self.post_nn.layers) - - if num_layers_to_drop > 0: - self.post_nn.layers = self.post_nn.layers[:-num_layers_to_drop] - - def extend_post_nn_layers(self, layers: nn.ModuleList): - r""" - Add layers at the end of the model. Useful for Transfer Learning. - - Parameters: - layers: A ModuleList of all the layers to extend - - """ - - assert isinstance(layers, nn.ModuleList) - if len(self.post_nn.layers) > 0: - assert layers[0].in_dim == self.post_nn.layers.out_dim[-1] - - self.post_nn.extend(layers) - - def forward(self, g: dgl.DGLGraph) -> torch.Tensor: - r""" - Apply the pre-processing neural network, the graph neural network, - and the post-processing neural network on the graph features. - - Parameters: - - g: - graph on which the convolution is done. - Must contain the following elements: - - - `g.ndata["h"]`: `torch.Tensor[..., N, Din]`. - Input node feature tensor, before the network. - `N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim`` - - - `g.edata["e"]`: `torch.Tensor[..., N, Ein]` **Optional**. - The edge features to use. It will be ignored if the - model doesn't supporte edge features or if - `self.in_dim_edges==0`. - - Returns: - - `torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`: - Node or graph feature tensor, after the network. - `N` is the number of nodes, `M` is the number of graphs, - `Dout` is the output dimension ``self.post_nn.out_dim`` - If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`, - otherwise it returns graph features and the output dimension is `M` - - """ - - # Get the node features and positional embedding - h = g.ndata["feat"] - if "pos_enc_feats_sign_flip" in g.ndata.keys(): - pos_enc = g.ndata["pos_enc_feats_sign_flip"] - rand_sign_shape = ([1] * (pos_enc.ndim - 1)) + [pos_enc.shape[-1]] - rand_sign = torch.sign(torch.randn(rand_sign_shape, dtype=h.dtype, device=h.device)) - h = torch.cat((h, pos_enc * rand_sign), dim=-1) - if "pos_enc_feats_no_flip" in g.ndata.keys(): - pos_enc = g.ndata["pos_enc_feats_no_flip"] - h = torch.cat((h, pos_enc), dim=-1) - - g.ndata["h"] = h - - if "feat" in g.edata.keys(): - g.edata["e"] = g.edata["feat"] - - # Run the pre-processing network on node features - if self.pre_nn is not None: - h = g.ndata["h"] - h = self.pre_nn.forward(h) - g.ndata["h"] = h - - # Run the pre-processing network on edge features - # If there are no edges, skip the forward and change the dimension of e - if self.pre_nn_edges is not None: - e = g.edata["e"] - if torch.prod(torch.as_tensor(e.shape[:-1])) == 0: - e = torch.zeros( - list(e.shape[:-1]) + [self.pre_nn_edges.out_dim], device=e.device, dtype=e.dtype - ) - else: - e = self.pre_nn_edges.forward(e) - g.edata["e"] = e - - # Run the graph neural network - h = self.gnn.forward(g) - - # Run the output network - if self.post_nn is not None: - h = self.post_nn.forward(h) - - return h - - def __repr__(self): - r""" - Controls how the class is printed - """ - pre_nn_str, post_nn_str, pre_nn_edges_str = "", "", "" - if self.pre_nn is not None: - pre_nn_str = self.pre_nn.__repr__() + "\n\n" - if self.pre_nn_edges is not None: - pre_nn_edges_str = self.pre_nn_edges.__repr__() + "\n\n" - gnn_str = self.gnn.__repr__() + "\n\n" - if self.post_nn is not None: - post_nn_str = self.post_nn.__repr__() - - child_str = " " + pre_nn_str + pre_nn_edges_str + gnn_str + post_nn_str - child_str = " ".join(child_str.splitlines(True)) - - full_str = self.name + "\n" + "-" * (len(self.name) + 2) + "\n" + child_str - - return full_str - - @property - def in_dim(self): - r""" - Returns the input dimension of the network - """ - if self.pre_nn is not None: - return self.pre_nn.in_dim - else: - return self.gnn.in_dim - - @property - def out_dim(self): - r""" - Returns the output dimension of the network - """ - if self.pre_nn is not None: - return self.post_nn.out_dim - else: - return self.gnn.out_dim - - @property - def in_dim_edges(self): - r""" - Returns the input edge dimension of the network - """ - return self.gnn.in_dim_edges - - -class FullDGLSiameseNetwork(FullDGLNetwork): - def __init__(self, pre_nn_kwargs, gnn_kwargs, post_nn_kwargs, dist_method, name="Siamese_DGL_GNN"): - - # Initialize the parent nn.Module - super().__init__( - pre_nn_kwargs=pre_nn_kwargs, - gnn_kwargs=gnn_kwargs, - post_nn_kwargs=post_nn_kwargs, - name=name, - ) - - self.dist_method = dist_method.lower() - - def forward(self, graphs): - graph_1, graph_2 = graphs - - out_1 = super().forward(graph_1) - out_2 = super().forward(graph_2) - - if self.dist_method == "manhattan": - # Normalized L1 distance - out_1 = out_1 / torch.mean(out_1.abs(), dim=-1, keepdim=True) - out_2 = out_2 / torch.mean(out_2.abs(), dim=-1, keepdim=True) - dist = torch.abs(out_1 - out_2) - out = torch.mean(dist, dim=-1) - - elif self.dist_method == "euclidean": - # Normalized Euclidean distance - out_1 = out_1 / torch.norm(out_1, dim=-1, keepdim=True) - out_2 = out_2 / torch.norm(out_2, dim=-1, keepdim=True) - out = torch.norm(out_1 - out_2, dim=-1) - elif self.dist_method == "cosine": - # Cosine distance - out = torch.sum(out_1 * out_2, dim=-1) / (torch.norm(out_1, dim=-1) * torch.norm(out_2, dim=-1)) - else: - raise ValueError(f"Unsupported `dist_method`: {self.dist_method}") - - return out +from torch import nn +import torch +import dgl +from typing import List, Dict, Tuple, Union, Callable, Any, Optional +import inspect + +from goli.nn.base_layers import FCLayer, get_activation +from goli.nn.dgl_layers import BaseDGLLayer +from goli.nn.residual_connections import ResidualConnectionBase +from goli.nn.dgl_layers.pooling import parse_pooling_layer, VirtualNode +from goli.utils.spaces import LAYERS_DICT, RESIDUALS_DICT + + +class FeedForwardNN(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dims: Union[List[int], int], + depth: Optional[int] = None, + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + batch_norm: bool = False, + last_batch_norm: bool = False, + residual_type: str = "none", + residual_skip_steps: int = 1, + name: str = "LNN", + layer_type: Union[str, nn.Module] = "fc", + layer_kwargs: Optional[Dict] = None, + ): + r""" + A flexible neural network architecture, with variable hidden dimensions, + support for multiple layer types, and support for different residual + connections. + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + hidden_dims: + Either an integer specifying all the hidden dimensions, + or a list of dimensions in the hidden layers. + Be careful, the "simple" residual type only supports + hidden dimensions of the same value. + + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. If `hidden_dims` is a `list`, `depth` must + be `None`. + + activation: + activation function to use in the hidden layers. + + last_activation: + activation function to use in the last layer. + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + last_dropout: + The ratio of units to dropout for the last_layer. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + last_batch_norm: + Whether to use batch normalization in the last layer + + residual_type: + - "none": No residual connection + - "simple": Residual connection similar to the ResNet architecture. + See class `ResidualConnectionSimple` + - "weighted": Residual connection similar to the Resnet architecture, + but with weights applied before the summation. See class `ResidualConnectionWeighted` + - "concat": Residual connection where the residual is concatenated instead + of being added. + - "densenet": Residual connection where the residual of all previous layers + are concatenated. This leads to a strong increase in the number of parameters + if there are multiple hidden layers. + + residual_skip_steps: + The number of steps to skip between each residual connection. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + + name: + Name attributed to the current network, for display and printing + purposes. + + layer_type: + The type of layers to use in the network. + Either "fc" as the `FCLayer`, or a class representing the `nn.Module` + to use. + + layer_kwargs: + The arguments to be used in the initialization of the layer provided by `layer_type` + + """ + + super().__init__() + + # Set the class attributes + self.in_dim = in_dim + self.out_dim = out_dim + if isinstance(hidden_dims, int): + self.hidden_dims = [hidden_dims] * (depth - 1) + else: + self.hidden_dims = list(hidden_dims) + assert depth is None + self.depth = len(self.hidden_dims) + 1 + self.activation = get_activation(activation) + self.last_activation = get_activation(last_activation) + self.dropout = dropout + self.last_dropout = last_dropout + self.batch_norm = batch_norm + self.last_batch_norm = last_batch_norm + self.residual_type = None if residual_type is None else residual_type.lower() + self.residual_skip_steps = residual_skip_steps + self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} + self.name = name + + # Parse the layer and residuals + self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT) + self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) + + self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] + self._create_layers() + self._check_bad_arguments() + + def _check_bad_arguments(self): + r""" + Raise comprehensive errors if the arguments seem wrong + """ + if (self.residual_type == "simple") and not (self.hidden_dims[:-1] == self.hidden_dims[1:]): + raise ValueError( + f"When using the residual_type={self.residual_type}" + + f", all elements in the hidden_dims must be equal. Provided:{self.hidden_dims}" + ) + + def _parse_class_from_dict( + self, name_or_class: Union[type, str], class_dict: Dict[str, type] + ) -> Tuple[type, str]: + r""" + Register the hyperparameters for tracking by Pytorch-lightning + """ + if isinstance(name_or_class, str): + obj_name = name_or_class.lower() + obj_class = class_dict[obj_name] + elif callable(name_or_class): + obj_name = str(name_or_class) + obj_class = name_or_class + else: + raise TypeError(f"`name_or_class` must be str or callable, provided: {type(name_or_class)}") + + return obj_class, obj_name + + def _create_residual_connection(self, out_dims: List[int]) -> Tuple[ResidualConnectionBase, List[int]]: + r""" + Create the residual connection classes. + The out_dims is only used if the residual classes requires weights + """ + if self.residual_class.has_weights: + residual_layer = self.residual_class( + skip_steps=self.residual_skip_steps, + out_dims=out_dims, + dropout=self.dropout, + activation=self.activation, + batch_norm=self.batch_norm, + bias=False, + ) + else: + residual_layer = self.residual_class(skip_steps=self.residual_skip_steps) + + residual_out_dims = residual_layer.get_true_out_dims(self.full_dims[1:]) + + return residual_layer, residual_out_dims + + def _create_layers(self): + r""" + Create all the necessary layers for the network. + It's a bit complicated to explain what's going on in this function, + but it must manage the varying features sizes caused by: + + - The presence of different types of residual connections + """ + + self.residual_layer, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:]) + + # Create a ModuleList of the GNN layers + self.layers = nn.ModuleList() + this_in_dim = self.full_dims[0] + this_activation = self.activation + this_batch_norm = self.batch_norm + this_dropout = self.dropout + + for ii in range(self.depth): + this_out_dim = self.full_dims[ii + 1] + if ii == self.depth - 1: + this_activation = self.last_activation + this_batch_norm = self.last_batch_norm + this_dropout = self.last_dropout + + # Create the layer + self.layers.append( + self.layer_class( + in_dim=this_in_dim, + out_dim=this_out_dim, + activation=this_activation, + dropout=this_dropout, + batch_norm=this_batch_norm, + **self.layer_kwargs, + ) + ) + + if ii < len(residual_out_dims): + this_in_dim = residual_out_dims[ii] + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the neural network on the input features. + + Parameters: + + h: `torch.Tensor[..., Din]`: + Input feature tensor, before the network. + `Din` is the number of input features + + Returns: + + `torch.Tensor[..., Dout]`: + Output feature tensor, after the network. + `Dout` is the number of output features + + """ + h_prev = None + for ii, layer in enumerate(self.layers): + h = layer.forward(h) + if ii < len(self.layers) - 1: + h, h_prev = self.residual_layer.forward(h, h_prev, step_idx=ii) + + return h + + def __repr__(self): + r""" + Controls how the class is printed + """ + class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n " + layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]" + + return class_str + layer_str + + +class FeedForwardDGL(FeedForwardNN): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dims: List[int], + depth: Optional[int] = None, + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + batch_norm: bool = False, + last_batch_norm: bool = False, + residual_type: str = "none", + residual_skip_steps: int = 1, + in_dim_edges: int = 0, + hidden_dims_edges: List[int] = [], + pooling: Union[List[str], List[Callable]] = ["sum"], + name: str = "GNN", + layer_type: Union[str, nn.Module] = "gcn", + layer_kwargs: Optional[Dict] = None, + virtual_node: str = "none", + ): + r""" + A flexible neural network architecture, with variable hidden dimensions, + support for multiple layer types, and support for different residual + connections. + + This class is meant to work with different DGL-based graph neural networks + layers. Any layer must inherit from `goli.nn.dgl_layers.BaseDGLLayer`. + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + hidden_dims: + List of dimensions in the hidden layers. + Be careful, the "simple" residual type only supports + hidden dimensions of the same value. + + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. If `hidden_dims` is a `list`, `depth` must + be `None`. + + activation: + activation function to use in the hidden layers. + + last_activation: + activation function to use in the last layer. + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + last_dropout: + The ratio of units to dropout for the last layer. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + last_batch_norm: + Whether to use batch normalization in the last layer + + residual_type: + - "none": No residual connection + - "simple": Residual connection similar to the ResNet architecture. + See class `ResidualConnectionSimple` + - "weighted": Residual connection similar to the Resnet architecture, + but with weights applied before the summation. See class `ResidualConnectionWeighted` + - "concat": Residual connection where the residual is concatenated instead + of being added. + - "densenet": Residual connection where the residual of all previous layers + are concatenated. This leads to a strong increase in the number of parameters + if there are multiple hidden layers. + + residual_skip_steps: + The number of steps to skip between each residual connection. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + + in_dim_edges: + Input edge-feature dimensions of the network. Keep at 0 if not using + edge features, or if the layer doesn't support edges. + + hidden_dims_edges: + Hidden dimensions for the edges. Most models don't support it, so it + should only be used for those that do, i.e. `GatedGCNLayer` + + pooling: + The pooling types to use. Multiple pooling can be used, and their + results will be concatenated. + For node feature predictions, use `["none"]`. + For graph feature predictions see `goli.nn.dgl_layers.pooling.parse_pooling_layer`. + The list must either contain Callables, or the string below + + - "none": No pooling is applied + - "sum": `SumPooling` + - "mean": `MeanPooling` + - "max": `MaxPooling` + - "min": `MinPooling` + - "std": `StdPooling` + - "s2s": `Set2Set` + + name: + Name attributed to the current network, for display and printing + purposes. + + layer_type: + The type of layers to use in the network. + Either a class that inherits from `goli.nn.dgl_layers.BaseDGLLayer`, + or one of the following strings + + - "gcn": `GCNLayer` + - "gin": `GINLayer` + - "gat": `GATLayer` + - "gated-gcn": `GatedGCNLayer` + - "pna-conv": `PNAConvolutionalLayer` + - "pna-msgpass": `PNAMessagePassingLayer` + - "dgn-conv": `DGNConvolutionalLayer` + - "dgn-msgpass": `DGNMessagePassingLayer` + + layer_kwargs: + The arguments to be used in the initialization of the layer provided by `layer_type` + + virtual_node: + A string associated to the type of virtual node to use, + either `None`, "none", "mean", "sum", "max", "logsum". + See `goli.nn.dgl_layers.VirtualNode`. + + The virtual node will not use any residual connection if `residual_type` + is "none". Otherwise, it will use a simple ResNet like residual + connection. + + """ + + # Initialize the additional attributes + self.in_dim_edges = in_dim_edges + if isinstance(hidden_dims_edges, int): + self.hidden_dims_edges = [hidden_dims_edges] * (depth - 1) + elif len(hidden_dims_edges) == 0: + self.hidden_dims_edges = [] + else: + self.hidden_dims_edges = list(hidden_dims_edges) + assert depth is None + self.full_dims_edges = None + if len(self.hidden_dims_edges) > 0: + self.full_dims_edges = [self.in_dim_edges] + self.hidden_dims_edges + [self.hidden_dims_edges[-1]] + + self.virtual_node = virtual_node.lower() if virtual_node is not None else "none" + self.pooling = pooling + + # Initialize the parent `FeedForwardNN` + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + depth=depth, + activation=activation, + last_activation=last_activation, + batch_norm=batch_norm, + last_batch_norm=last_batch_norm, + residual_type=residual_type, + residual_skip_steps=residual_skip_steps, + name=name, + layer_type=layer_type, + dropout=dropout, + last_dropout=last_dropout, + layer_kwargs=layer_kwargs, + ) + + def _check_bad_arguments(self): + r""" + Raise comprehensive errors if the arguments seem wrong + """ + super()._check_bad_arguments() + if ( + (self.in_dim_edges > 0) or (self.full_dims_edges is not None) + ) and not self.layer_class.layer_supports_edges: + raise ValueError(f"Cannot use edge features with class `{self.layer_class}`") + + def _create_layers(self): + r""" + Create all the necessary layers for the network. + It's a bit complicated to explain what's going on in this function, + but it must manage the varying features sizes caused by: + + - The presence of different types of residual connections + - The presence or absence of edges + - The output dimensions varying for different networks i.e. `GatLayer` outputs different feature sizes according to the number of heads + - The presence or absence of virtual nodes + - The different possible pooling, and the concatenation of multiple pooling together. + """ + + residual_layer_temp, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:]) + + # Create a ModuleList of the GNN layers + self.layers = nn.ModuleList() + self.virtual_node_layers = nn.ModuleList() + this_in_dim = self.full_dims[0] + this_activation = self.activation + this_batch_norm = self.batch_norm + this_dropout = self.dropout + + # Find the appropriate edge dimensions, depending if edges are used, + # And if the residual is required for the edges + this_in_dim_edges, this_out_dim_edges = None, None + if self.full_dims_edges is not None: + this_in_dim_edges, this_out_dim_edges = self.full_dims_edges[0:2] + residual_out_dims_edges = residual_layer_temp.get_true_out_dims(self.full_dims_edges[1:]) + elif self.in_dim_edges > 0: + this_in_dim_edges = self.in_dim_edges + layer_out_dims_edges = [] + + # Create all the layers in a loop + for ii in range(self.depth): + this_out_dim = self.full_dims[ii + 1] + + if ii == self.depth - 1: + this_activation = self.last_activation + + # Find the edge key-word arguments depending on the layer type and residual connection + this_edge_kwargs = {} + if self.layer_class.layer_supports_edges and self.in_dim_edges > 0: + this_edge_kwargs["in_dim_edges"] = this_in_dim_edges + if "out_dim_edges" in inspect.signature(self.layer_class.__init__).parameters.keys(): + layer_out_dims_edges.append(self.full_dims_edges[ii + 1]) + this_edge_kwargs["out_dim_edges"] = layer_out_dims_edges[-1] + + # Create the GNN layer + self.layers.append( + self.layer_class( + in_dim=this_in_dim, + out_dim=this_out_dim, + activation=this_activation, + dropout=this_dropout, + batch_norm=this_batch_norm, + **self.layer_kwargs, + **this_edge_kwargs, + ) + ) + + # Create the Virtual Node layer, except at the last layer + if ii < len(residual_out_dims): + self.virtual_node_layers.append( + VirtualNode( + dim=this_out_dim * self.layers[-1].out_dim_factor, + activation=this_activation, + dropout=this_dropout, + batch_norm=this_batch_norm, + bias=True, + vn_type=self.virtual_node, + residual=self.residual_type is not None, + ) + ) + + # Get the true input dimension of the next layer, + # by factoring both the residual connection and GNN layer type + if ii < len(residual_out_dims): + this_in_dim = residual_out_dims[ii] * self.layers[ii - 1].out_dim_factor + if self.full_dims_edges is not None: + this_in_dim_edges = residual_out_dims_edges[ii] * self.layers[ii - 1].out_dim_factor + + layer_out_dims = [layer.out_dim_factor * layer.out_dim for layer in self.layers] + + # Initialize residual and pooling layers + self.residual_layer, _ = self._create_residual_connection(out_dims=layer_out_dims) + if len(layer_out_dims_edges) > 0: + self.residual_edges_layer, _ = self._create_residual_connection(out_dims=layer_out_dims_edges) + else: + self.residual_edges_layer = None + self.global_pool_layer, out_pool_dim = parse_pooling_layer(layer_out_dims[-1], self.pooling) + + # Output linear layer + self.out_linear = FCLayer( + in_dim=out_pool_dim, + out_dim=self.out_dim, + activation="none", + dropout=self.dropout, + batch_norm=self.batch_norm, + ) + + def _pool_layer_forward(self, g, h): + r""" + Apply the graph pooling layer, followed by the linear output layer. + + Parameters: + + g: dgl.DGLGraph + graph on which the convolution is done + + h (torch.Tensor[..., N, Din]): + Node feature tensor, before convolution. + `N` is the number of nodes, `Din` is the output size of the last DGL layer + + Returns: + + torch.Tensor[..., M, Din] or torch.Tensor[..., N, Din]: + Node feature tensor, after convolution. + `N` is the number of nodes, `M` is the number of graphs, `Dout` is the output dimension ``self.out_dim`` + If the pooling is `None`, the dimension is `N`, otherwise it is `M` + + """ + + if len(self.global_pool_layer) > 0: + pooled_h = [] + for this_pool in self.global_pool_layer: + pooled_h.append(this_pool(g, h)) + pooled_h = torch.cat(pooled_h, dim=-1) + else: + pooled_h = h + + pooled_h = self.out_linear(pooled_h) + + return pooled_h + + def _dgl_layer_forward( + self, + layer: BaseDGLLayer, + g: dgl.DGLGraph, + h: torch.Tensor, + e: Union[torch.Tensor, None], + h_prev: Union[torch.Tensor, None], + e_prev: Union[torch.Tensor, None], + step_idx: int, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]: + r""" + Apply the *i-th* DGL graph layer, where *i* is the index given by `step_idx`. + The layer is applied differently depending if there are edge features or not. + + Then, the residual is also applied on both the features and the edges (if applicable) + + Parameters: + + layer: + The DGL layer used for the convolution + + g: + graph on which the convolution is done + + h (torch.Tensor[..., N, Din]): + Node feature tensor, before convolution. + `N` is the number of nodes, `Din` is the input features + + e (torch.Tensor[..., N, Ein]): + Edge feature tensor, before convolution. + `N` is the number of nodes, `Ein` is the input edge features + + h_prev: + Node feature of the previous residual connection, or `None` + + e_prev: + Edge feature of the previous residual connection, or `None` + + step_idx: + The current step idx in the forward loop + + Returns: + + h (torch.Tensor[..., N, Dout]): + Node feature tensor, after convolution and residual. + `N` is the number of nodes, `Dout` is the output features of the layer and residual + + e: + Edge feature tensor, after convolution and residual. + `N` is the number of nodes, `Ein` is the input edge features + + h_prev: + Node feature tensor to be used at the next residual connection, or `None` + + e_prev: + Edge feature tensor to be used at the next residual connection, or `None` + + """ + + # Apply the GNN layer with the right inputs/outputs + if layer.layer_inputs_edges and layer.layer_outputs_edges: + h, e = layer(g=g, h=h, e=e) + elif layer.layer_inputs_edges: + h = layer(g=g, h=h, e=e) + elif layer.layer_outputs_edges: + h, e = layer(g=g, h=h) + else: + h = layer(g=g, h=h) + + # Apply the residual layers on the features and edges (if applicable) + if step_idx < len(self.layers) - 1: + h, h_prev = self.residual_layer.forward(h, h_prev, step_idx=step_idx) + if (self.residual_edges_layer is not None) and (layer.layer_outputs_edges): + e, e_prev = self.residual_edges_layer.forward(e, e_prev, step_idx=step_idx) + + return h, e, h_prev, e_prev + + def _virtual_node_forward( + self, g: dgl.DGLGraph, h: torch.Tensor, vn_h: torch.Tensor, step_idx: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Apply the *i-th* virtual node layer, where *i* is the index given by `step_idx`. + + Parameters: + + g: + graph on which the convolution is done + + h (torch.Tensor[..., N, Din]): + Node feature tensor, before convolution. + `N` is the number of nodes, `Din` is the input features + + vn_h (torch.Tensor[..., M, Din]): + Graph feature of the previous virtual node, or `None` + `M` is the number of graphs, `Din` is the input features + It is added to the result after the MLP, as a residual connection + + step_idx: + The current step idx in the forward loop + + Returns: + + `h = torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution and residual. + `N` is the number of nodes, `Dout` is the output features of the layer and residual + + `vn_h = torch.Tensor[..., M, Dout]`: + Graph feature tensor to be used at the next virtual node, or `None` + `M` is the number of graphs, `Dout` is the output features + + """ + + if step_idx == 0: + vn_h = 0.0 + if step_idx < len(self.virtual_node_layers): + h, vn_h = self.virtual_node_layers[step_idx].forward(g=g, h=h, vn_h=vn_h) + + return h, vn_h + + def forward(self, g: dgl.DGLGraph) -> torch.Tensor: + r""" + Apply the full graph neural network on the input graph and node features. + + Parameters: + + g: + graph on which the convolution is done. + Must contain the following elements: + + - `g.ndata["h"]`: `torch.Tensor[..., N, Din]`. + Input node feature tensor, before the network. + `N` is the number of nodes, `Din` is the input features + + - `g.edata["e"]`: `torch.Tensor[..., N, Ein]` **Optional**. + The edge features to use. It will be ignored if the + model doesn't supporte edge features or if + `self.in_dim_edges==0`. + + Returns: + + `torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`: + Node or graph feature tensor, after the network. + `N` is the number of nodes, `M` is the number of graphs, + `Dout` is the output dimension ``self.out_dim`` + If the `self.pooling` is [`None`], then it returns node features and the output dimension is `N`, + otherwise it returns graph features and the output dimension is `M` + + """ + + # Get node and edge features + h = g.ndata["h"] + e = g.edata["e"] if (self.in_dim_edges > 0) else None + + # Initialize values of the residuals and virtual node + h_prev = None + e_prev = None + vn_h = 0 + + # Apply the forward loop of the layers, residuals and virtual nodes + for ii, layer in enumerate(self.layers): + h, e, h_prev, e_prev = self._dgl_layer_forward( + layer=layer, g=g, h=h, e=e, h_prev=h_prev, e_prev=e_prev, step_idx=ii + ) + h, vn_h = self._virtual_node_forward(g=g, h=h, vn_h=vn_h, step_idx=ii) + + pooled_h = self._pool_layer_forward(g=g, h=h) + + return pooled_h + + def __repr__(self): + r""" + Controls how the class is printed + """ + class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n " + layer_str = f"{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]\n " + pool_str = f"-> Pooling({self.pooling})" + out_str = f" -> {self.out_linear}" + + return class_str + layer_str + pool_str + out_str + + +class FullDGLNetwork(nn.Module): + def __init__( + self, + gnn_kwargs: Dict[str, Any], + pre_nn_kwargs: Optional[Dict[str, Any]] = None, + pre_nn_edges_kwargs: Optional[Dict[str, Any]] = None, + post_nn_kwargs: Optional[Dict[str, Any]] = None, + name: str = "DGL_GNN", + ): + r""" + Class that allows to implement a full graph neural network architecture, + including the pre-processing MLP and the post processing MLP. + + Parameters: + + gnn_kwargs: + key-word arguments to use for the initialization of the pre-processing + GNN network using the class `FeedForwardDGL`. + It must respect the following criteria: + + - gnn_kwargs["in_dim"] must be equal to pre_nn_kwargs["out_dim"] + - gnn_kwargs["out_dim"] must be equal to post_nn_kwargs["in_dim"] + + pre_nn_kwargs: + key-word arguments to use for the initialization of the pre-processing + MLP network of the node features before the GNN, using the class `FeedForwardNN`. + If `None`, there won't be a pre-processing MLP. + + pre_nn_kwargs: + key-word arguments to use for the initialization of the pre-processing + MLP network of the edge features before the GNN, using the class `FeedForwardNN`. + If `None`, there won't be a pre-processing MLP. + + post_nn_kwargs: + key-word arguments to use for the initialization of the post-processing + MLP network after the GNN, using the class `FeedForwardNN`. + If `None`, there won't be a post-processing MLP. + + name: + Name attributed to the current network, for display and printing + purposes. + """ + + super().__init__() + self.name = name + + # Initialize the networks + self.pre_nn, self.post_nn, self.pre_nn_edges = None, None, None + if pre_nn_kwargs is not None: + name = pre_nn_kwargs.pop("name", "pre-NN") + self.pre_nn = FeedForwardNN(**pre_nn_kwargs, name=name) + next_in_dim = self.pre_nn.out_dim + gnn_kwargs.setdefault("in_dim", next_in_dim) + assert next_in_dim == gnn_kwargs["in_dim"] + + if pre_nn_edges_kwargs is not None: + name = pre_nn_edges_kwargs.pop("name", "pre-NN-edges") + self.pre_nn_edges = FeedForwardNN(**pre_nn_edges_kwargs, name=name) + next_in_dim = self.pre_nn_edges.out_dim + gnn_kwargs.setdefault("in_dim_edges", next_in_dim) + assert next_in_dim == gnn_kwargs["in_dim_edges"] + + name = gnn_kwargs.pop("name", "GNN") + self.gnn = FeedForwardDGL(**gnn_kwargs, name=name) + next_in_dim = self.gnn.out_dim + + if post_nn_kwargs is not None: + name = post_nn_kwargs.pop("name", "post-NN") + post_nn_kwargs.setdefault("in_dim", next_in_dim) + self.post_nn = FeedForwardNN(**post_nn_kwargs, name=name) + assert next_in_dim == self.post_nn.in_dim + + def _check_bad_arguments(self): + r""" + Raise comprehensive errors if the arguments seem wrong + """ + if self.pre_nn is not None: + if self.pre_nn["out_dim"] != self.gnn["in_dim"]: + raise ValueError( + f"`self.pre_nn.out_dim` must be equal to `self.gnn.in_dim`." + + 'Provided" {self.pre_nn.out_dim} and {self.gnn.in_dim}' + ) + + if self.post_nn is not None: + if self.gnn["out_dim"] != self.post_nn["in_dim"]: + raise ValueError( + f"`self.gnn.out_dim` must be equal to `self.post_nn.in_dim`." + + 'Provided" {self.gnn.out_dim} and {self.post_nn.in_dim}' + ) + + def drop_post_nn_layers(self, num_layers_to_drop: int) -> None: + r""" + Remove the last layers of the model. Useful for Transfer Learning. + + Parameters: + num_layers_to_drop: The number of layers to drop from the `self.post_nn` network. + + """ + + assert num_layers_to_drop >= 0 + assert num_layers_to_drop <= len(self.post_nn.layers) + + if num_layers_to_drop > 0: + self.post_nn.layers = self.post_nn.layers[:-num_layers_to_drop] + + def extend_post_nn_layers(self, layers: nn.ModuleList): + r""" + Add layers at the end of the model. Useful for Transfer Learning. + + Parameters: + layers: A ModuleList of all the layers to extend + + """ + + assert isinstance(layers, nn.ModuleList) + if len(self.post_nn.layers) > 0: + assert layers[0].in_dim == self.post_nn.layers.out_dim[-1] + + self.post_nn.extend(layers) + + def forward(self, g: dgl.DGLGraph) -> torch.Tensor: + r""" + Apply the pre-processing neural network, the graph neural network, + and the post-processing neural network on the graph features. + + Parameters: + + g: + graph on which the convolution is done. + Must contain the following elements: + + - `g.ndata["h"]`: `torch.Tensor[..., N, Din]`. + Input node feature tensor, before the network. + `N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim`` + + - `g.edata["e"]`: `torch.Tensor[..., N, Ein]` **Optional**. + The edge features to use. It will be ignored if the + model doesn't supporte edge features or if + `self.in_dim_edges==0`. + + Returns: + + `torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`: + Node or graph feature tensor, after the network. + `N` is the number of nodes, `M` is the number of graphs, + `Dout` is the output dimension ``self.post_nn.out_dim`` + If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`, + otherwise it returns graph features and the output dimension is `M` + + """ + + # Get the node features and positional embedding + h = g.ndata["feat"] + if "pos_enc_feats_sign_flip" in g.ndata.keys(): + pos_enc = g.ndata["pos_enc_feats_sign_flip"] + rand_sign_shape = ([1] * (pos_enc.ndim - 1)) + [pos_enc.shape[-1]] + rand_sign = torch.sign(torch.randn(rand_sign_shape, dtype=h.dtype, device=h.device)) + h = torch.cat((h, pos_enc * rand_sign), dim=-1) + if "pos_enc_feats_no_flip" in g.ndata.keys(): + pos_enc = g.ndata["pos_enc_feats_no_flip"] + h = torch.cat((h, pos_enc), dim=-1) + + g.ndata["h"] = h + + if "feat" in g.edata.keys(): + g.edata["e"] = g.edata["feat"] + + # Run the pre-processing network on node features + if self.pre_nn is not None: + h = g.ndata["h"] + h = self.pre_nn.forward(h) + g.ndata["h"] = h + + # Run the pre-processing network on edge features + # If there are no edges, skip the forward and change the dimension of e + if self.pre_nn_edges is not None: + e = g.edata["e"] + if torch.prod(torch.as_tensor(e.shape[:-1])) == 0: + e = torch.zeros( + list(e.shape[:-1]) + [self.pre_nn_edges.out_dim], device=e.device, dtype=e.dtype + ) + else: + e = self.pre_nn_edges.forward(e) + g.edata["e"] = e + + # Run the graph neural network + h = self.gnn.forward(g) + + # Run the output network + if self.post_nn is not None: + h = self.post_nn.forward(h) + + return h + + def __repr__(self): + r""" + Controls how the class is printed + """ + pre_nn_str, post_nn_str, pre_nn_edges_str = "", "", "" + if self.pre_nn is not None: + pre_nn_str = self.pre_nn.__repr__() + "\n\n" + if self.pre_nn_edges is not None: + pre_nn_edges_str = self.pre_nn_edges.__repr__() + "\n\n" + gnn_str = self.gnn.__repr__() + "\n\n" + if self.post_nn is not None: + post_nn_str = self.post_nn.__repr__() + + child_str = " " + pre_nn_str + pre_nn_edges_str + gnn_str + post_nn_str + child_str = " ".join(child_str.splitlines(True)) + + full_str = self.name + "\n" + "-" * (len(self.name) + 2) + "\n" + child_str + + return full_str + + @property + def in_dim(self): + r""" + Returns the input dimension of the network + """ + if self.pre_nn is not None: + return self.pre_nn.in_dim + else: + return self.gnn.in_dim + + @property + def out_dim(self): + r""" + Returns the output dimension of the network + """ + if self.pre_nn is not None: + return self.post_nn.out_dim + else: + return self.gnn.out_dim + + @property + def in_dim_edges(self): + r""" + Returns the input edge dimension of the network + """ + return self.gnn.in_dim_edges + + +class FullDGLSiameseNetwork(FullDGLNetwork): + def __init__(self, pre_nn_kwargs, gnn_kwargs, post_nn_kwargs, dist_method, name="Siamese_DGL_GNN"): + + # Initialize the parent nn.Module + super().__init__( + pre_nn_kwargs=pre_nn_kwargs, + gnn_kwargs=gnn_kwargs, + post_nn_kwargs=post_nn_kwargs, + name=name, + ) + + self.dist_method = dist_method.lower() + + def forward(self, graphs): + graph_1, graph_2 = graphs + + out_1 = super().forward(graph_1) + out_2 = super().forward(graph_2) + + if self.dist_method == "manhattan": + # Normalized L1 distance + out_1 = out_1 / torch.mean(out_1.abs(), dim=-1, keepdim=True) + out_2 = out_2 / torch.mean(out_2.abs(), dim=-1, keepdim=True) + dist = torch.abs(out_1 - out_2) + out = torch.mean(dist, dim=-1) + + elif self.dist_method == "euclidean": + # Normalized Euclidean distance + out_1 = out_1 / torch.norm(out_1, dim=-1, keepdim=True) + out_2 = out_2 / torch.norm(out_2, dim=-1, keepdim=True) + out = torch.norm(out_1 - out_2, dim=-1) + elif self.dist_method == "cosine": + # Cosine distance + out = torch.sum(out_1 * out_2, dim=-1) / (torch.norm(out_1, dim=-1) * torch.norm(out_2, dim=-1)) + else: + raise ValueError(f"Unsupported `dist_method`: {self.dist_method}") + + return out diff --git a/goli/nn/base_layers.py b/goli/nn/base_layers.py index def6ce937..8c6cd9cfb 100644 --- a/goli/nn/base_layers.py +++ b/goli/nn/base_layers.py @@ -1,317 +1,317 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from typing import Union, Callable - -SUPPORTED_ACTIVATION_MAP = {"ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus", "None"} - - -def get_activation(activation: Union[type(None), str, Callable]) -> Union[type(None), Callable]: - r""" - returns the activation function represented by the input string - - Parameters: - activation: Callable, `None`, or string with value: - "none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus" - - Returns: - Callable or None: The activation function - """ - if (activation is not None) and callable(activation): - # activation is already a function - return activation - - if (activation is None) or (activation.lower() == "none"): - return None - - # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation - activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()] - assert len(activation) == 1 and isinstance(activation[0], str), "Unhandled activation function" - activation = activation[0] - - return vars(torch.nn.modules.activation)[activation]() - - -class FCLayer(nn.Module): - def __init__( - self, - in_dim: int, - out_dim: int, - activation: Union[str, Callable] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - bias: bool = True, - init_fn: Union[type(None), Callable] = None, - ): - - r""" - A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. - The order in which transformations are applied is: - - - Dense Layer - - Activation - - Dropout (if applicable) - - Batch Normalization (if applicable) - - Parameters: - in_dim: - Input dimension of the layer (the torch.nn.Linear) - out_dim: - Output dimension of the layer. - dropout: - The ratio of units to dropout. No dropout by default. - activation: - Activation function to use. - batch_norm: - Whether to use batch normalization - bias: - Whether to enable bias in for the linear layer. - init_fn: - Initialization function to use for the weight of the layer. Default is - $$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$ - - Attributes: - dropout (int): - The ratio of units to dropout. - batch_norm (int): - Whether to use batch normalization - linear (torch.nn.Linear): - The linear layer - activation (torch.nn.Module): - The activation layer - init_fn (Callable): - Initialization function used for the weight of the layer - in_dim (int): - Input dimension of the linear layer - out_dim (int): - Output dimension of the linear layer - """ - - super().__init__() - - self.__params = locals() - del self.__params["__class__"] - del self.__params["self"] - self.in_dim = in_dim - self.out_dim = out_dim - self.bias = bias - self.linear = nn.Linear(in_dim, out_dim, bias=bias) - self.dropout = None - self.batch_norm = None - if batch_norm: - self.batch_norm = nn.BatchNorm1d(out_dim) - if dropout: - self.dropout = nn.Dropout(p=dropout) - self.activation = get_activation(activation) - self.init_fn = nn.init.xavier_uniform_ - - self.reset_parameters() - - def reset_parameters(self, init_fn=None): - init_fn = init_fn or self.init_fn - if init_fn is not None: - init_fn(self.linear.weight, 1 / self.in_dim) - if self.bias: - self.linear.bias.data.zero_() - - def forward(self, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the FC layer on the input features. - - Parameters: - - h: `torch.Tensor[..., Din]`: - Input feature tensor, before the FC. - `Din` is the number of input features - - Returns: - - `torch.Tensor[..., Dout]`: - Output feature tensor, after the FC. - `Dout` is the number of output features - - """ - - if torch.prod(torch.as_tensor(h.shape[:-1])) == 0: - h = torch.zeros(list(h.shape[:-1]) + [self.linear.out_features], device=h.device, dtype=h.dtype) - return h - - h = self.linear(h) - - if self.batch_norm is not None: - if h.shape[1] != self.out_dim: - h = self.batch_norm(h.transpose(1, 2)).transpose(1, 2) - else: - h = self.batch_norm(h) - if self.dropout is not None: - h = self.dropout(h) - if self.activation is not None: - h = self.activation(h) - - return h - - def __repr__(self): - return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}, activation={self.activation})" - - -class MLP(nn.Module): - def __init__( - self, - in_dim: int, - hidden_dim: int, - out_dim: int, - layers: int, - activation: Union[str, Callable] = "relu", - last_activation: Union[str, Callable] = "none", - dropout=0.0, - batch_norm=False, - last_batch_norm=False, - ): - r""" - Simple multi-layer perceptron, built of a series of FCLayers - - Parameters: - in_dim: - Input dimension of the MLP - hidden_dim: - Hidden dimension of the MLP. All hidden dimensions will have - the same number of parameters - out_dim: - Output dimension of the MLP. - layers: - Number of hidden layers - activation: - Activation function to use in all the layers except the last. - if `layers==1`, this parameter is ignored - last_activation: - Activation function to use in the last layer. - dropout: - The ratio of units to dropout. Must be between 0 and 1 - batch_norm: - Whether to use batch normalization in the hidden layers. - if `layers==1`, this parameter is ignored - last_batch_norm: - Whether to use batch normalization in the last layer - - """ - - super().__init__() - - self.in_dim = in_dim - self.hidden_dim = hidden_dim - self.out_dim = out_dim - - self.fully_connected = nn.ModuleList() - if layers <= 1: - self.fully_connected.append( - FCLayer( - in_dim, - out_dim, - activation=last_activation, - batch_norm=last_batch_norm, - dropout=dropout, - ) - ) - else: - self.fully_connected.append( - FCLayer( - in_dim, - hidden_dim, - activation=activation, - batch_norm=batch_norm, - dropout=dropout, - ) - ) - for _ in range(layers - 2): - self.fully_connected.append( - FCLayer( - hidden_dim, - hidden_dim, - activation=activation, - batch_norm=batch_norm, - dropout=dropout, - ) - ) - self.fully_connected.append( - FCLayer( - hidden_dim, - out_dim, - activation=last_activation, - batch_norm=last_batch_norm, - dropout=dropout, - ) - ) - - def forward(self, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the MLP on the input features. - - Parameters: - - h: `torch.Tensor[..., Din]`: - Input feature tensor, before the MLP. - `Din` is the number of input features - - Returns: - - `torch.Tensor[..., Dout]`: - Output feature tensor, after the MLP. - `Dout` is the number of output features - - """ - for fc in self.fully_connected: - h = fc(h) - return h - - def __repr__(self): - r""" - Controls how the class is printed - """ - return self.__class__.__name__ + " (" + str(self.in_dim) + " -> " + str(self.out_dim) + ")" - - -class GRU(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int): - r""" - Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself - - Parameters: - in_dim: - Input dimension of the GRU layer - hidden_dim: - Hidden dimension of the GRU layer. - """ - - super().__init__() - self.in_dim = in_dim - self.hidden_dim = hidden_dim - self.gru = nn.GRU(in_dim=in_dim, hidden_dim=hidden_dim) - - def forward(self, x, y): - r""" - Parameters: - x: `torch.Tensor[B, N, Din]` - where Din <= in_dim (difference is padded) - y: `torch.Tensor[B, N, Dh]` - where Dh <= hidden_dim (difference is padded) - - Returns: - torch.Tensor: `torch.Tensor[B, N, Dh]` - - """ - assert x.shape[-1] <= self.in_dim and y.shape[-1] <= self.hidden_dim - - (B, N, _) = x.shape - x = x.reshape(1, B * N, -1).contiguous() - y = y.reshape(1, B * N, -1).contiguous() - - # padding if necessary - if x.shape[-1] < self.in_dim: - x = F.pad(input=x, pad=[0, self.in_dim - x.shape[-1]], mode="constant", value=0) - if y.shape[-1] < self.hidden_dim: - y = F.pad(input=y, pad=[0, self.hidden_dim - y.shape[-1]], mode="constant", value=0) - - x = self.gru(x, y)[1] - x = x.reshape(B, N, -1) - return x +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Union, Callable + +SUPPORTED_ACTIVATION_MAP = {"ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus", "None"} + + +def get_activation(activation: Union[type(None), str, Callable]) -> Union[type(None), Callable]: + r""" + returns the activation function represented by the input string + + Parameters: + activation: Callable, `None`, or string with value: + "none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus" + + Returns: + Callable or None: The activation function + """ + if (activation is not None) and callable(activation): + # activation is already a function + return activation + + if (activation is None) or (activation.lower() == "none"): + return None + + # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation + activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()] + assert len(activation) == 1 and isinstance(activation[0], str), "Unhandled activation function" + activation = activation[0] + + return vars(torch.nn.modules.activation)[activation]() + + +class FCLayer(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + bias: bool = True, + init_fn: Union[type(None), Callable] = None, + ): + + r""" + A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. + The order in which transformations are applied is: + + - Dense Layer + - Activation + - Dropout (if applicable) + - Batch Normalization (if applicable) + + Parameters: + in_dim: + Input dimension of the layer (the torch.nn.Linear) + out_dim: + Output dimension of the layer. + dropout: + The ratio of units to dropout. No dropout by default. + activation: + Activation function to use. + batch_norm: + Whether to use batch normalization + bias: + Whether to enable bias in for the linear layer. + init_fn: + Initialization function to use for the weight of the layer. Default is + $$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$ + + Attributes: + dropout (int): + The ratio of units to dropout. + batch_norm (int): + Whether to use batch normalization + linear (torch.nn.Linear): + The linear layer + activation (torch.nn.Module): + The activation layer + init_fn (Callable): + Initialization function used for the weight of the layer + in_dim (int): + Input dimension of the linear layer + out_dim (int): + Output dimension of the linear layer + """ + + super().__init__() + + self.__params = locals() + del self.__params["__class__"] + del self.__params["self"] + self.in_dim = in_dim + self.out_dim = out_dim + self.bias = bias + self.linear = nn.Linear(in_dim, out_dim, bias=bias) + self.dropout = None + self.batch_norm = None + if batch_norm: + self.batch_norm = nn.BatchNorm1d(out_dim) + if dropout: + self.dropout = nn.Dropout(p=dropout) + self.activation = get_activation(activation) + self.init_fn = nn.init.xavier_uniform_ + + self.reset_parameters() + + def reset_parameters(self, init_fn=None): + init_fn = init_fn or self.init_fn + if init_fn is not None: + init_fn(self.linear.weight, 1 / self.in_dim) + if self.bias: + self.linear.bias.data.zero_() + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the FC layer on the input features. + + Parameters: + + h: `torch.Tensor[..., Din]`: + Input feature tensor, before the FC. + `Din` is the number of input features + + Returns: + + `torch.Tensor[..., Dout]`: + Output feature tensor, after the FC. + `Dout` is the number of output features + + """ + + if torch.prod(torch.as_tensor(h.shape[:-1])) == 0: + h = torch.zeros(list(h.shape[:-1]) + [self.linear.out_features], device=h.device, dtype=h.dtype) + return h + + h = self.linear(h) + + if self.batch_norm is not None: + if h.shape[1] != self.out_dim: + h = self.batch_norm(h.transpose(1, 2)).transpose(1, 2) + else: + h = self.batch_norm(h) + if self.dropout is not None: + h = self.dropout(h) + if self.activation is not None: + h = self.activation(h) + + return h + + def __repr__(self): + return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}, activation={self.activation})" + + +class MLP(nn.Module): + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + layers: int, + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout=0.0, + batch_norm=False, + last_batch_norm=False, + ): + r""" + Simple multi-layer perceptron, built of a series of FCLayers + + Parameters: + in_dim: + Input dimension of the MLP + hidden_dim: + Hidden dimension of the MLP. All hidden dimensions will have + the same number of parameters + out_dim: + Output dimension of the MLP. + layers: + Number of hidden layers + activation: + Activation function to use in all the layers except the last. + if `layers==1`, this parameter is ignored + last_activation: + Activation function to use in the last layer. + dropout: + The ratio of units to dropout. Must be between 0 and 1 + batch_norm: + Whether to use batch normalization in the hidden layers. + if `layers==1`, this parameter is ignored + last_batch_norm: + Whether to use batch normalization in the last layer + + """ + + super().__init__() + + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + self.fully_connected = nn.ModuleList() + if layers <= 1: + self.fully_connected.append( + FCLayer( + in_dim, + out_dim, + activation=last_activation, + batch_norm=last_batch_norm, + dropout=dropout, + ) + ) + else: + self.fully_connected.append( + FCLayer( + in_dim, + hidden_dim, + activation=activation, + batch_norm=batch_norm, + dropout=dropout, + ) + ) + for _ in range(layers - 2): + self.fully_connected.append( + FCLayer( + hidden_dim, + hidden_dim, + activation=activation, + batch_norm=batch_norm, + dropout=dropout, + ) + ) + self.fully_connected.append( + FCLayer( + hidden_dim, + out_dim, + activation=last_activation, + batch_norm=last_batch_norm, + dropout=dropout, + ) + ) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the MLP on the input features. + + Parameters: + + h: `torch.Tensor[..., Din]`: + Input feature tensor, before the MLP. + `Din` is the number of input features + + Returns: + + `torch.Tensor[..., Dout]`: + Output feature tensor, after the MLP. + `Dout` is the number of output features + + """ + for fc in self.fully_connected: + h = fc(h) + return h + + def __repr__(self): + r""" + Controls how the class is printed + """ + return self.__class__.__name__ + " (" + str(self.in_dim) + " -> " + str(self.out_dim) + ")" + + +class GRU(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + r""" + Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself + + Parameters: + in_dim: + Input dimension of the GRU layer + hidden_dim: + Hidden dimension of the GRU layer. + """ + + super().__init__() + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.gru = nn.GRU(in_dim=in_dim, hidden_dim=hidden_dim) + + def forward(self, x, y): + r""" + Parameters: + x: `torch.Tensor[B, N, Din]` + where Din <= in_dim (difference is padded) + y: `torch.Tensor[B, N, Dh]` + where Dh <= hidden_dim (difference is padded) + + Returns: + torch.Tensor: `torch.Tensor[B, N, Dh]` + + """ + assert x.shape[-1] <= self.in_dim and y.shape[-1] <= self.hidden_dim + + (B, N, _) = x.shape + x = x.reshape(1, B * N, -1).contiguous() + y = y.reshape(1, B * N, -1).contiguous() + + # padding if necessary + if x.shape[-1] < self.in_dim: + x = F.pad(input=x, pad=[0, self.in_dim - x.shape[-1]], mode="constant", value=0) + if y.shape[-1] < self.hidden_dim: + y = F.pad(input=y, pad=[0, self.hidden_dim - y.shape[-1]], mode="constant", value=0) + + x = self.gru(x, y)[1] + x = x.reshape(B, N, -1) + return x diff --git a/goli/nn/dgl_layers/__init__.py b/goli/nn/dgl_layers/__init__.py index 71e88f002..fe4b946c6 100644 --- a/goli/nn/dgl_layers/__init__.py +++ b/goli/nn/dgl_layers/__init__.py @@ -1,7 +1,7 @@ -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.nn.dgl_layers.gat_layer import GATLayer -from goli.nn.dgl_layers.gcn_layer import GCNLayer -from goli.nn.dgl_layers.gin_layer import GINLayer -from goli.nn.dgl_layers.gated_gcn_layer import GatedGCNLayer -from goli.nn.dgl_layers.pna_layer import PNAConvolutionalLayer, PNAMessagePassingLayer -from goli.nn.dgl_layers.dgn_layer import DGNConvolutionalLayer, DGNMessagePassingLayer +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.nn.dgl_layers.gat_layer import GATLayer +from goli.nn.dgl_layers.gcn_layer import GCNLayer +from goli.nn.dgl_layers.gin_layer import GINLayer +from goli.nn.dgl_layers.gated_gcn_layer import GatedGCNLayer +from goli.nn.dgl_layers.pna_layer import PNAConvolutionalLayer, PNAMessagePassingLayer +from goli.nn.dgl_layers.dgn_layer import DGNConvolutionalLayer, DGNMessagePassingLayer diff --git a/goli/nn/dgl_layers/base_dgl_layer.py b/goli/nn/dgl_layers/base_dgl_layer.py index ee1d1386a..b7ca1b630 100644 --- a/goli/nn/dgl_layers/base_dgl_layer.py +++ b/goli/nn/dgl_layers/base_dgl_layer.py @@ -1,175 +1,175 @@ -import torch -import torch.nn as nn -import abc -from typing import List, Dict, Tuple, Union, Callable - -from goli.nn.base_layers import get_activation -from goli.utils.decorators import classproperty - - -class BaseDGLLayer(nn.Module): - def __init__( - self, - in_dim: int, - out_dim: int, - activation: Union[str, Callable] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - ): - r""" - Abstract class used to standardize the implementation of DGL layers - in the current library. It will allow a network to seemlesly swap between - different GNN layers by better understanding the expected inputs - and outputs. - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - """ - - super().__init__() - - # Basic attributes - self.in_dim = in_dim - self.out_dim = out_dim - self.batch_norm = batch_norm - self.dropout = dropout - self.activation = activation - - # Build the layers - self.activation_layer = get_activation(activation) - - self.dropout_layer = None - if dropout > 0: - self.dropout_layer = nn.Dropout(p=dropout) - - self.batch_norm_layer = None - if batch_norm: - self.batch_norm_layer = nn.BatchNorm1d(out_dim * self.out_dim_factor) - - def apply_norm_activation_dropout( - self, h: torch.Tensor, batch_norm: bool = True, activation: bool = True, dropout: bool = True - ): - r""" - Apply the different normalization and the dropout to the - output layer. - - Parameters: - - h: - Feature tensor, to be normalized - - batch_norm: - Whether to apply the batch_norm layer - - activation: - Whether to apply the activation layer - - dropout: - Whether to apply the dropout layer - - Returns: - - h: - Normalized and dropped-out features - - """ - - if batch_norm and (self.batch_norm_layer is not None): - h = self.batch_norm_layer(h) - - if activation and (self.activation_layer is not None): - h = self.activation_layer(h) - - if dropout and (self.dropout_layer is not None): - h = self.dropout_layer(h) - - return h - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - supports output edges edges or not. - - Returns: - - bool: - Whether the layer supports the use of edges - """ - ... - - @property - @abc.abstractmethod - def layer_inputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_input_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Whether the layer uses input edges in the forward pass - """ - ... - - @property - @abc.abstractmethod - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_output_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Whether the layer outputs edges in the forward pass - """ - ... - - @property - @abc.abstractmethod - def out_dim_factor(self) -> int: - r""" - Abstract method. - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - int: - The factor that multiplies the output dimensions - """ - ... - - def __repr__(self): - r""" - Controls how the class is printed - """ - f = self.out_dim_factor - out_dim_f_print = "" if f == 1 else f" * {f}" - return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}{out_dim_f_print}, activation={self.activation})" +import torch +import torch.nn as nn +import abc +from typing import List, Dict, Tuple, Union, Callable + +from goli.nn.base_layers import get_activation +from goli.utils.decorators import classproperty + + +class BaseDGLLayer(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + ): + r""" + Abstract class used to standardize the implementation of DGL layers + in the current library. It will allow a network to seemlesly swap between + different GNN layers by better understanding the expected inputs + and outputs. + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + """ + + super().__init__() + + # Basic attributes + self.in_dim = in_dim + self.out_dim = out_dim + self.batch_norm = batch_norm + self.dropout = dropout + self.activation = activation + + # Build the layers + self.activation_layer = get_activation(activation) + + self.dropout_layer = None + if dropout > 0: + self.dropout_layer = nn.Dropout(p=dropout) + + self.batch_norm_layer = None + if batch_norm: + self.batch_norm_layer = nn.BatchNorm1d(out_dim * self.out_dim_factor) + + def apply_norm_activation_dropout( + self, h: torch.Tensor, batch_norm: bool = True, activation: bool = True, dropout: bool = True + ): + r""" + Apply the different normalization and the dropout to the + output layer. + + Parameters: + + h: + Feature tensor, to be normalized + + batch_norm: + Whether to apply the batch_norm layer + + activation: + Whether to apply the activation layer + + dropout: + Whether to apply the dropout layer + + Returns: + + h: + Normalized and dropped-out features + + """ + + if batch_norm and (self.batch_norm_layer is not None): + h = self.batch_norm_layer(h) + + if activation and (self.activation_layer is not None): + h = self.activation_layer(h) + + if dropout and (self.dropout_layer is not None): + h = self.dropout_layer(h) + + return h + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + supports output edges edges or not. + + Returns: + + bool: + Whether the layer supports the use of edges + """ + ... + + @property + @abc.abstractmethod + def layer_inputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_input_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Whether the layer uses input edges in the forward pass + """ + ... + + @property + @abc.abstractmethod + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_output_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Whether the layer outputs edges in the forward pass + """ + ... + + @property + @abc.abstractmethod + def out_dim_factor(self) -> int: + r""" + Abstract method. + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + int: + The factor that multiplies the output dimensions + """ + ... + + def __repr__(self): + r""" + Controls how the class is printed + """ + f = self.out_dim_factor + out_dim_f_print = "" if f == 1 else f" * {f}" + return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}{out_dim_f_print}, activation={self.activation})" diff --git a/goli/nn/dgl_layers/dgn_layer.py b/goli/nn/dgl_layers/dgn_layer.py index d7b315d99..5120cf9af 100644 --- a/goli/nn/dgl_layers/dgn_layer.py +++ b/goli/nn/dgl_layers/dgn_layer.py @@ -1,169 +1,169 @@ -import torch -from typing import Dict, List, Tuple, Union, Callable -from functools import partial - -from goli.nn.dgl_layers.pna_layer import PNAConvolutionalLayer, PNAMessagePassingLayer -from goli.nn.pna_operations import PNA_AGGREGATORS -from goli.nn.dgn_operations import DGN_AGGREGATORS - - -class BaseDGNLayer: - def parse_aggregators(self, aggregators_name: List[str]) -> List[Callable]: - r""" - Parse the aggregators from a list of strings into a list of callables. - - The possibilities are: - - - `"mean"` - - `"sum"` - - `"min"` - - `"max"` - - `"std"` - - `"dir{dir_idx:int}/smooth/{Optional[temperature:float]}"` - - `"dir{dir_idx:int}/dx_abs/{Optional[temperature:float]}"` - - `"dir{dir_idx:int}/dx_no_abs/{Optional[temperature:float]}"` - - `"dir{dir_idx:int}/dx_abs_balanced/{Optional[temperature:float]}"` - - `"dir{dir_idx:int}/forward/{Optional[temperature:float]}"` - - `"dir{dir_idx:int}/backward/{Optional[temperature:float]}"` - - `dir_idx` is an integer specifying the index of the positional encoding - to use for direction. In the case of eigenvector-based directions, `dir_idx=1` - is chosen for the first non-trivial eigenvector and `dir_idx=2` for the second. - - `temperature` is used to harden the direction using a softmax on the directional - matrices. If it is not provided, then no softmax is applied. The larger the temperature, - the more weight is attributed to the dominant direction. - - Example: - ``` - In: self.parse_aggregators(["dir1/dx_abs", "dir2/smooth/0.2"]) - Out: [partial(aggregate_dir_dx_abs, dir_idx=1, temperature=None), - partial(aggregate_dir_smooth, dir_idx=2, temperature=0.2)] - ``` - - Parameters: - aggregators_name: The list of all aggregators names to use, selected - from the list of possible strings. - - Returns: - aggregators: The list of all callable aggregators. - - """ - aggregators = [] - - for agg_name in aggregators_name: - agg_name = agg_name.lower() - this_agg = None - - # Get the aggregator from PNA if not a directional aggregation - if agg_name in PNA_AGGREGATORS.keys(): - this_agg = PNA_AGGREGATORS[agg_name] - - # If the directional, get the right aggregator - elif "dir" == agg_name[:3]: - agg_split = agg_name.split("/") - agg_dir, agg_fn_name = agg_split[0], agg_split[1] - dir_idx = int(agg_dir[3:]) - temperature = None - if len(agg_split) > 2: - temperature = float(agg_split[2]) - this_agg = partial(DGN_AGGREGATORS[agg_fn_name], dir_idx=dir_idx, temperature=temperature) - - if this_agg is None: - raise ValueError(f"aggregator `{agg_name}` not a valid choice.") - - aggregators.append(this_agg) - - return aggregators - - def message_func(self, edges) -> Dict[str, torch.Tensor]: - r""" - The message function to generate messages along the edges. - """ - return { - "e": edges.data["e"], - "source_pos": edges.data["source_pos"], - "dest_pos": edges.data["dest_pos"], - } - - def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: - r""" - The reduce function to aggregate the messages. - Apply the aggregators and scalers, and concatenate the results. - """ - h_in = nodes.data["h"] - h = nodes.mailbox["e"] - source_pos = nodes.mailbox["source_pos"] - dest_pos = nodes.mailbox["dest_pos"] - D = h.shape[-2] - - # aggregators and scalers - h_to_cat = [ - aggr(h=h, h_in=h_in, source_pos=source_pos, dest_pos=dest_pos) for aggr in self.aggregators - ] - h = torch.cat(h_to_cat, dim=-1) - if len(self.scalers) > 1: - h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) - - return {"h": h} - - -class DGNConvolutionalLayer(BaseDGNLayer, PNAConvolutionalLayer): - r""" - Implementation of the convolutional architecture of the DGN layer, - previously known as `DGNSimpleLayer`. This layer aggregates the - neighbouring messages using multiple aggregators and scalers, - concatenates their results, then applies an MLP on the concatenated - features. - - DGN: Directional Graph Networks - Dominique Beaini, Saro Passaro, Vincent Létourneau, William L. Hamilton, Gabriele Corso, Pietro Liò - https://arxiv.org/pdf/2010.02863.pdf - """ - - def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: - return BaseDGNLayer.parse_aggregators(self, aggregators) - - def message_func(self, edges) -> Dict[str, torch.Tensor]: - return BaseDGNLayer.message_func(self, edges) - - def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: - return BaseDGNLayer.reduce_func(self, nodes) - - def pretrans_edges(self, edges): - pretrans = PNAConvolutionalLayer.pretrans_edges(self, edges) - pretrans.update({"source_pos": edges.src["pos_dir"], "dest_pos": edges.dst["pos_dir"]}) - return pretrans - - -class DGNMessagePassingLayer(BaseDGNLayer, PNAMessagePassingLayer): - r""" - Implementation of the message passing architecture of the DGN message passing layer, - previously known as `DGNLayerComplex`. This layer applies an MLP as - pretransformation to the concatenation of $[h_u, h_v, e_{uv}]$ to generate - the messages, with $h_u$ the node feature, $h_v$ the neighbour node features, - and $e_{uv}$ the edge feature between the nodes $u$ and $v$. - - After the pre-transformation, it aggregates the messages using - multiple aggregators and scalers, - concatenates their results, then applies an MLP on the concatenated - features. - - DGN: Directional Graph Networks - Dominique Beaini, Saro Passaro, Vincent Létourneau, William L. Hamilton, Gabriele Corso, Pietro Liò - https://arxiv.org/pdf/2010.02863.pdf - """ - - def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: - return BaseDGNLayer.parse_aggregators(self, aggregators) - - def message_func(self, edges) -> Dict[str, torch.Tensor]: - return BaseDGNLayer.message_func(self, edges) - - def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: - return BaseDGNLayer.reduce_func(self, nodes) - - def pretrans_edges(self, edges): - pretrans = PNAMessagePassingLayer.pretrans_edges(self, edges) - pretrans.update({"source_pos": edges.src["pos_dir"], "dest_pos": edges.dst["pos_dir"]}) - return pretrans +import torch +from typing import Dict, List, Tuple, Union, Callable +from functools import partial + +from goli.nn.dgl_layers.pna_layer import PNAConvolutionalLayer, PNAMessagePassingLayer +from goli.nn.pna_operations import PNA_AGGREGATORS +from goli.nn.dgn_operations import DGN_AGGREGATORS + + +class BaseDGNLayer: + def parse_aggregators(self, aggregators_name: List[str]) -> List[Callable]: + r""" + Parse the aggregators from a list of strings into a list of callables. + + The possibilities are: + + - `"mean"` + - `"sum"` + - `"min"` + - `"max"` + - `"std"` + - `"dir{dir_idx:int}/smooth/{Optional[temperature:float]}"` + - `"dir{dir_idx:int}/dx_abs/{Optional[temperature:float]}"` + - `"dir{dir_idx:int}/dx_no_abs/{Optional[temperature:float]}"` + - `"dir{dir_idx:int}/dx_abs_balanced/{Optional[temperature:float]}"` + - `"dir{dir_idx:int}/forward/{Optional[temperature:float]}"` + - `"dir{dir_idx:int}/backward/{Optional[temperature:float]}"` + + `dir_idx` is an integer specifying the index of the positional encoding + to use for direction. In the case of eigenvector-based directions, `dir_idx=1` + is chosen for the first non-trivial eigenvector and `dir_idx=2` for the second. + + `temperature` is used to harden the direction using a softmax on the directional + matrices. If it is not provided, then no softmax is applied. The larger the temperature, + the more weight is attributed to the dominant direction. + + Example: + ``` + In: self.parse_aggregators(["dir1/dx_abs", "dir2/smooth/0.2"]) + Out: [partial(aggregate_dir_dx_abs, dir_idx=1, temperature=None), + partial(aggregate_dir_smooth, dir_idx=2, temperature=0.2)] + ``` + + Parameters: + aggregators_name: The list of all aggregators names to use, selected + from the list of possible strings. + + Returns: + aggregators: The list of all callable aggregators. + + """ + aggregators = [] + + for agg_name in aggregators_name: + agg_name = agg_name.lower() + this_agg = None + + # Get the aggregator from PNA if not a directional aggregation + if agg_name in PNA_AGGREGATORS.keys(): + this_agg = PNA_AGGREGATORS[agg_name] + + # If the directional, get the right aggregator + elif "dir" == agg_name[:3]: + agg_split = agg_name.split("/") + agg_dir, agg_fn_name = agg_split[0], agg_split[1] + dir_idx = int(agg_dir[3:]) + temperature = None + if len(agg_split) > 2: + temperature = float(agg_split[2]) + this_agg = partial(DGN_AGGREGATORS[agg_fn_name], dir_idx=dir_idx, temperature=temperature) + + if this_agg is None: + raise ValueError(f"aggregator `{agg_name}` not a valid choice.") + + aggregators.append(this_agg) + + return aggregators + + def message_func(self, edges) -> Dict[str, torch.Tensor]: + r""" + The message function to generate messages along the edges. + """ + return { + "e": edges.data["e"], + "source_pos": edges.data["source_pos"], + "dest_pos": edges.data["dest_pos"], + } + + def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: + r""" + The reduce function to aggregate the messages. + Apply the aggregators and scalers, and concatenate the results. + """ + h_in = nodes.data["h"] + h = nodes.mailbox["e"] + source_pos = nodes.mailbox["source_pos"] + dest_pos = nodes.mailbox["dest_pos"] + D = h.shape[-2] + + # aggregators and scalers + h_to_cat = [ + aggr(h=h, h_in=h_in, source_pos=source_pos, dest_pos=dest_pos) for aggr in self.aggregators + ] + h = torch.cat(h_to_cat, dim=-1) + if len(self.scalers) > 1: + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) + + return {"h": h} + + +class DGNConvolutionalLayer(BaseDGNLayer, PNAConvolutionalLayer): + r""" + Implementation of the convolutional architecture of the DGN layer, + previously known as `DGNSimpleLayer`. This layer aggregates the + neighbouring messages using multiple aggregators and scalers, + concatenates their results, then applies an MLP on the concatenated + features. + + DGN: Directional Graph Networks + Dominique Beaini, Saro Passaro, Vincent Létourneau, William L. Hamilton, Gabriele Corso, Pietro Liò + https://arxiv.org/pdf/2010.02863.pdf + """ + + def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: + return BaseDGNLayer.parse_aggregators(self, aggregators) + + def message_func(self, edges) -> Dict[str, torch.Tensor]: + return BaseDGNLayer.message_func(self, edges) + + def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: + return BaseDGNLayer.reduce_func(self, nodes) + + def pretrans_edges(self, edges): + pretrans = PNAConvolutionalLayer.pretrans_edges(self, edges) + pretrans.update({"source_pos": edges.src["pos_dir"], "dest_pos": edges.dst["pos_dir"]}) + return pretrans + + +class DGNMessagePassingLayer(BaseDGNLayer, PNAMessagePassingLayer): + r""" + Implementation of the message passing architecture of the DGN message passing layer, + previously known as `DGNLayerComplex`. This layer applies an MLP as + pretransformation to the concatenation of $[h_u, h_v, e_{uv}]$ to generate + the messages, with $h_u$ the node feature, $h_v$ the neighbour node features, + and $e_{uv}$ the edge feature between the nodes $u$ and $v$. + + After the pre-transformation, it aggregates the messages using + multiple aggregators and scalers, + concatenates their results, then applies an MLP on the concatenated + features. + + DGN: Directional Graph Networks + Dominique Beaini, Saro Passaro, Vincent Létourneau, William L. Hamilton, Gabriele Corso, Pietro Liò + https://arxiv.org/pdf/2010.02863.pdf + """ + + def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: + return BaseDGNLayer.parse_aggregators(self, aggregators) + + def message_func(self, edges) -> Dict[str, torch.Tensor]: + return BaseDGNLayer.message_func(self, edges) + + def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: + return BaseDGNLayer.reduce_func(self, nodes) + + def pretrans_edges(self, edges): + pretrans = PNAMessagePassingLayer.pretrans_edges(self, edges) + pretrans.update({"source_pos": edges.src["pos_dir"], "dest_pos": edges.dst["pos_dir"]}) + return pretrans diff --git a/goli/nn/dgl_layers/gat_layer.py b/goli/nn/dgl_layers/gat_layer.py index e3e9ca91e..5a7655507 100644 --- a/goli/nn/dgl_layers/gat_layer.py +++ b/goli/nn/dgl_layers/gat_layer.py @@ -1,160 +1,160 @@ -import torch - -from dgl.nn.pytorch import GATConv -from dgl import DGLGraph - -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.utils.decorators import classproperty - -""" - GAT: Graph Attention Network - Graph Attention Networks (Veličković et al., ICLR 2018) - https://arxiv.org/abs/1710.10903 -""" - - -class GATLayer(BaseDGLLayer): - def __init__( - self, - in_dim: int, - out_dim: int, - num_heads: int, - activation="elu", - dropout: float = 0.0, - batch_norm: bool = False, - ): - r""" - GAT: Graph Attention Network - Graph Attention Networks (Veličković et al., ICLR 2018) - https://arxiv.org/abs/1710.10903 - - The implementation is built on top of the DGL ``GATCONV`` layer - - Parameters: - - in_dim: int - Input feature dimensions of the layer - - out_dim: int - Output feature dimensions of the layer - - num_heads: int - Number of heads in Multi-Head Attention - - activation: str, Callable - activation function to use in the layer - - dropout: float - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: bool - Whether to use batch normalization - """ - - self.num_heads = num_heads - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - ) - - self.gatconv = GATConv( - in_feats=self.in_dim, - out_feats=self.out_dim, - num_heads=self.num_heads, - feat_drop=self.dropout, - attn_drop=self.dropout, - activation=None, # Activation is applied after - ) - - def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the graph convolutional layer, with the specified activations, - normalizations and dropout. - - Parameters: - - g: dgl.DGLGraph - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - Returns: - - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - """ - - h = self.gatconv(g, h).flatten(1) - self.apply_norm_activation_dropout(h, batch_norm=True, activation=True, dropout=False) - - return h - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - supports_edges: bool - Always ``False`` for the current class - """ - return False - - @property - def layer_inputs_edges(self) -> bool: - r""" - Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - uses_edges: bool - Always ``False`` for the current class - """ - return False - - @property - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - uses_edges: bool - Always ``False`` for the current class - """ - return False - - @property - def out_dim_factor(self) -> int: - r""" - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - dim_factor: int - Always ``self.num_heads`` for the current class - """ - return self.num_heads +import torch + +from dgl.nn.pytorch import GATConv +from dgl import DGLGraph + +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.utils.decorators import classproperty + +""" + GAT: Graph Attention Network + Graph Attention Networks (Veličković et al., ICLR 2018) + https://arxiv.org/abs/1710.10903 +""" + + +class GATLayer(BaseDGLLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + num_heads: int, + activation="elu", + dropout: float = 0.0, + batch_norm: bool = False, + ): + r""" + GAT: Graph Attention Network + Graph Attention Networks (Veličković et al., ICLR 2018) + https://arxiv.org/abs/1710.10903 + + The implementation is built on top of the DGL ``GATCONV`` layer + + Parameters: + + in_dim: int + Input feature dimensions of the layer + + out_dim: int + Output feature dimensions of the layer + + num_heads: int + Number of heads in Multi-Head Attention + + activation: str, Callable + activation function to use in the layer + + dropout: float + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: bool + Whether to use batch normalization + """ + + self.num_heads = num_heads + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + ) + + self.gatconv = GATConv( + in_feats=self.in_dim, + out_feats=self.out_dim, + num_heads=self.num_heads, + feat_drop=self.dropout, + attn_drop=self.dropout, + activation=None, # Activation is applied after + ) + + def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the graph convolutional layer, with the specified activations, + normalizations and dropout. + + Parameters: + + g: dgl.DGLGraph + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + Returns: + + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + """ + + h = self.gatconv(g, h).flatten(1) + self.apply_norm_activation_dropout(h, batch_norm=True, activation=True, dropout=False) + + return h + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + supports_edges: bool + Always ``False`` for the current class + """ + return False + + @property + def layer_inputs_edges(self) -> bool: + r""" + Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + uses_edges: bool + Always ``False`` for the current class + """ + return False + + @property + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + uses_edges: bool + Always ``False`` for the current class + """ + return False + + @property + def out_dim_factor(self) -> int: + r""" + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + dim_factor: int + Always ``self.num_heads`` for the current class + """ + return self.num_heads diff --git a/goli/nn/dgl_layers/gated_gcn_layer.py b/goli/nn/dgl_layers/gated_gcn_layer.py index fcdb6a49b..3d073d4b8 100644 --- a/goli/nn/dgl_layers/gated_gcn_layer.py +++ b/goli/nn/dgl_layers/gated_gcn_layer.py @@ -1,190 +1,190 @@ -import torch -import torch.nn as nn -from typing import Tuple, Union, Callable -from dgl import DGLGraph - -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.utils.decorators import classproperty - -""" - ResGatedGCN: Residual Gated Graph ConvNets - An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) - https://arxiv.org/pdf/1711.07553v2.pdf -""" - - -class GatedGCNLayer(BaseDGLLayer): - def __init__( - self, - in_dim: int, - out_dim: int, - in_dim_edges: int, - out_dim_edges: int, - activation: Union[Callable, str] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - ): - r""" - ResGatedGCN: Residual Gated Graph ConvNets - An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) - https://arxiv.org/pdf/1711.07553v2.pdf - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer, and for the edges - - in_dim_edges: - Input edge-feature dimensions of the layer - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - ) - - self.A = nn.Linear(in_dim, out_dim, bias=True) - self.B = nn.Linear(in_dim, out_dim, bias=True) - self.C = nn.Linear(in_dim_edges, out_dim, bias=True) - self.D = nn.Linear(in_dim, out_dim, bias=True) - self.E = nn.Linear(in_dim, out_dim, bias=True) - - def message_func(self, edges): - Bh_j = edges.src["Bh"] - e_ij = edges.data["Ce"] + edges.src["Dh"] + edges.dst["Eh"] # e_ij = Ce_ij + Dhi + Ehj - edges.data["e"] = e_ij - return {"Bh_j": Bh_j, "e_ij": e_ij} - - def reduce_func(self, nodes): - Ah_i = nodes.data["Ah"] - Bh_j = nodes.mailbox["Bh_j"] - e = nodes.mailbox["e_ij"] - sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) - # h = Ah_i + torch.mean( sigma_ij * Bh_j, dim=1 ) # hi = Ahi + mean_j alpha_ij * Bhj - h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / ( - torch.sum(sigma_ij, dim=1) + 1e-6 - ) # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention - return {"h": h} - - def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Apply the graph convolutional layer, with the specified activations, - normalizations and dropout. - - Parameters: - - g: - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - e: `torch.Tensor[..., N, Din_edges]` - Edge feature tensor, before convolution. - N is the number of nodes, Din is the input edge dimension ``self.in_dim_edges`` - - Returns: - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - `torch.Tensor[..., N, Dout]`: - Edge feature tensor, after convolution. - N is the number of nodes, Dout_edges is the output edge dimension ``self.out_dim`` - - """ - - g.ndata["h"] = h - g.ndata["Ah"] = self.A(h) - g.ndata["Bh"] = self.B(h) - g.ndata["Dh"] = self.D(h) - g.ndata["Eh"] = self.E(h) - g.edata["e"] = e - g.edata["Ce"] = self.C(e) - g.update_all(self.message_func, self.reduce_func) - h = g.ndata["h"] # result of graph convolution - e = g.edata["e"] # result of graph convolution - - h = self.apply_norm_activation_dropout(h) - e = self.apply_norm_activation_dropout(e) - - return h, e - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - bool: - Always ``True`` for the current class - """ - return True - - @property - def layer_inputs_edges(self) -> bool: - r""" - Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``True`` for the current class - """ - return True - - @property - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``True`` for the current class - """ - return True - - @property - def out_dim_factor(self) -> int: - r""" - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - int: - Always ``1`` for the current class - """ - return 1 +import torch +import torch.nn as nn +from typing import Tuple, Union, Callable +from dgl import DGLGraph + +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.utils.decorators import classproperty + +""" + ResGatedGCN: Residual Gated Graph ConvNets + An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) + https://arxiv.org/pdf/1711.07553v2.pdf +""" + + +class GatedGCNLayer(BaseDGLLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + in_dim_edges: int, + out_dim_edges: int, + activation: Union[Callable, str] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + ): + r""" + ResGatedGCN: Residual Gated Graph ConvNets + An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) + https://arxiv.org/pdf/1711.07553v2.pdf + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer, and for the edges + + in_dim_edges: + Input edge-feature dimensions of the layer + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + ) + + self.A = nn.Linear(in_dim, out_dim, bias=True) + self.B = nn.Linear(in_dim, out_dim, bias=True) + self.C = nn.Linear(in_dim_edges, out_dim, bias=True) + self.D = nn.Linear(in_dim, out_dim, bias=True) + self.E = nn.Linear(in_dim, out_dim, bias=True) + + def message_func(self, edges): + Bh_j = edges.src["Bh"] + e_ij = edges.data["Ce"] + edges.src["Dh"] + edges.dst["Eh"] # e_ij = Ce_ij + Dhi + Ehj + edges.data["e"] = e_ij + return {"Bh_j": Bh_j, "e_ij": e_ij} + + def reduce_func(self, nodes): + Ah_i = nodes.data["Ah"] + Bh_j = nodes.mailbox["Bh_j"] + e = nodes.mailbox["e_ij"] + sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) + # h = Ah_i + torch.mean( sigma_ij * Bh_j, dim=1 ) # hi = Ahi + mean_j alpha_ij * Bhj + h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / ( + torch.sum(sigma_ij, dim=1) + 1e-6 + ) # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention + return {"h": h} + + def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Apply the graph convolutional layer, with the specified activations, + normalizations and dropout. + + Parameters: + + g: + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + e: `torch.Tensor[..., N, Din_edges]` + Edge feature tensor, before convolution. + N is the number of nodes, Din is the input edge dimension ``self.in_dim_edges`` + + Returns: + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + `torch.Tensor[..., N, Dout]`: + Edge feature tensor, after convolution. + N is the number of nodes, Dout_edges is the output edge dimension ``self.out_dim`` + + """ + + g.ndata["h"] = h + g.ndata["Ah"] = self.A(h) + g.ndata["Bh"] = self.B(h) + g.ndata["Dh"] = self.D(h) + g.ndata["Eh"] = self.E(h) + g.edata["e"] = e + g.edata["Ce"] = self.C(e) + g.update_all(self.message_func, self.reduce_func) + h = g.ndata["h"] # result of graph convolution + e = g.edata["e"] # result of graph convolution + + h = self.apply_norm_activation_dropout(h) + e = self.apply_norm_activation_dropout(e) + + return h, e + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + bool: + Always ``True`` for the current class + """ + return True + + @property + def layer_inputs_edges(self) -> bool: + r""" + Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``True`` for the current class + """ + return True + + @property + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``True`` for the current class + """ + return True + + @property + def out_dim_factor(self) -> int: + r""" + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + int: + Always ``1`` for the current class + """ + return 1 diff --git a/goli/nn/dgl_layers/gcn_layer.py b/goli/nn/dgl_layers/gcn_layer.py index e7c97f660..c20206626 100644 --- a/goli/nn/dgl_layers/gcn_layer.py +++ b/goli/nn/dgl_layers/gcn_layer.py @@ -1,154 +1,154 @@ -import torch -from typing import List, Dict, Tuple, Union, Callable - -from dgl.nn.pytorch import GraphConv -from dgl import DGLGraph - -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.utils.decorators import classproperty - -""" - GCN: Graph Convolutional Networks - Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) - http://arxiv.org/abs/1609.02907 -""" - - -class GCNLayer(BaseDGLLayer): - def __init__( - self, - in_dim: int, - out_dim: int, - activation: Union[str, Callable] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - ): - r""" - Graph convolutional network (GCN) layer from - Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) - http://arxiv.org/abs/1609.02907 - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - ) - - self.conv = GraphConv( - in_feats=in_dim, - out_feats=out_dim, - norm="both", - weight=True, - bias=True, - activation=None, - allow_zero_in_degree=False, - ) - - def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the graph convolutional layer, with the specified activations, - normalizations and dropout. - - Parameters: - - g: - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - Returns: - - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - """ - - h = self.conv(g, h) - h = self.apply_norm_activation_dropout(h) - - return h - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - bool - Always ``False`` for the current class - """ - return False - - @property - def layer_inputs_edges(self) -> bool: - r""" - Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``False`` for the current class - """ - return False - - @property - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``False`` for the current class - """ - return False - - @property - def out_dim_factor(self) -> int: - r""" - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - int: - Always ``1`` for the current class - """ - return 1 +import torch +from typing import List, Dict, Tuple, Union, Callable + +from dgl.nn.pytorch import GraphConv +from dgl import DGLGraph + +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.utils.decorators import classproperty + +""" + GCN: Graph Convolutional Networks + Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) + http://arxiv.org/abs/1609.02907 +""" + + +class GCNLayer(BaseDGLLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + ): + r""" + Graph convolutional network (GCN) layer from + Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) + http://arxiv.org/abs/1609.02907 + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + ) + + self.conv = GraphConv( + in_feats=in_dim, + out_feats=out_dim, + norm="both", + weight=True, + bias=True, + activation=None, + allow_zero_in_degree=False, + ) + + def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the graph convolutional layer, with the specified activations, + normalizations and dropout. + + Parameters: + + g: + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + Returns: + + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + """ + + h = self.conv(g, h) + h = self.apply_norm_activation_dropout(h) + + return h + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + bool + Always ``False`` for the current class + """ + return False + + @property + def layer_inputs_edges(self) -> bool: + r""" + Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``False`` for the current class + """ + return False + + @property + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``False`` for the current class + """ + return False + + @property + def out_dim_factor(self) -> int: + r""" + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + int: + Always ``1`` for the current class + """ + return 1 diff --git a/goli/nn/dgl_layers/gin_layer.py b/goli/nn/dgl_layers/gin_layer.py index ecb8509c3..23ac79c1e 100644 --- a/goli/nn/dgl_layers/gin_layer.py +++ b/goli/nn/dgl_layers/gin_layer.py @@ -1,194 +1,194 @@ -import torch -import dgl.function as fn -from dgl import DGLGraph -from typing import Callable, Union - -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.nn.base_layers import MLP -from goli.utils.decorators import classproperty - -""" - GIN: Graph Isomorphism Networks - HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) - https://arxiv.org/pdf/1810.00826.pdf -""" - - -class GINLayer(BaseDGLLayer): - def __init__( - self, - in_dim: int, - out_dim: int, - activation: Union[Callable, str] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - init_eps: float = 0.0, - learn_eps: bool = True, - ): - r""" - GIN: Graph Isomorphism Networks - HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) - https://arxiv.org/pdf/1810.00826.pdf - - [!] code adapted from dgl implementation of GINConv - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - init_eps : - Initial :math:`\epsilon` value, default: ``0``. - - learn_eps : - If True, :math:`\epsilon` will be a learnable parameter. - - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - ) - - # Specify to consider the edges weight in the aggregation - - # to specify whether eps is trainable or not. - if learn_eps: - self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps])) - else: - self.register_buffer("eps", torch.FloatTensor([init_eps])) - - # The weights of the model, applied after the aggregation - self.mlp = MLP( - in_dim=self.in_dim, - hidden_dim=self.in_dim, - out_dim=self.out_dim, - layers=2, - activation=self.activation_layer, - last_activation="none", - batch_norm=self.batch_norm, - last_batch_norm=False, - ) - - def message_func(self, g): - r""" - If edge weights are provided, use them to weight the messages - """ - - if "w" in g.edata.keys(): - func = fn.u_mul_e("h", "w", "m") - else: - func = fn.copy_u("h", "m") - return func - - def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: - r""" - Apply the GIN convolutional layer, with the specified activations, - normalizations and dropout. - - Parameters: - - g: - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - Returns: - - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - """ - - # Aggregate the message - g = g.local_var() - g.ndata["h"] = h - func = fn.copy_u("h", "m") - g.update_all(self.message_func(g), fn.sum("m", "neigh")) - h = (1 + self.eps) * h + g.ndata["neigh"] - - # Apply the MLP - h = self.mlp(h) - h = self.apply_norm_activation_dropout(h) - - return h - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - supports_edges: bool - Always ``False`` for the current class - """ - return False - - @property - def layer_inputs_edges(self) -> bool: - r""" - Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``False`` for the current class - """ - return False - - @property - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``False`` for the current class - """ - return False - - @property - def out_dim_factor(self) -> int: - r""" - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - int: - Always ``1`` for the current class - """ - return 1 +import torch +import dgl.function as fn +from dgl import DGLGraph +from typing import Callable, Union + +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.nn.base_layers import MLP +from goli.utils.decorators import classproperty + +""" + GIN: Graph Isomorphism Networks + HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) + https://arxiv.org/pdf/1810.00826.pdf +""" + + +class GINLayer(BaseDGLLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + activation: Union[Callable, str] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + init_eps: float = 0.0, + learn_eps: bool = True, + ): + r""" + GIN: Graph Isomorphism Networks + HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) + https://arxiv.org/pdf/1810.00826.pdf + + [!] code adapted from dgl implementation of GINConv + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + init_eps : + Initial :math:`\epsilon` value, default: ``0``. + + learn_eps : + If True, :math:`\epsilon` will be a learnable parameter. + + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + ) + + # Specify to consider the edges weight in the aggregation + + # to specify whether eps is trainable or not. + if learn_eps: + self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps])) + else: + self.register_buffer("eps", torch.FloatTensor([init_eps])) + + # The weights of the model, applied after the aggregation + self.mlp = MLP( + in_dim=self.in_dim, + hidden_dim=self.in_dim, + out_dim=self.out_dim, + layers=2, + activation=self.activation_layer, + last_activation="none", + batch_norm=self.batch_norm, + last_batch_norm=False, + ) + + def message_func(self, g): + r""" + If edge weights are provided, use them to weight the messages + """ + + if "w" in g.edata.keys(): + func = fn.u_mul_e("h", "w", "m") + else: + func = fn.copy_u("h", "m") + return func + + def forward(self, g: DGLGraph, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the GIN convolutional layer, with the specified activations, + normalizations and dropout. + + Parameters: + + g: + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + Returns: + + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + """ + + # Aggregate the message + g = g.local_var() + g.ndata["h"] = h + func = fn.copy_u("h", "m") + g.update_all(self.message_func(g), fn.sum("m", "neigh")) + h = (1 + self.eps) * h + g.ndata["neigh"] + + # Apply the MLP + h = self.mlp(h) + h = self.apply_norm_activation_dropout(h) + + return h + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + supports_edges: bool + Always ``False`` for the current class + """ + return False + + @property + def layer_inputs_edges(self) -> bool: + r""" + Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``False`` for the current class + """ + return False + + @property + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``False`` for the current class + """ + return False + + @property + def out_dim_factor(self) -> int: + r""" + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + int: + Always ``1`` for the current class + """ + return 1 diff --git a/goli/nn/dgl_layers/pna_layer.py b/goli/nn/dgl_layers/pna_layer.py index 238fc69a0..111488b77 100644 --- a/goli/nn/dgl_layers/pna_layer.py +++ b/goli/nn/dgl_layers/pna_layer.py @@ -1,584 +1,584 @@ -import torch -import dgl -from dgl import DGLGraph -from typing import Dict, List, Tuple, Union, Callable -from copy import deepcopy - -from goli.nn.pna_operations import PNA_AGGREGATORS, PNA_SCALERS -from goli.nn.base_layers import MLP, get_activation -from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer -from goli.utils.decorators import classproperty - -""" - PNA: Principal Neighbourhood Aggregation - Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic - https://arxiv.org/abs/2004.05718 -""" - - -class BasePNALayer(BaseDGLLayer): - def __init__( - self, - in_dim: int, - out_dim: int, - aggregators: List[str], - scalers: List[str], - activation: Union[Callable, str] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - avg_d: float = 1.0, - last_activation: Union[Callable, str] = "none", - in_dim_edges: int = 0, - ): - r""" - Abstract class used to standardize the implementation of PNA layers - in the current library. - - PNA: Principal Neighbourhood Aggregation - Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic - https://arxiv.org/abs/2004.05718 - - Method ``layer_inputs_edges()`` needs to be implemented in children classes - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - aggregators: - Set of aggregation function identifiers, - e.g. "mean", "max", "min", "std", "sum", "var", "moment3". - The results from all aggregators will be concatenated. - - scalers: - Set of scaling functions identifiers - e.g. "identidy", "amplification", "attenuation" - The results from all scalers will be concatenated - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - avg_d: - Average degree of nodes in the training set, used by scalers to normalize - - last_activation: - activation function to use in the last layer of the internal MLP - - in_dim_edges: - size of the edge features. If 0, edges are ignored - - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - ) - - # Edge dimensions - self.in_dim_edges = in_dim_edges - self.edge_features = self.in_dim_edges > 0 - - # Initializing basic attributes - self.avg_d = avg_d - self.last_activation = get_activation(last_activation) - - # Initializing aggregators and scalers - self.aggregators = self.parse_aggregators(aggregators) - self.scalers = self._parse_scalers(scalers) - - def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: - r""" - Parse the aggregators from a list of strings into a list of callables - """ - return [PNA_AGGREGATORS[aggr] for aggr in aggregators] - - def _parse_scalers(self, scalers: List[str]) -> List[Callable]: - r""" - Parse the scalers from a list of strings into a list of callables - """ - return [PNA_SCALERS[scale] for scale in scalers] - - def message_func(self, edges) -> Dict[str, torch.Tensor]: - r""" - The message function to generate messages along the edges. - """ - return {"e": edges.data["e"]} - - def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: - r""" - The reduce function to aggregate the messages. - Apply the aggregators and scalers, and concatenate the results. - """ - h_in = nodes.data["h"] - h = nodes.mailbox["e"] - D = h.shape[-2] - h_to_cat = [aggr(h=h, h_in=h_in) for aggr in self.aggregators] - h = torch.cat(h_to_cat, dim=-1) - - if len(self.scalers) > 1: - h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=-1) - - return {"h": h} - - def add_virtual_graph_if_no_edges(self, g: DGLGraph) -> Tuple[DGLGraph, bool]: - r""" - When all elements of a given batch don't have any edges - (e.g. molecule with a single atom), the message function will - be skipped, and the number of features will be inconsistent due - to the variable number of aggregators. - - To fix this issue, this method creates a new graph with self-loop - and appends it to the batch, only if all the elements of the batch - have degree 0. - - Parameters: - g: The batched graphs - - Returns: - g: The batched graphs, with a possible new graph appended at the end - no_edges: - Whether a graph was appended to the end of the batch - if all the elements of the batch had no edges. - """ - no_edges = torch.all(g.in_degrees(range(g.num_nodes())) == 0) - if no_edges: - new_g = deepcopy(dgl.unbatch(g)[0]) - new_g.add_edges(0, 0) - g = dgl.batch([g, new_g]) - - return g, no_edges - - def remove_virtual_graph_if_no_edges(self, g: DGLGraph, no_edges: bool) -> DGLGraph: - r""" - This removes the added graph from the method - `add_virtual_graph_if_no_edges`. - - Parameters: - g: The batched graphs - no_edges: - Whether to remove the last graph of the batch - if all the elements of the batch had no edges. - - Returns: - g: The batched graphs, with a possible new graph appended at the end - """ - if no_edges: - g = dgl.batch(dgl.unbatch(g)[:-1]) - - return g - - @property - def layer_outputs_edges(self) -> bool: - r""" - Abstract method. Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Always ``False`` for the current class - """ - return False - - @property - def out_dim_factor(self) -> int: - r""" - Get the factor by which the output dimension is multiplied for - the next layer. - - For standard layers, this will return ``1``. - - But for others, such as ``GatLayer``, the output is the concatenation - of the outputs from each head, so the out_dim gets multiplied by - the number of heads, and this function should return the number - of heads. - - Returns: - - int: - Always ``1`` for the current class - """ - return 1 - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - bool: - Always ``True`` for the current class - """ - return True - - @property - def layer_inputs_edges(self) -> bool: - r""" - Return a boolean specifying if the layer type - uses edges as input or not. - It is different from ``layer_supports_edges`` since a layer that - supports edges can decide to not use them. - - Returns: - - bool: - Returns ``self.edge_features`` - """ - return self.edge_features - - -class PNAConvolutionalLayer(BasePNALayer): - r""" - Implementation of the convolutional architecture of the PNA layer, - previously known as `PNASimpleLayer`. This layer aggregates the - neighbouring messages using multiple aggregators and scalers, - concatenates their results, then applies an MLP on the concatenated - features. - - PNA: Principal Neighbourhood Aggregation - Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic - https://arxiv.org/abs/2004.05718 - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - aggregators: List[str], - scalers: List[str], - activation: Union[Callable, str] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - avg_d: Dict[str, float] = {"log": 1.0}, - last_activation: Union[Callable, str] = "none", - posttrans_layers: int = 1, - in_dim_edges: int = 0, - ): - r""" - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - aggregators: - Set of aggregation function identifiers, - e.g. "mean", "max", "min", "std", "sum", "var", "moment3". - The results from all aggregators will be concatenated. - - scalers: - Set of scaling functions identifiers - e.g. "identidy", "amplification", "attenuation" - The results from all scalers will be concatenated - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - avg_d: - Average degree of nodes in the training set, used by scalers to normalize - - last_activation: - activation function to use in the last layer of the internal MLP - - posttrans_layers: - number of layers in the MLP transformation after the aggregation - - in_dim_edges: - size of the edge features. If 0, edges are ignored - - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - aggregators=aggregators, - scalers=scalers, - avg_d=avg_d, - activation=activation, - dropout=0, - batch_norm=False, - last_activation=last_activation, - in_dim_edges=in_dim_edges, - ) - - # MLP used on the aggregated messages of the neighbours - self.posttrans = MLP( - in_dim=(len(aggregators) * len(scalers)) * (self.in_dim + self.in_dim_edges), - hidden_dim=self.out_dim, - out_dim=self.out_dim, - layers=posttrans_layers, - activation=self.activation, - last_activation=self.last_activation, - dropout=dropout, - batch_norm=batch_norm, - last_batch_norm=batch_norm, - ) - - def pretrans_edges(self, edges) -> Dict[str, torch.Tensor]: - r""" - Return a mapping to the features of the source nodes, concatenated to the - edge data. - """ - if self.edge_features: - edata = torch.cat([edges.src["h"], edges.data["ef"]], dim=-1) - else: - edata = edges.src["h"] - return {"e": edata} - - def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor = None) -> torch.Tensor: - r""" - Apply the PNA convolutional layer, with the specified post transformation - - Parameters: - - g: - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - e: `torch.Tensor[..., N, Din_edges]` or `None` - Edge feature tensor, before convolution. - N is the number of nodes, Din is the input edge dimension - - Can be set to None if the layer does not use edge features - i.e. ``self.layer_inputs_edges -> False`` - - Returns: - - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - """ - - g.ndata["h"] = h - if self.edge_features: # add the edges information only if edge_features = True - g.edata["ef"] = e - - g, no_edges = self.add_virtual_graph_if_no_edges(g) - - g.apply_edges(self.pretrans_edges) - - # aggregation - g.update_all(self.message_func, self.reduce_func) - g = self.remove_virtual_graph_if_no_edges(g, no_edges=no_edges) - h = g.ndata["h"] - - # post-transformation - h = self.posttrans(h) - - return h - - -class PNAMessagePassingLayer(BasePNALayer): - r""" - Implementation of the message passing architecture of the PNA message passing layer, - previously known as `PNALayerComplex`. This layer applies an MLP as - pretransformation to the concatenation of $[h_u, h_v, e_{uv}]$ to generate - the messages, with $h_u$ the node feature, $h_v$ the neighbour node features, - and $e_{uv}$ the edge feature between the nodes $u$ and $v$. - - After the pre-transformation, it aggregates the messages - multiple aggregators and scalers, - concatenates their results, then applies an MLP on the concatenated - features. - - PNA: Principal Neighbourhood Aggregation - Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic - https://arxiv.org/abs/2004.05718 - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - aggregators: List[str], - scalers: List[str], - activation: Union[Callable, str] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - avg_d: Dict[str, float] = {"log": 1.0}, - last_activation: Union[Callable, str] = "none", - posttrans_layers: int = 1, - pretrans_layers: int = 1, - in_dim_edges: int = 0, - ): - r""" - - Parameters: - - in_dim: - Input feature dimensions of the layer - - out_dim: - Output feature dimensions of the layer - - aggregators: - Set of aggregation function identifiers, - e.g. "mean", "max", "min", "std", "sum", "var", "moment3". - The results from all aggregators will be concatenated. - - scalers: - Set of scaling functions identifiers - e.g. "identidy", "amplification", "attenuation" - The results from all scalers will be concatenated - - activation: - activation function to use in the layer - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - avg_d: - Average degree of nodes in the training set, used by scalers to normalize - - last_activation: - activation function to use in the last layer of the internal MLP - - posttrans_layers: - number of layers in the MLP transformation after the aggregation - - pretrans_layers: - number of layers in the transformation before the aggregation - - in_dim_edges: - size of the edge features. If 0, edges are ignored - - """ - - super().__init__( - in_dim=in_dim, - out_dim=out_dim, - aggregators=aggregators, - scalers=scalers, - avg_d=avg_d, - activation=activation, - dropout=0, - batch_norm=False, - last_activation=last_activation, - in_dim_edges=in_dim_edges, - ) - - # MLP used on each pair of nodes with their edge MLP(h_u, h_v, e_uv) - self.pretrans = MLP( - in_dim=2 * in_dim + in_dim_edges, - hidden_dim=in_dim, - out_dim=in_dim, - layers=pretrans_layers, - activation=self.activation, - last_activation=self.last_activation, - dropout=dropout, - batch_norm=batch_norm, - last_batch_norm=batch_norm, - ) - - # MLP used on the aggregated messages MLP(h'_u) - self.posttrans = MLP( - in_dim=(len(self.aggregators) * len(self.scalers) + 1) * in_dim, - hidden_dim=out_dim, - out_dim=out_dim, - layers=posttrans_layers, - activation=self.activation, - last_activation=self.last_activation, - dropout=dropout, - batch_norm=batch_norm, - last_batch_norm=batch_norm, - ) - - def pretrans_edges(self, edges) -> Dict[str, torch.Tensor]: - r""" - Return a mapping to the concatenation of the features from - the source node, the destination node, and the edge between them (if applicable). - """ - if self.edge_features: - z2 = torch.cat([edges.src["h"], edges.dst["h"], edges.data["ef"]], dim=-1) - else: - z2 = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1) - return {"e": self.pretrans(z2)} - - def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor = None) -> torch.Tensor: - r""" - Apply the PNA Message passing layer, with the specified pre/post transformations - - Parameters: - - g: - graph on which the convolution is done - - h: `torch.Tensor[..., N, Din]` - Node feature tensor, before convolution. - N is the number of nodes, Din is the input dimension ``self.in_dim`` - - e: `torch.Tensor[..., N, Din_edges]` or `None` - Edge feature tensor, before convolution. - N is the number of nodes, Din is the input edge dimension - - Can be set to None if the layer does not use edge features - i.e. ``self.layer_inputs_edges -> False`` - - Returns: - - `torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution. - N is the number of nodes, Dout is the output dimension ``self.out_dim`` - - """ - - g.ndata["h"] = h - if self.edge_features: # add the edges information only if edge_features = True - g.edata["ef"] = e - g, no_edges = self.add_virtual_graph_if_no_edges(g) - - # pretransformation - g.apply_edges(self.pretrans_edges) - - # aggregation - g.update_all(self.message_func, self.reduce_func) - g = self.remove_virtual_graph_if_no_edges(g, no_edges=no_edges) - h = torch.cat([h, g.ndata["h"]], dim=-1) - - # post-transformation - h = self.posttrans(h) - - return h - - @classproperty - def layer_supports_edges(cls) -> bool: - r""" - Return a boolean specifying if the layer type supports edges or not. - - Returns: - - bool: - Always ``True`` for the current class - """ - return True +import torch +import dgl +from dgl import DGLGraph +from typing import Dict, List, Tuple, Union, Callable +from copy import deepcopy + +from goli.nn.pna_operations import PNA_AGGREGATORS, PNA_SCALERS +from goli.nn.base_layers import MLP, get_activation +from goli.nn.dgl_layers.base_dgl_layer import BaseDGLLayer +from goli.utils.decorators import classproperty + +""" + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 +""" + + +class BasePNALayer(BaseDGLLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + aggregators: List[str], + scalers: List[str], + activation: Union[Callable, str] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + avg_d: float = 1.0, + last_activation: Union[Callable, str] = "none", + in_dim_edges: int = 0, + ): + r""" + Abstract class used to standardize the implementation of PNA layers + in the current library. + + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 + + Method ``layer_inputs_edges()`` needs to be implemented in children classes + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + aggregators: + Set of aggregation function identifiers, + e.g. "mean", "max", "min", "std", "sum", "var", "moment3". + The results from all aggregators will be concatenated. + + scalers: + Set of scaling functions identifiers + e.g. "identidy", "amplification", "attenuation" + The results from all scalers will be concatenated + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + avg_d: + Average degree of nodes in the training set, used by scalers to normalize + + last_activation: + activation function to use in the last layer of the internal MLP + + in_dim_edges: + size of the edge features. If 0, edges are ignored + + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + ) + + # Edge dimensions + self.in_dim_edges = in_dim_edges + self.edge_features = self.in_dim_edges > 0 + + # Initializing basic attributes + self.avg_d = avg_d + self.last_activation = get_activation(last_activation) + + # Initializing aggregators and scalers + self.aggregators = self.parse_aggregators(aggregators) + self.scalers = self._parse_scalers(scalers) + + def parse_aggregators(self, aggregators: List[str]) -> List[Callable]: + r""" + Parse the aggregators from a list of strings into a list of callables + """ + return [PNA_AGGREGATORS[aggr] for aggr in aggregators] + + def _parse_scalers(self, scalers: List[str]) -> List[Callable]: + r""" + Parse the scalers from a list of strings into a list of callables + """ + return [PNA_SCALERS[scale] for scale in scalers] + + def message_func(self, edges) -> Dict[str, torch.Tensor]: + r""" + The message function to generate messages along the edges. + """ + return {"e": edges.data["e"]} + + def reduce_func(self, nodes) -> Dict[str, torch.Tensor]: + r""" + The reduce function to aggregate the messages. + Apply the aggregators and scalers, and concatenate the results. + """ + h_in = nodes.data["h"] + h = nodes.mailbox["e"] + D = h.shape[-2] + h_to_cat = [aggr(h=h, h_in=h_in) for aggr in self.aggregators] + h = torch.cat(h_to_cat, dim=-1) + + if len(self.scalers) > 1: + h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=-1) + + return {"h": h} + + def add_virtual_graph_if_no_edges(self, g: DGLGraph) -> Tuple[DGLGraph, bool]: + r""" + When all elements of a given batch don't have any edges + (e.g. molecule with a single atom), the message function will + be skipped, and the number of features will be inconsistent due + to the variable number of aggregators. + + To fix this issue, this method creates a new graph with self-loop + and appends it to the batch, only if all the elements of the batch + have degree 0. + + Parameters: + g: The batched graphs + + Returns: + g: The batched graphs, with a possible new graph appended at the end + no_edges: + Whether a graph was appended to the end of the batch + if all the elements of the batch had no edges. + """ + no_edges = torch.all(g.in_degrees(range(g.num_nodes())) == 0) + if no_edges: + new_g = deepcopy(dgl.unbatch(g)[0]) + new_g.add_edges(0, 0) + g = dgl.batch([g, new_g]) + + return g, no_edges + + def remove_virtual_graph_if_no_edges(self, g: DGLGraph, no_edges: bool) -> DGLGraph: + r""" + This removes the added graph from the method + `add_virtual_graph_if_no_edges`. + + Parameters: + g: The batched graphs + no_edges: + Whether to remove the last graph of the batch + if all the elements of the batch had no edges. + + Returns: + g: The batched graphs, with a possible new graph appended at the end + """ + if no_edges: + g = dgl.batch(dgl.unbatch(g)[:-1]) + + return g + + @property + def layer_outputs_edges(self) -> bool: + r""" + Abstract method. Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Always ``False`` for the current class + """ + return False + + @property + def out_dim_factor(self) -> int: + r""" + Get the factor by which the output dimension is multiplied for + the next layer. + + For standard layers, this will return ``1``. + + But for others, such as ``GatLayer``, the output is the concatenation + of the outputs from each head, so the out_dim gets multiplied by + the number of heads, and this function should return the number + of heads. + + Returns: + + int: + Always ``1`` for the current class + """ + return 1 + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + bool: + Always ``True`` for the current class + """ + return True + + @property + def layer_inputs_edges(self) -> bool: + r""" + Return a boolean specifying if the layer type + uses edges as input or not. + It is different from ``layer_supports_edges`` since a layer that + supports edges can decide to not use them. + + Returns: + + bool: + Returns ``self.edge_features`` + """ + return self.edge_features + + +class PNAConvolutionalLayer(BasePNALayer): + r""" + Implementation of the convolutional architecture of the PNA layer, + previously known as `PNASimpleLayer`. This layer aggregates the + neighbouring messages using multiple aggregators and scalers, + concatenates their results, then applies an MLP on the concatenated + features. + + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + aggregators: List[str], + scalers: List[str], + activation: Union[Callable, str] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + avg_d: Dict[str, float] = {"log": 1.0}, + last_activation: Union[Callable, str] = "none", + posttrans_layers: int = 1, + in_dim_edges: int = 0, + ): + r""" + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + aggregators: + Set of aggregation function identifiers, + e.g. "mean", "max", "min", "std", "sum", "var", "moment3". + The results from all aggregators will be concatenated. + + scalers: + Set of scaling functions identifiers + e.g. "identidy", "amplification", "attenuation" + The results from all scalers will be concatenated + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + avg_d: + Average degree of nodes in the training set, used by scalers to normalize + + last_activation: + activation function to use in the last layer of the internal MLP + + posttrans_layers: + number of layers in the MLP transformation after the aggregation + + in_dim_edges: + size of the edge features. If 0, edges are ignored + + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + aggregators=aggregators, + scalers=scalers, + avg_d=avg_d, + activation=activation, + dropout=0, + batch_norm=False, + last_activation=last_activation, + in_dim_edges=in_dim_edges, + ) + + # MLP used on the aggregated messages of the neighbours + self.posttrans = MLP( + in_dim=(len(aggregators) * len(scalers)) * (self.in_dim + self.in_dim_edges), + hidden_dim=self.out_dim, + out_dim=self.out_dim, + layers=posttrans_layers, + activation=self.activation, + last_activation=self.last_activation, + dropout=dropout, + batch_norm=batch_norm, + last_batch_norm=batch_norm, + ) + + def pretrans_edges(self, edges) -> Dict[str, torch.Tensor]: + r""" + Return a mapping to the features of the source nodes, concatenated to the + edge data. + """ + if self.edge_features: + edata = torch.cat([edges.src["h"], edges.data["ef"]], dim=-1) + else: + edata = edges.src["h"] + return {"e": edata} + + def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor = None) -> torch.Tensor: + r""" + Apply the PNA convolutional layer, with the specified post transformation + + Parameters: + + g: + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + e: `torch.Tensor[..., N, Din_edges]` or `None` + Edge feature tensor, before convolution. + N is the number of nodes, Din is the input edge dimension + + Can be set to None if the layer does not use edge features + i.e. ``self.layer_inputs_edges -> False`` + + Returns: + + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + """ + + g.ndata["h"] = h + if self.edge_features: # add the edges information only if edge_features = True + g.edata["ef"] = e + + g, no_edges = self.add_virtual_graph_if_no_edges(g) + + g.apply_edges(self.pretrans_edges) + + # aggregation + g.update_all(self.message_func, self.reduce_func) + g = self.remove_virtual_graph_if_no_edges(g, no_edges=no_edges) + h = g.ndata["h"] + + # post-transformation + h = self.posttrans(h) + + return h + + +class PNAMessagePassingLayer(BasePNALayer): + r""" + Implementation of the message passing architecture of the PNA message passing layer, + previously known as `PNALayerComplex`. This layer applies an MLP as + pretransformation to the concatenation of $[h_u, h_v, e_{uv}]$ to generate + the messages, with $h_u$ the node feature, $h_v$ the neighbour node features, + and $e_{uv}$ the edge feature between the nodes $u$ and $v$. + + After the pre-transformation, it aggregates the messages + multiple aggregators and scalers, + concatenates their results, then applies an MLP on the concatenated + features. + + PNA: Principal Neighbourhood Aggregation + Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic + https://arxiv.org/abs/2004.05718 + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + aggregators: List[str], + scalers: List[str], + activation: Union[Callable, str] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + avg_d: Dict[str, float] = {"log": 1.0}, + last_activation: Union[Callable, str] = "none", + posttrans_layers: int = 1, + pretrans_layers: int = 1, + in_dim_edges: int = 0, + ): + r""" + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + aggregators: + Set of aggregation function identifiers, + e.g. "mean", "max", "min", "std", "sum", "var", "moment3". + The results from all aggregators will be concatenated. + + scalers: + Set of scaling functions identifiers + e.g. "identidy", "amplification", "attenuation" + The results from all scalers will be concatenated + + activation: + activation function to use in the layer + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + avg_d: + Average degree of nodes in the training set, used by scalers to normalize + + last_activation: + activation function to use in the last layer of the internal MLP + + posttrans_layers: + number of layers in the MLP transformation after the aggregation + + pretrans_layers: + number of layers in the transformation before the aggregation + + in_dim_edges: + size of the edge features. If 0, edges are ignored + + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + aggregators=aggregators, + scalers=scalers, + avg_d=avg_d, + activation=activation, + dropout=0, + batch_norm=False, + last_activation=last_activation, + in_dim_edges=in_dim_edges, + ) + + # MLP used on each pair of nodes with their edge MLP(h_u, h_v, e_uv) + self.pretrans = MLP( + in_dim=2 * in_dim + in_dim_edges, + hidden_dim=in_dim, + out_dim=in_dim, + layers=pretrans_layers, + activation=self.activation, + last_activation=self.last_activation, + dropout=dropout, + batch_norm=batch_norm, + last_batch_norm=batch_norm, + ) + + # MLP used on the aggregated messages MLP(h'_u) + self.posttrans = MLP( + in_dim=(len(self.aggregators) * len(self.scalers) + 1) * in_dim, + hidden_dim=out_dim, + out_dim=out_dim, + layers=posttrans_layers, + activation=self.activation, + last_activation=self.last_activation, + dropout=dropout, + batch_norm=batch_norm, + last_batch_norm=batch_norm, + ) + + def pretrans_edges(self, edges) -> Dict[str, torch.Tensor]: + r""" + Return a mapping to the concatenation of the features from + the source node, the destination node, and the edge between them (if applicable). + """ + if self.edge_features: + z2 = torch.cat([edges.src["h"], edges.dst["h"], edges.data["ef"]], dim=-1) + else: + z2 = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1) + return {"e": self.pretrans(z2)} + + def forward(self, g: DGLGraph, h: torch.Tensor, e: torch.Tensor = None) -> torch.Tensor: + r""" + Apply the PNA Message passing layer, with the specified pre/post transformations + + Parameters: + + g: + graph on which the convolution is done + + h: `torch.Tensor[..., N, Din]` + Node feature tensor, before convolution. + N is the number of nodes, Din is the input dimension ``self.in_dim`` + + e: `torch.Tensor[..., N, Din_edges]` or `None` + Edge feature tensor, before convolution. + N is the number of nodes, Din is the input edge dimension + + Can be set to None if the layer does not use edge features + i.e. ``self.layer_inputs_edges -> False`` + + Returns: + + `torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution. + N is the number of nodes, Dout is the output dimension ``self.out_dim`` + + """ + + g.ndata["h"] = h + if self.edge_features: # add the edges information only if edge_features = True + g.edata["ef"] = e + g, no_edges = self.add_virtual_graph_if_no_edges(g) + + # pretransformation + g.apply_edges(self.pretrans_edges) + + # aggregation + g.update_all(self.message_func, self.reduce_func) + g = self.remove_virtual_graph_if_no_edges(g, no_edges=no_edges) + h = torch.cat([h, g.ndata["h"]], dim=-1) + + # post-transformation + h = self.posttrans(h) + + return h + + @classproperty + def layer_supports_edges(cls) -> bool: + r""" + Return a boolean specifying if the layer type supports edges or not. + + Returns: + + bool: + Always ``True`` for the current class + """ + return True diff --git a/goli/nn/dgl_layers/pooling.py b/goli/nn/dgl_layers/pooling.py index 76304858a..73a166802 100644 --- a/goli/nn/dgl_layers/pooling.py +++ b/goli/nn/dgl_layers/pooling.py @@ -1,336 +1,336 @@ -import torch -import torch.nn as nn -from typing import List, Union, Callable, Tuple - -import dgl -from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling, Set2Set, GlobalAttentionPooling -from dgl import mean_nodes, sum_nodes, max_nodes - -from goli.nn.base_layers import MLP, FCLayer -from goli.utils.tensor import ModuleListConcat - - -EPS = 1e-6 - - -class S2SReadout(nn.Module): - r""" - Performs a Set2Set aggregation of all the graph nodes' features followed by a series of fully connected layers - """ - - def __init__( - self, - in_dim, - hidden_dim, - out_dim, - fc_layers=3, - device="cpu", - final_activation: Union[str, Callable] = "relu", - ): - super().__init__() - - # set2set aggregation - self.set2set = Set2Set(in_dim, device=device) - - # fully connected layers - self.mlp = MLP( - in_dim=2 * in_dim, - hidden_dim=hidden_dim, - out_dim=out_dim, - layers=fc_layers, - activation="relu", - last_activation=final_activation, - batch_norm=True, - last_batch_norm=False, - device=device, - ) - - def forward(self, x): - x = self.set2set(x) - return self.mlp(x) - - -class StdPooling(nn.Module): - r"""Apply standard deviation pooling over the nodes in the graph. - - $$r^{(i)} = \sigma_{k=1}^{N_i}\left( x^{(i)}_k \right)$$ - """ - - def __init__(self): - super().__init__() - self.sum_pooler = SumPooling() - self.relu = nn.ReLU() - - def forward(self, graph, feat): - r"""Compute standard deviation pooling. - - Parameters: - graph : DGLGraph - The graph. - feat : torch.Tensor - The input feature with shape :math:`(N, *)` where - :math:`N` is the number of nodes in the graph. - - Returns: - torch.Tensor - The output feature with shape :math:`(B, *)`, where - :math:`B` refers to the batch size. - """ - - readout = torch.sqrt( - self.relu((self.sum_pooler(graph, feat ** 2)) - (self.sum_pooler(graph, feat) ** 2)) + EPS - ) - return readout - - -class MinPooling(MaxPooling): - r"""Apply min pooling over the nodes in the graph. - - $$r^{(i)} = \min_{k=1}^{N_i}\left( x^{(i)}_k \right)$$ - """ - - def forward(self, graph, feat): - r"""Compute max pooling. - - Parameters: - graph : DGLGraph - The graph. - feat : torch.Tensor - The input feature with shape :math:`(N, *)` where - :math:`N` is the number of nodes in the graph. - - Returns: - readout: torch.Tensor - The output feature with shape :math:`(B, *)`, where - :math:`B` refers to the batch size. - """ - - return -super().forward(graph, -feat) - - -class DirPooling(nn.Module): - r""" - Apply pooling over the nodes in the graph using a directional potential - with an inner product. - - In most cases, this is a pooling using the Fiedler vector. - This is basically equivalent to computing a Fourier transform for the - Fiedler vector. Then, we use the absolute value due to the sign ambiguity - - """ - - def __init__(self, dir_idx): - super().__init__() - self.sum_pooler = SumPooling() - self.dir_idx = dir_idx - - def forward(self, graph, feat): - r"""Compute directional inner-product pooling, and return absolute value. - - Parameters: - graph : DGLGraph - The graph. Must have the key `graph.ndata["pos_dir"]` - feat : torch.Tensor - The input feature with shape :math:`(N, *)` where - :math:`N` is the number of nodes in the graph. - - Returns: - readout: torch.Tensor - The output feature with shape :math:`(B, *)`, where - :math:`B` refers to the batch size. - """ - - dir = graph.ndata["pos_dir"][:, self.dir_idx].unsqueeze(-1) - pooled = torch.abs(self.sum_pooler(graph, feat * dir)) - - return pooled - - -def parse_pooling_layer(in_dim: int, pooling: Union[str, List[str]], n_iters: int = 2, n_layers: int = 2): - r""" - Select the pooling layers from a list of strings, and put them - in a Module that concatenates their outputs. - - Parameters: - - in_dim: - The dimension at the input layer of the pooling - - pooling: - The list of pooling layers to use. The accepted strings are: - - - "sum": `SumPooling` - - "mean": `MeanPooling` - - "max": `MaxPooling` - - "min": `MinPooling` - - "std": `StdPooling` - - "s2s": `Set2Set` - - "dir{int}": `DirPooling` - - n_iters: - IGNORED FOR ALL POOLING LAYERS, EXCEPT "s2s". - The number of iterations. - - n_layers: - IGNORED FOR ALL POOLING LAYERS, EXCEPT "s2s". - The number of recurrent layers. - """ - - # TODO: Add configuration for the pooling layer kwargs - - # Create the pooling layer - pool_layer = ModuleListConcat() - out_pool_dim = 0 - if isinstance(pooling, str): - pooling = [pooling] - - for this_pool in pooling: - this_pool = None if this_pool is None else this_pool.lower() - out_pool_dim += in_dim - if this_pool == "sum": - pool_layer.append(SumPooling()) - elif this_pool == "mean": - pool_layer.append(AvgPooling()) - elif this_pool == "max": - pool_layer.append(MaxPooling()) - elif this_pool == "min": - pool_layer.append(MinPooling()) - elif this_pool == "std": - pool_layer.append(StdPooling()) - elif this_pool == "s2s": - pool_layer.append(Set2Set(input_dim=in_dim, n_iters=n_iters, n_layers=n_layers)) - out_pool_dim += in_dim - elif isinstance(this_pool, str) and (this_pool[:3] == "dir"): - dir_idx = int(this_pool[3:]) - pool_layer.append(DirPooling(dir_idx=dir_idx)) - elif (this_pool == "none") or (this_pool is None): - pass - else: - raise NotImplementedError(f"Undefined pooling `{this_pool}`") - - return pool_layer, out_pool_dim - - -class VirtualNode(nn.Module): - def __init__( - self, - dim: int, - vn_type: Union[type(None), str] = "sum", - activation: Union[str, Callable] = "relu", - dropout: float = 0.0, - batch_norm: bool = False, - bias: bool = True, - residual: bool = True, - ): - r""" - The VirtualNode is a layer that pool the features of the graph, - applies a neural network layer on the pooled features, - then add the result back to the node features of every node. - - Parameters: - - in_dim: - Input feature dimensions of the virtual node layer - - activation: - activation function to use in the neural network layer. - - dropout: - The ratio of units to dropout. Must be between 0 and 1 - - batch_norm: - Whether to use batch normalization - - bias: - Whether to add a bias to the neural network - - residual: - Whether all virtual nodes should be connected together - via a residual connection - - """ - super().__init__() - if (vn_type is None) or (vn_type.lower() == "none"): - self.vn_type = None - self.fc_layer = None - self.residual = None - return - - self.vn_type = vn_type.lower() - self.residual = residual - self.fc_layer = FCLayer( - in_dim=dim, - out_dim=dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - bias=bias, - ) - - def forward( - self, g: dgl.DGLGraph, h: torch.Tensor, vn_h: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Apply the virtual node layer. - - Parameters: - - g: - graph on which the convolution is done - - h (torch.Tensor[..., N, Din]): - Node feature tensor, before convolution. - `N` is the number of nodes, `Din` is the input features - - vn_h (torch.Tensor[..., M, Din]): - Graph feature of the previous virtual node, or `None` - `M` is the number of graphs, `Din` is the input features. - It is added to the result after the MLP, as a residual connection - - Returns: - - `h = torch.Tensor[..., N, Dout]`: - Node feature tensor, after convolution and residual. - `N` is the number of nodes, `Dout` is the output features of the layer and residual - - `vn_h = torch.Tensor[..., M, Dout]`: - Graph feature tensor to be used at the next virtual node, or `None` - `M` is the number of graphs, `Dout` is the output features - - """ - - g.ndata["h"] = h - - # Pool the features - if self.vn_type is None: - return h, vn_h - elif self.vn_type == "mean": - pool = mean_nodes(g, "h") - elif self.vn_type == "max": - pool = max_nodes(g, "h") - elif self.vn_type == "sum": - pool = sum_nodes(g, "h") - elif self.vn_type == "logsum": - pool = mean_nodes(g, "h") - lognum = torch.log(torch.tensor(g.batch_num_nodes, dtype=h.dtype, device=h.device)) - pool = pool * lognum.unsqueeze(-1) - else: - raise ValueError( - f'Undefined input "{self.pooling}". Accepted values are "none", "sum", "mean", "logsum"' - ) - - # Compute the new virtual node features - vn_h_temp = self.fc_layer.forward(vn_h + pool) - if self.residual: - vn_h = vn_h + vn_h_temp - else: - vn_h = vn_h_temp - - # Add the virtual node value to the graph features - temp_h = torch.cat( - [vn_h[ii : ii + 1].repeat(num_nodes, 1) for ii, num_nodes in enumerate(g.batch_num_nodes())], - dim=0, - ) - h = h + temp_h - - return h, vn_h +import torch +import torch.nn as nn +from typing import List, Union, Callable, Tuple + +import dgl +from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling, Set2Set, GlobalAttentionPooling +from dgl import mean_nodes, sum_nodes, max_nodes + +from goli.nn.base_layers import MLP, FCLayer +from goli.utils.tensor import ModuleListConcat + + +EPS = 1e-6 + + +class S2SReadout(nn.Module): + r""" + Performs a Set2Set aggregation of all the graph nodes' features followed by a series of fully connected layers + """ + + def __init__( + self, + in_dim, + hidden_dim, + out_dim, + fc_layers=3, + device="cpu", + final_activation: Union[str, Callable] = "relu", + ): + super().__init__() + + # set2set aggregation + self.set2set = Set2Set(in_dim, device=device) + + # fully connected layers + self.mlp = MLP( + in_dim=2 * in_dim, + hidden_dim=hidden_dim, + out_dim=out_dim, + layers=fc_layers, + activation="relu", + last_activation=final_activation, + batch_norm=True, + last_batch_norm=False, + device=device, + ) + + def forward(self, x): + x = self.set2set(x) + return self.mlp(x) + + +class StdPooling(nn.Module): + r"""Apply standard deviation pooling over the nodes in the graph. + + $$r^{(i)} = \sigma_{k=1}^{N_i}\left( x^{(i)}_k \right)$$ + """ + + def __init__(self): + super().__init__() + self.sum_pooler = SumPooling() + self.relu = nn.ReLU() + + def forward(self, graph, feat): + r"""Compute standard deviation pooling. + + Parameters: + graph : DGLGraph + The graph. + feat : torch.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns: + torch.Tensor + The output feature with shape :math:`(B, *)`, where + :math:`B` refers to the batch size. + """ + + readout = torch.sqrt( + self.relu((self.sum_pooler(graph, feat ** 2)) - (self.sum_pooler(graph, feat) ** 2)) + EPS + ) + return readout + + +class MinPooling(MaxPooling): + r"""Apply min pooling over the nodes in the graph. + + $$r^{(i)} = \min_{k=1}^{N_i}\left( x^{(i)}_k \right)$$ + """ + + def forward(self, graph, feat): + r"""Compute max pooling. + + Parameters: + graph : DGLGraph + The graph. + feat : torch.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns: + readout: torch.Tensor + The output feature with shape :math:`(B, *)`, where + :math:`B` refers to the batch size. + """ + + return -super().forward(graph, -feat) + + +class DirPooling(nn.Module): + r""" + Apply pooling over the nodes in the graph using a directional potential + with an inner product. + + In most cases, this is a pooling using the Fiedler vector. + This is basically equivalent to computing a Fourier transform for the + Fiedler vector. Then, we use the absolute value due to the sign ambiguity + + """ + + def __init__(self, dir_idx): + super().__init__() + self.sum_pooler = SumPooling() + self.dir_idx = dir_idx + + def forward(self, graph, feat): + r"""Compute directional inner-product pooling, and return absolute value. + + Parameters: + graph : DGLGraph + The graph. Must have the key `graph.ndata["pos_dir"]` + feat : torch.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns: + readout: torch.Tensor + The output feature with shape :math:`(B, *)`, where + :math:`B` refers to the batch size. + """ + + dir = graph.ndata["pos_dir"][:, self.dir_idx].unsqueeze(-1) + pooled = torch.abs(self.sum_pooler(graph, feat * dir)) + + return pooled + + +def parse_pooling_layer(in_dim: int, pooling: Union[str, List[str]], n_iters: int = 2, n_layers: int = 2): + r""" + Select the pooling layers from a list of strings, and put them + in a Module that concatenates their outputs. + + Parameters: + + in_dim: + The dimension at the input layer of the pooling + + pooling: + The list of pooling layers to use. The accepted strings are: + + - "sum": `SumPooling` + - "mean": `MeanPooling` + - "max": `MaxPooling` + - "min": `MinPooling` + - "std": `StdPooling` + - "s2s": `Set2Set` + - "dir{int}": `DirPooling` + + n_iters: + IGNORED FOR ALL POOLING LAYERS, EXCEPT "s2s". + The number of iterations. + + n_layers: + IGNORED FOR ALL POOLING LAYERS, EXCEPT "s2s". + The number of recurrent layers. + """ + + # TODO: Add configuration for the pooling layer kwargs + + # Create the pooling layer + pool_layer = ModuleListConcat() + out_pool_dim = 0 + if isinstance(pooling, str): + pooling = [pooling] + + for this_pool in pooling: + this_pool = None if this_pool is None else this_pool.lower() + out_pool_dim += in_dim + if this_pool == "sum": + pool_layer.append(SumPooling()) + elif this_pool == "mean": + pool_layer.append(AvgPooling()) + elif this_pool == "max": + pool_layer.append(MaxPooling()) + elif this_pool == "min": + pool_layer.append(MinPooling()) + elif this_pool == "std": + pool_layer.append(StdPooling()) + elif this_pool == "s2s": + pool_layer.append(Set2Set(input_dim=in_dim, n_iters=n_iters, n_layers=n_layers)) + out_pool_dim += in_dim + elif isinstance(this_pool, str) and (this_pool[:3] == "dir"): + dir_idx = int(this_pool[3:]) + pool_layer.append(DirPooling(dir_idx=dir_idx)) + elif (this_pool == "none") or (this_pool is None): + pass + else: + raise NotImplementedError(f"Undefined pooling `{this_pool}`") + + return pool_layer, out_pool_dim + + +class VirtualNode(nn.Module): + def __init__( + self, + dim: int, + vn_type: Union[type(None), str] = "sum", + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + batch_norm: bool = False, + bias: bool = True, + residual: bool = True, + ): + r""" + The VirtualNode is a layer that pool the features of the graph, + applies a neural network layer on the pooled features, + then add the result back to the node features of every node. + + Parameters: + + in_dim: + Input feature dimensions of the virtual node layer + + activation: + activation function to use in the neural network layer. + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + batch_norm: + Whether to use batch normalization + + bias: + Whether to add a bias to the neural network + + residual: + Whether all virtual nodes should be connected together + via a residual connection + + """ + super().__init__() + if (vn_type is None) or (vn_type.lower() == "none"): + self.vn_type = None + self.fc_layer = None + self.residual = None + return + + self.vn_type = vn_type.lower() + self.residual = residual + self.fc_layer = FCLayer( + in_dim=dim, + out_dim=dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + bias=bias, + ) + + def forward( + self, g: dgl.DGLGraph, h: torch.Tensor, vn_h: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Apply the virtual node layer. + + Parameters: + + g: + graph on which the convolution is done + + h (torch.Tensor[..., N, Din]): + Node feature tensor, before convolution. + `N` is the number of nodes, `Din` is the input features + + vn_h (torch.Tensor[..., M, Din]): + Graph feature of the previous virtual node, or `None` + `M` is the number of graphs, `Din` is the input features. + It is added to the result after the MLP, as a residual connection + + Returns: + + `h = torch.Tensor[..., N, Dout]`: + Node feature tensor, after convolution and residual. + `N` is the number of nodes, `Dout` is the output features of the layer and residual + + `vn_h = torch.Tensor[..., M, Dout]`: + Graph feature tensor to be used at the next virtual node, or `None` + `M` is the number of graphs, `Dout` is the output features + + """ + + g.ndata["h"] = h + + # Pool the features + if self.vn_type is None: + return h, vn_h + elif self.vn_type == "mean": + pool = mean_nodes(g, "h") + elif self.vn_type == "max": + pool = max_nodes(g, "h") + elif self.vn_type == "sum": + pool = sum_nodes(g, "h") + elif self.vn_type == "logsum": + pool = mean_nodes(g, "h") + lognum = torch.log(torch.tensor(g.batch_num_nodes, dtype=h.dtype, device=h.device)) + pool = pool * lognum.unsqueeze(-1) + else: + raise ValueError( + f'Undefined input "{self.pooling}". Accepted values are "none", "sum", "mean", "logsum"' + ) + + # Compute the new virtual node features + vn_h_temp = self.fc_layer.forward(vn_h + pool) + if self.residual: + vn_h = vn_h + vn_h_temp + else: + vn_h = vn_h_temp + + # Add the virtual node value to the graph features + temp_h = torch.cat( + [vn_h[ii : ii + 1].repeat(num_nodes, 1) for ii, num_nodes in enumerate(g.batch_num_nodes())], + dim=0, + ) + h = h + temp_h + + return h, vn_h diff --git a/goli/nn/dgn_operations.py b/goli/nn/dgn_operations.py index 935c75284..9ecbb78e6 100644 --- a/goli/nn/dgn_operations.py +++ b/goli/nn/dgn_operations.py @@ -1,297 +1,297 @@ -from typing import Optional -import torch -from torch import Tensor - - -EPS = 1e-8 - - -def get_grad_of_pos( - source_pos: Tensor, dest_pos: Tensor, dir_idx: int, temperature: Optional[float] = None -) -> Tensor: - r""" - Get the vector field associated to the gradient of the positional - encoding. - - $$F_k = \nabla pos_k$$ - - or, if a temperature $T$ is provided - - $$F_k = softmax((\nabla pos_k)^+) - softmax((\nabla pos_k)^-)$$ - - Where $F_k$ is the *k-th* directional field associated to the $k-th$ positional - encoding. - - Parameters: - - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - """ - - grad = source_pos[:, :, dir_idx] - dest_pos[:, :, dir_idx] - if temperature is not None: - grad_pos, grad_neg = grad >= 0, grad <= 0 - grad_plus, grad_minus = grad, -grad - - # Compute softmax, considering one sign at a time. - grad_plus[grad_neg] = -float("inf") - grad_minus[grad_pos] = -float("inf") - grad_plus = torch.nn.Softmax(1)(temperature * grad_plus) - grad_minus = torch.nn.Softmax(1)(temperature * grad_minus) - grad_plus[grad_neg] = 0 - grad_minus[grad_pos] = 0 - - grad = grad_plus - grad_minus - - return grad - - -def aggregate_dir_smooth( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the following: - - $$y^{(l)} = |\hat{F}_k| h^{(l)}$$ - - - $\hat{F}^+_k$ is the normalized directional field *k-th* directional field $F_k$ - - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. - - $h^{(l)}$ is the node features at the *l-th* layer. - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - - """ - grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) - h_mod = h * (grad.abs() / (torch.sum(grad.abs(), keepdim=True, dim=1) + EPS)).unsqueeze(-1) - return torch.sum(h_mod, dim=1) - - -def aggregate_dir_dx_abs( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the following: - - $$y^{(l)} = |B_{dx}^k h^{(l)}|$$ - - $$B_{dx}^k = \hat{F}_k - diag \left(\sum_j{\hat{F}_{k_{(:, j)}}} \right)$$ - - - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ - - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. - - $h^{(l)}$ is the node features at the *l-th* layer. - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - - """ - return torch.abs(aggregate_dir_dx_no_abs(h, source_pos, dest_pos, h_in, dir_idx, temperature, **kwargs)) - - -def aggregate_dir_dx_no_abs( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the following: - - $$y^{(l)} = B_{dx}^k h^{(l)}$$ - - $$B_{dx}^k = \hat{F}_k - diag \left(\sum_j{\hat{F}_{k_{(:, j)}}} \right)$$ - - - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ - - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. - - $h^{(l)}$ is the node features at the *l-th* layer. - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - - """ - grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) - dir_weight = (grad / (torch.sum(grad.abs(), keepdim=True, dim=1) + EPS)).unsqueeze(-1) - h_mod = h * dir_weight - - h_dx = torch.sum(h_mod, dim=1) - h_self = -torch.sum(dir_weight, dim=1) * h_in - - # In case h_in has more parameters than h (for example when concatenating edges), - # the derivative is only computed for the features contained in h_in. - h_dx[..., : h_in.shape[-1]] = h_dx[..., : h_in.shape[-1]] + h_self - return h_dx - - -def aggregate_dir_dx_abs_balanced( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the same as `aggregate_dir_dx_no_abs`, but the positive and - negative parts of the field are normalized separately. - - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - - """ - grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) - eig_front = torch.relu(grad) / (torch.sum(torch.relu(grad), keepdim=True, dim=1) + EPS) - eig_back = torch.relu(-grad) / (torch.sum(torch.relu(-grad), keepdim=True, dim=1) + EPS) - - dir_weight = (eig_front.unsqueeze(-1) + eig_back.unsqueeze(-1)) / 2 - h_mod = h * dir_weight - - h_dx = torch.sum(h_mod, dim=1) - h_self = -torch.sum(dir_weight, dim=1) * h_in - - # In case h_in has more parameters than h (for example when concatenating edges), - # the derivative is only computed for the features contained in h_in. - h_dx[..., : h_in.shape[-1]] = h_dx[..., : h_in.shape[-1]] + h_self - return torch.abs(h_dx) - - -def aggregate_dir_forward( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the following: - - $$y^{(l)} = \hat{F}^+_k h^{(l)}$$ - - - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ - - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. - - $h^{(l)}$ is the node features at the *l-th* layer. - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - """ - grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) - eig_front = torch.relu(grad) / (torch.sum(torch.relu(grad), keepdim=True, dim=1) + EPS) - h_mod = h * eig_front.unsqueeze(-1) - return torch.sum(h_mod, dim=1) - - -def aggregate_dir_backward( - h: Tensor, - source_pos: Tensor, - dest_pos: Tensor, - h_in: Tensor, - dir_idx: int, - temperature: Optional[float] = None, - **kwargs, -) -> Tensor: - r""" - The aggregation is the following: - - $$y^{(l)} = \hat{F}^-_k h^{(l)}$$ - - - $\hat{F}^-_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ - - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. - - $h^{(l)}$ is the node features at the *l-th* layer. - - Parameters: - - h: The features to aggregate $h^{(l)}$ - source_pos: The positional encoding at the source node, used to compute the directional field - dest_pos: The positional encoding at the destination node, used to compute the directional field - h_in: The input features of the layer, before any operation. - dir_idx: The index of the positional encoding ($k$ in the equation above) - temperature: The temperature to use in the softmax of the directional field. - If `None`, then the softmax is not applied on the field - - Returns: - h_mod: The aggregated features $y^{(l)}$ - """ - return aggregate_dir_forward(h, -source_pos, -dest_pos, h_in, dir_idx, temperature, **kwargs) - - -DGN_AGGREGATORS = { - "smooth": aggregate_dir_smooth, - "dx_abs": aggregate_dir_dx_abs, - "dx_no_abs": aggregate_dir_dx_no_abs, - "dx_abs_balanced": aggregate_dir_dx_abs_balanced, - "forward": aggregate_dir_forward, - "backward": aggregate_dir_backward, -} +from typing import Optional +import torch +from torch import Tensor + + +EPS = 1e-8 + + +def get_grad_of_pos( + source_pos: Tensor, dest_pos: Tensor, dir_idx: int, temperature: Optional[float] = None +) -> Tensor: + r""" + Get the vector field associated to the gradient of the positional + encoding. + + $$F_k = \nabla pos_k$$ + + or, if a temperature $T$ is provided + + $$F_k = softmax((\nabla pos_k)^+) - softmax((\nabla pos_k)^-)$$ + + Where $F_k$ is the *k-th* directional field associated to the $k-th$ positional + encoding. + + Parameters: + + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + """ + + grad = source_pos[:, :, dir_idx] - dest_pos[:, :, dir_idx] + if temperature is not None: + grad_pos, grad_neg = grad >= 0, grad <= 0 + grad_plus, grad_minus = grad, -grad + + # Compute softmax, considering one sign at a time. + grad_plus[grad_neg] = -float("inf") + grad_minus[grad_pos] = -float("inf") + grad_plus = torch.nn.Softmax(1)(temperature * grad_plus) + grad_minus = torch.nn.Softmax(1)(temperature * grad_minus) + grad_plus[grad_neg] = 0 + grad_minus[grad_pos] = 0 + + grad = grad_plus - grad_minus + + return grad + + +def aggregate_dir_smooth( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the following: + + $$y^{(l)} = |\hat{F}_k| h^{(l)}$$ + + - $\hat{F}^+_k$ is the normalized directional field *k-th* directional field $F_k$ + - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. + - $h^{(l)}$ is the node features at the *l-th* layer. + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + + """ + grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) + h_mod = h * (grad.abs() / (torch.sum(grad.abs(), keepdim=True, dim=1) + EPS)).unsqueeze(-1) + return torch.sum(h_mod, dim=1) + + +def aggregate_dir_dx_abs( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the following: + + $$y^{(l)} = |B_{dx}^k h^{(l)}|$$ + + $$B_{dx}^k = \hat{F}_k - diag \left(\sum_j{\hat{F}_{k_{(:, j)}}} \right)$$ + + - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ + - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. + - $h^{(l)}$ is the node features at the *l-th* layer. + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + + """ + return torch.abs(aggregate_dir_dx_no_abs(h, source_pos, dest_pos, h_in, dir_idx, temperature, **kwargs)) + + +def aggregate_dir_dx_no_abs( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the following: + + $$y^{(l)} = B_{dx}^k h^{(l)}$$ + + $$B_{dx}^k = \hat{F}_k - diag \left(\sum_j{\hat{F}_{k_{(:, j)}}} \right)$$ + + - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ + - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. + - $h^{(l)}$ is the node features at the *l-th* layer. + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + + """ + grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) + dir_weight = (grad / (torch.sum(grad.abs(), keepdim=True, dim=1) + EPS)).unsqueeze(-1) + h_mod = h * dir_weight + + h_dx = torch.sum(h_mod, dim=1) + h_self = -torch.sum(dir_weight, dim=1) * h_in + + # In case h_in has more parameters than h (for example when concatenating edges), + # the derivative is only computed for the features contained in h_in. + h_dx[..., : h_in.shape[-1]] = h_dx[..., : h_in.shape[-1]] + h_self + return h_dx + + +def aggregate_dir_dx_abs_balanced( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the same as `aggregate_dir_dx_no_abs`, but the positive and + negative parts of the field are normalized separately. + + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + + """ + grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) + eig_front = torch.relu(grad) / (torch.sum(torch.relu(grad), keepdim=True, dim=1) + EPS) + eig_back = torch.relu(-grad) / (torch.sum(torch.relu(-grad), keepdim=True, dim=1) + EPS) + + dir_weight = (eig_front.unsqueeze(-1) + eig_back.unsqueeze(-1)) / 2 + h_mod = h * dir_weight + + h_dx = torch.sum(h_mod, dim=1) + h_self = -torch.sum(dir_weight, dim=1) * h_in + + # In case h_in has more parameters than h (for example when concatenating edges), + # the derivative is only computed for the features contained in h_in. + h_dx[..., : h_in.shape[-1]] = h_dx[..., : h_in.shape[-1]] + h_self + return torch.abs(h_dx) + + +def aggregate_dir_forward( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the following: + + $$y^{(l)} = \hat{F}^+_k h^{(l)}$$ + + - $\hat{F}^+_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ + - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. + - $h^{(l)}$ is the node features at the *l-th* layer. + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + """ + grad = get_grad_of_pos(source_pos=source_pos, dest_pos=dest_pos, dir_idx=dir_idx, temperature=temperature) + eig_front = torch.relu(grad) / (torch.sum(torch.relu(grad), keepdim=True, dim=1) + EPS) + h_mod = h * eig_front.unsqueeze(-1) + return torch.sum(h_mod, dim=1) + + +def aggregate_dir_backward( + h: Tensor, + source_pos: Tensor, + dest_pos: Tensor, + h_in: Tensor, + dir_idx: int, + temperature: Optional[float] = None, + **kwargs, +) -> Tensor: + r""" + The aggregation is the following: + + $$y^{(l)} = \hat{F}^-_k h^{(l)}$$ + + - $\hat{F}^-_k$ is the normalized positive component of the directional field *k-th* directional field $F_k$ + - $y^{(l)}$ is the returned aggregated result at the *l-th* layer. + - $h^{(l)}$ is the node features at the *l-th* layer. + + Parameters: + + h: The features to aggregate $h^{(l)}$ + source_pos: The positional encoding at the source node, used to compute the directional field + dest_pos: The positional encoding at the destination node, used to compute the directional field + h_in: The input features of the layer, before any operation. + dir_idx: The index of the positional encoding ($k$ in the equation above) + temperature: The temperature to use in the softmax of the directional field. + If `None`, then the softmax is not applied on the field + + Returns: + h_mod: The aggregated features $y^{(l)}$ + """ + return aggregate_dir_forward(h, -source_pos, -dest_pos, h_in, dir_idx, temperature, **kwargs) + + +DGN_AGGREGATORS = { + "smooth": aggregate_dir_smooth, + "dx_abs": aggregate_dir_dx_abs, + "dx_no_abs": aggregate_dir_dx_no_abs, + "dx_abs_balanced": aggregate_dir_dx_abs_balanced, + "forward": aggregate_dir_forward, + "backward": aggregate_dir_backward, +} diff --git a/goli/nn/pna_operations.py b/goli/nn/pna_operations.py index b819bb0b2..d717358e8 100644 --- a/goli/nn/pna_operations.py +++ b/goli/nn/pna_operations.py @@ -1,88 +1,88 @@ -import torch -import numpy as np -from functools import partial - -EPS = 1e-5 - - -def aggregate_mean(h, **kwargs): - return torch.mean(h, dim=-2) - - -def aggregate_max(h, **kwargs): - return torch.max(h, dim=-2)[0] - - -def aggregate_min(h, **kwargs): - return torch.min(h, dim=-2)[0] - - -def aggregate_std(h, **kwargs): - return torch.sqrt(aggregate_var(h) + EPS) - - -def aggregate_mean_laplacian(h, h_in, **kwargs): - # In case h_in has more parameters than h (for example when concatenating edges), - # the laplacian is only computed for the features contained in h_in. - lap = -aggregate_mean(h, **kwargs) - lap[..., : h_in.shape[-1]] = h_in + lap[..., : h_in.shape[-1]] - return lap - - -def aggregate_var(h, **kwargs): - h_mean_squares = torch.mean(h * h, dim=-2) - h_mean = torch.mean(h, dim=-2) - var = torch.relu(h_mean_squares - h_mean * h_mean) - return var - - -def aggregate_moment(h, n=3, **kwargs): - # for each node (E[(X-E[X])^n])^{1/n} - # EPS is added to the absolute value of expectation before taking the nth root for stability - h_mean = torch.mean(h, dim=-2, keepdim=True) - h_n = torch.mean(torch.pow(h - h_mean, n), dim=-2) - rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1.0 / n) - return rooted_h_n - - -def aggregate_sum(h, **kwargs): - return torch.sum(h, dim=-2) - - -# each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and -# avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output - - -def scale_identity(h, D=None, avg_d=None): - return h - - -def scale_amplification(h, D, avg_d): - # log(D + 1) / d * h where d is the average of the ``log(D + 1)`` in the training set - return h * (np.log(D + 1) / avg_d["log"]) - - -def scale_attenuation(h, D, avg_d): - # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set - return h * (avg_d["log"] / np.log(D + 1)) - - -PNA_AGGREGATORS = { - "mean": aggregate_mean, - "sum": aggregate_sum, - "max": aggregate_max, - "min": aggregate_min, - "std": aggregate_std, - "var": aggregate_var, - "lap": aggregate_mean_laplacian, - "moment3": partial(aggregate_moment, n=3), - "moment4": partial(aggregate_moment, n=4), - "moment5": partial(aggregate_moment, n=5), -} - - -PNA_SCALERS = { - "identity": scale_identity, - "amplification": scale_amplification, - "attenuation": scale_attenuation, -} +import torch +import numpy as np +from functools import partial + +EPS = 1e-5 + + +def aggregate_mean(h, **kwargs): + return torch.mean(h, dim=-2) + + +def aggregate_max(h, **kwargs): + return torch.max(h, dim=-2)[0] + + +def aggregate_min(h, **kwargs): + return torch.min(h, dim=-2)[0] + + +def aggregate_std(h, **kwargs): + return torch.sqrt(aggregate_var(h) + EPS) + + +def aggregate_mean_laplacian(h, h_in, **kwargs): + # In case h_in has more parameters than h (for example when concatenating edges), + # the laplacian is only computed for the features contained in h_in. + lap = -aggregate_mean(h, **kwargs) + lap[..., : h_in.shape[-1]] = h_in + lap[..., : h_in.shape[-1]] + return lap + + +def aggregate_var(h, **kwargs): + h_mean_squares = torch.mean(h * h, dim=-2) + h_mean = torch.mean(h, dim=-2) + var = torch.relu(h_mean_squares - h_mean * h_mean) + return var + + +def aggregate_moment(h, n=3, **kwargs): + # for each node (E[(X-E[X])^n])^{1/n} + # EPS is added to the absolute value of expectation before taking the nth root for stability + h_mean = torch.mean(h, dim=-2, keepdim=True) + h_n = torch.mean(torch.pow(h - h_mean, n), dim=-2) + rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1.0 / n) + return rooted_h_n + + +def aggregate_sum(h, **kwargs): + return torch.sum(h, dim=-2) + + +# each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and +# avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output + + +def scale_identity(h, D=None, avg_d=None): + return h + + +def scale_amplification(h, D, avg_d): + # log(D + 1) / d * h where d is the average of the ``log(D + 1)`` in the training set + return h * (np.log(D + 1) / avg_d["log"]) + + +def scale_attenuation(h, D, avg_d): + # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set + return h * (avg_d["log"] / np.log(D + 1)) + + +PNA_AGGREGATORS = { + "mean": aggregate_mean, + "sum": aggregate_sum, + "max": aggregate_max, + "min": aggregate_min, + "std": aggregate_std, + "var": aggregate_var, + "lap": aggregate_mean_laplacian, + "moment3": partial(aggregate_moment, n=3), + "moment4": partial(aggregate_moment, n=4), + "moment5": partial(aggregate_moment, n=5), +} + + +PNA_SCALERS = { + "identity": scale_identity, + "amplification": scale_amplification, + "attenuation": scale_attenuation, +} diff --git a/goli/nn/residual_connections.py b/goli/nn/residual_connections.py index f23399f98..ea9b1eec2 100644 --- a/goli/nn/residual_connections.py +++ b/goli/nn/residual_connections.py @@ -1,523 +1,523 @@ -""" -Different types of residual connections, including None, Simple (ResNet-like), -Concat and DenseNet -""" - -import abc -import torch -import torch.nn as nn -from typing import List, Union, Callable - -from goli.nn.base_layers import FCLayer -from goli.utils.decorators import classproperty - - -class ResidualConnectionBase(nn.Module): - def __init__(self, skip_steps: int = 1): - r""" - Abstract class for the residual connections. Using this class, - we implement different types of residual connections, such as - the ResNet, weighted-ResNet, skip-concat and DensNet. - - The following methods must be implemented in a children class - - - ``h_dim_increase_type()`` - - ``has_weights()`` - - Parameters: - - skip_steps: int - The number of steps to skip between the residual connections. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - """ - - super().__init__() - self.skip_steps = skip_steps - - def _bool_apply_skip_step(self, step_idx: int): - r""" - Whether to apply the skip connection, depending on the - ``step_idx`` and ``self.skip_steps``. - - Parameters: - - step_idx: int - The current layer step index. - - """ - return (self.skip_steps != 0) and ((step_idx % self.skip_steps) == 0) - - def __repr__(self): - r""" - Controls how the class is printed - """ - return f"{self.__class__.__name__}(skip_steps={self.skip_steps})" - - @classproperty - @abc.abstractmethod - def h_dim_increase_type(cls): - r""" - How does the dimension of the output features increases after each layer? - - Returns: - - h_dim_increase_type: None or str - - ``None``: The dimension of the output features do not change at each layer. - E.g. ResNet. - - - "previous": The dimension of the output features is the concatenation of - the previous layer with the new layer. - - - "cumulative": The dimension of the output features is the concatenation - of all previous layers. - - """ - ... - - def get_true_out_dims(self, out_dims: List): - - true_out_dims = [out_dims[0]] - out_dims_at_skip = [out_dims[0]] - for ii in range(1, len(out_dims) - 1): - - # For the `None` type, don't change the output dims - if self.h_dim_increase_type is None: - true_out_dims.append(out_dims[ii]) - - # For the "previous" type, add the previous layers when the skip connection applies - elif self.h_dim_increase_type == "previous": - if self._bool_apply_skip_step(step_idx=ii): - true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1]) - out_dims_at_skip.append(out_dims[ii]) - else: - true_out_dims.append(out_dims[ii]) - - # For the "cumulative" type, add all previous layers when the skip connection applies - elif self.h_dim_increase_type == "cumulative": - if self._bool_apply_skip_step(step_idx=ii): - true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1]) - out_dims_at_skip.append(true_out_dims[ii]) - else: - true_out_dims.append(out_dims[ii]) - else: - raise ValueError(f"undefined value: {self.h_dim_increase_type}") - - return true_out_dims - - @classproperty - @abc.abstractmethod - def has_weights(cls): - r""" - Returns: - - has_weights: bool - Whether the residual connection uses weights - - """ - ... - - -class ResidualConnectionNone(ResidualConnectionBase): - r""" - No residual connection. - This class is only used for simpler code compatibility - """ - - def __init__(self, skip_steps: int = 1): - super().__init__(skip_steps=skip_steps) - - def __repr__(self): - r""" - Controls how the class is printed - """ - return f"{self.__class__.__name__}" - - @classproperty - def h_dim_increase_type(cls): - r""" - Returns: - - None: - The dimension of the output features do not change at each layer. - """ - - return None - - @classproperty - def has_weights(cls): - r""" - Returns: - - False - The current class does not use weights - - """ - return False - - def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): - r""" - Ignore the skip connection. - - Returns: - - h: torch.Tensor(..., m) - Return same as input. - - h_prev: torch.Tensor(..., m) - Return same as input. - - """ - return h, h_prev - - -class ResidualConnectionSimple(ResidualConnectionBase): - def __init__(self, skip_steps: int = 1): - r""" - Class for the simple residual connections proposed by ResNet, - where the current layer output is summed to a - previous layer output. - - Parameters: - - skip_steps: int - The number of steps to skip between the residual connections. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - """ - super().__init__(skip_steps=skip_steps) - - @classproperty - def h_dim_increase_type(cls): - r""" - Returns: - - None: - The dimension of the output features do not change at each layer. - """ - - return None - - @classproperty - def has_weights(cls): - r""" - Returns: - - False - The current class does not use weights - - """ - return False - - def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): - r""" - Add ``h`` with the previous layers with skip connection ``h_prev``, - similar to ResNet. - - Parameters: - - h: torch.Tensor(..., m) - The current layer features - - h_prev: torch.Tensor(..., m), None - The features from the previous layer with a skip connection. - At ``step_idx==0``, ``h_prev`` can be set to ``None``. - - step_idx: int - Current layer index or step index in the forward loop of the architecture. - - Returns: - - h: torch.Tensor(..., m) - Either return ``h`` unchanged, or the sum with - on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. - - h_prev: torch.Tensor(..., m) - Either return ``h_prev`` unchanged, or the same value as ``h``, - depending on the ``step_idx`` and ``self.skip_steps``. - - """ - if self._bool_apply_skip_step(step_idx): - if step_idx > 0: - h = h + h_prev - h_prev = h - - return h, h_prev - - -class ResidualConnectionWeighted(ResidualConnectionBase): - def __init__( - self, - out_dims, - skip_steps: int = 1, - dropout=0.0, - activation: Union[str, Callable] = "none", - batch_norm=False, - bias=False, - ): - r""" - Class for the simple residual connections proposed by ResNet, - with an added layer in the residual connection itself. - The layer output is summed to a a non-linear transformation - of a previous layer output. - - Parameters: - - skip_steps: int - The number of steps to skip between the residual connections. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - - out_dims: list(int) - list of all output dimensions for the network - that will use this residual connection. - E.g. ``out_dims = [4, 8, 8, 8, 2]``. - - dropout: float - value between 0 and 1.0 representing the percentage of dropout - to use in the weights - - activation: str, Callable - The activation function to use after the skip weights - - batch_norm: bool - Whether to apply batch normalisation after the weights - - bias: bool - Whether to apply add a bias after the weights - - """ - - super().__init__(skip_steps=skip_steps) - - self.residual_list = nn.ModuleList() - self.skip_count = 0 - self.out_dims = out_dims - - for ii in range(0, len(self.out_dims) - 1, self.skip_steps): - this_dim = self.out_dims[ii] - self.residual_list.append( - FCLayer( - this_dim, - this_dim, - activation=activation, - dropout=dropout, - batch_norm=batch_norm, - bias=False, - ) - ) - - @classproperty - def h_dim_increase_type(cls): - r""" - Returns: - - None: - The dimension of the output features do not change at each layer. - """ - return None - - @classproperty - def has_weights(cls): - r""" - Returns: - - True - The current class uses weights - - """ - return True - - def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): - r""" - Add ``h`` with the previous layers with skip connection ``h_prev``, after - a feed-forward layer. - - Parameters: - - h: torch.Tensor(..., m) - The current layer features - - h_prev: torch.Tensor(..., m), None - The features from the previous layer with a skip connection. - At ``step_idx==0``, ``h_prev`` can be set to ``None``. - - step_idx: int - Current layer index or step index in the forward loop of the architecture. - - Returns: - - h: torch.Tensor(..., m) - Either return ``h`` unchanged, or the sum with the output of a NN layer - on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. - - h_prev: torch.Tensor(..., m) - Either return ``h_prev`` unchanged, or the same value as ``h``, - depending on the ``step_idx`` and ``self.skip_steps``. - - """ - - if self._bool_apply_skip_step(step_idx): - if step_idx > 0: - h = h + self.residual_list[self.skip_count].forward(h_prev) - self.skip_count += 1 - h_prev = h - - return h, h_prev - - def _bool_apply_skip_step(self, step_idx: int): - return super()._bool_apply_skip_step(step_idx) and self.skip_count < len(self.residual_list) - - -class ResidualConnectionConcat(ResidualConnectionBase): - def __init__(self, skip_steps: int = 1): - r""" - Class for the simple residual connections proposed but where - the skip connection features are concatenated to the current - layer features. - - Parameters: - - skip_steps: int - The number of steps to skip between the residual connections. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - """ - - super().__init__(skip_steps=skip_steps) - - @classproperty - def h_dim_increase_type(cls): - r""" - Returns: - - "previous": - The dimension of the output layer is the concatenation with the previous layer. - """ - - return "previous" - - @classproperty - def has_weights(cls): - r""" - Returns: - - False - The current class does not use weights - - """ - return False - - def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): - r""" - Concatenate ``h`` with the previous layers with skip connection ``h_prev``. - - Parameters: - - h: torch.Tensor(..., m) - The current layer features - - h_prev: torch.Tensor(..., n), None - The features from the previous layer with a skip connection. - Usually, we have ``n`` equal to ``m``. - At ``step_idx==0``, ``h_prev`` can be set to ``None``. - - step_idx: int - Current layer index or step index in the forward loop of the architecture. - - Returns: - - h: torch.Tensor(..., m) or torch.Tensor(..., m + n) - Either return ``h`` unchanged, or the concatenation - with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. - - h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n) - Either return ``h_prev`` unchanged, or the same value as ``h``, - depending on the ``step_idx`` and ``self.skip_steps``. - - """ - - if self._bool_apply_skip_step(step_idx): - h_in = h - if step_idx > 0: - h = torch.cat([h, h_prev], dim=-1) - h_prev = h_in - - return h, h_prev - - -class ResidualConnectionDenseNet(ResidualConnectionBase): - def __init__(self, skip_steps: int = 1): - r""" - Class for the residual connections proposed by DenseNet, where - all previous skip connection features are concatenated to the current - layer features. - - Parameters: - - skip_steps: int - The number of steps to skip between the residual connections. - If `1`, all the layers are connected. If `2`, half of the - layers are connected. - """ - - super().__init__(skip_steps=skip_steps) - - @classproperty - def h_dim_increase_type(cls): - r""" - Returns: - - "cumulative": - The dimension of the output layer is the concatenation of all the previous layer. - """ - - return "cumulative" - - @classproperty - def has_weights(cls): - r""" - Returns: - - False - The current class does not use weights - - """ - return False - - def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): - r""" - Concatenate ``h`` with all the previous layers with skip connection ``h_prev``. - - Parameters: - - h: torch.Tensor(..., m) - The current layer features - - h_prev: torch.Tensor(..., n), None - The features from the previous layers. - n = ((step_idx // self.skip_steps) + 1) * m - - At ``step_idx==0``, ``h_prev`` can be set to ``None``. - - step_idx: int - Current layer index or step index in the forward loop of the architecture. - - Returns: - - h: torch.Tensor(..., m) or torch.Tensor(..., m + n) - Either return ``h`` unchanged, or the concatenation - with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. - - h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n) - Either return ``h_prev`` unchanged, or the same value as ``h``, - depending on the ``step_idx`` and ``self.skip_steps``. - - """ - - if self._bool_apply_skip_step(step_idx): - if step_idx > 0: - h = torch.cat([h, h_prev], dim=-1) - h_prev = h - - return h, h_prev +""" +Different types of residual connections, including None, Simple (ResNet-like), +Concat and DenseNet +""" + +import abc +import torch +import torch.nn as nn +from typing import List, Union, Callable + +from goli.nn.base_layers import FCLayer +from goli.utils.decorators import classproperty + + +class ResidualConnectionBase(nn.Module): + def __init__(self, skip_steps: int = 1): + r""" + Abstract class for the residual connections. Using this class, + we implement different types of residual connections, such as + the ResNet, weighted-ResNet, skip-concat and DensNet. + + The following methods must be implemented in a children class + + - ``h_dim_increase_type()`` + - ``has_weights()`` + + Parameters: + + skip_steps: int + The number of steps to skip between the residual connections. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + """ + + super().__init__() + self.skip_steps = skip_steps + + def _bool_apply_skip_step(self, step_idx: int): + r""" + Whether to apply the skip connection, depending on the + ``step_idx`` and ``self.skip_steps``. + + Parameters: + + step_idx: int + The current layer step index. + + """ + return (self.skip_steps != 0) and ((step_idx % self.skip_steps) == 0) + + def __repr__(self): + r""" + Controls how the class is printed + """ + return f"{self.__class__.__name__}(skip_steps={self.skip_steps})" + + @classproperty + @abc.abstractmethod + def h_dim_increase_type(cls): + r""" + How does the dimension of the output features increases after each layer? + + Returns: + + h_dim_increase_type: None or str + - ``None``: The dimension of the output features do not change at each layer. + E.g. ResNet. + + - "previous": The dimension of the output features is the concatenation of + the previous layer with the new layer. + + - "cumulative": The dimension of the output features is the concatenation + of all previous layers. + + """ + ... + + def get_true_out_dims(self, out_dims: List): + + true_out_dims = [out_dims[0]] + out_dims_at_skip = [out_dims[0]] + for ii in range(1, len(out_dims) - 1): + + # For the `None` type, don't change the output dims + if self.h_dim_increase_type is None: + true_out_dims.append(out_dims[ii]) + + # For the "previous" type, add the previous layers when the skip connection applies + elif self.h_dim_increase_type == "previous": + if self._bool_apply_skip_step(step_idx=ii): + true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1]) + out_dims_at_skip.append(out_dims[ii]) + else: + true_out_dims.append(out_dims[ii]) + + # For the "cumulative" type, add all previous layers when the skip connection applies + elif self.h_dim_increase_type == "cumulative": + if self._bool_apply_skip_step(step_idx=ii): + true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1]) + out_dims_at_skip.append(true_out_dims[ii]) + else: + true_out_dims.append(out_dims[ii]) + else: + raise ValueError(f"undefined value: {self.h_dim_increase_type}") + + return true_out_dims + + @classproperty + @abc.abstractmethod + def has_weights(cls): + r""" + Returns: + + has_weights: bool + Whether the residual connection uses weights + + """ + ... + + +class ResidualConnectionNone(ResidualConnectionBase): + r""" + No residual connection. + This class is only used for simpler code compatibility + """ + + def __init__(self, skip_steps: int = 1): + super().__init__(skip_steps=skip_steps) + + def __repr__(self): + r""" + Controls how the class is printed + """ + return f"{self.__class__.__name__}" + + @classproperty + def h_dim_increase_type(cls): + r""" + Returns: + + None: + The dimension of the output features do not change at each layer. + """ + + return None + + @classproperty + def has_weights(cls): + r""" + Returns: + + False + The current class does not use weights + + """ + return False + + def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): + r""" + Ignore the skip connection. + + Returns: + + h: torch.Tensor(..., m) + Return same as input. + + h_prev: torch.Tensor(..., m) + Return same as input. + + """ + return h, h_prev + + +class ResidualConnectionSimple(ResidualConnectionBase): + def __init__(self, skip_steps: int = 1): + r""" + Class for the simple residual connections proposed by ResNet, + where the current layer output is summed to a + previous layer output. + + Parameters: + + skip_steps: int + The number of steps to skip between the residual connections. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + """ + super().__init__(skip_steps=skip_steps) + + @classproperty + def h_dim_increase_type(cls): + r""" + Returns: + + None: + The dimension of the output features do not change at each layer. + """ + + return None + + @classproperty + def has_weights(cls): + r""" + Returns: + + False + The current class does not use weights + + """ + return False + + def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): + r""" + Add ``h`` with the previous layers with skip connection ``h_prev``, + similar to ResNet. + + Parameters: + + h: torch.Tensor(..., m) + The current layer features + + h_prev: torch.Tensor(..., m), None + The features from the previous layer with a skip connection. + At ``step_idx==0``, ``h_prev`` can be set to ``None``. + + step_idx: int + Current layer index or step index in the forward loop of the architecture. + + Returns: + + h: torch.Tensor(..., m) + Either return ``h`` unchanged, or the sum with + on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. + + h_prev: torch.Tensor(..., m) + Either return ``h_prev`` unchanged, or the same value as ``h``, + depending on the ``step_idx`` and ``self.skip_steps``. + + """ + if self._bool_apply_skip_step(step_idx): + if step_idx > 0: + h = h + h_prev + h_prev = h + + return h, h_prev + + +class ResidualConnectionWeighted(ResidualConnectionBase): + def __init__( + self, + out_dims, + skip_steps: int = 1, + dropout=0.0, + activation: Union[str, Callable] = "none", + batch_norm=False, + bias=False, + ): + r""" + Class for the simple residual connections proposed by ResNet, + with an added layer in the residual connection itself. + The layer output is summed to a a non-linear transformation + of a previous layer output. + + Parameters: + + skip_steps: int + The number of steps to skip between the residual connections. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + + out_dims: list(int) + list of all output dimensions for the network + that will use this residual connection. + E.g. ``out_dims = [4, 8, 8, 8, 2]``. + + dropout: float + value between 0 and 1.0 representing the percentage of dropout + to use in the weights + + activation: str, Callable + The activation function to use after the skip weights + + batch_norm: bool + Whether to apply batch normalisation after the weights + + bias: bool + Whether to apply add a bias after the weights + + """ + + super().__init__(skip_steps=skip_steps) + + self.residual_list = nn.ModuleList() + self.skip_count = 0 + self.out_dims = out_dims + + for ii in range(0, len(self.out_dims) - 1, self.skip_steps): + this_dim = self.out_dims[ii] + self.residual_list.append( + FCLayer( + this_dim, + this_dim, + activation=activation, + dropout=dropout, + batch_norm=batch_norm, + bias=False, + ) + ) + + @classproperty + def h_dim_increase_type(cls): + r""" + Returns: + + None: + The dimension of the output features do not change at each layer. + """ + return None + + @classproperty + def has_weights(cls): + r""" + Returns: + + True + The current class uses weights + + """ + return True + + def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): + r""" + Add ``h`` with the previous layers with skip connection ``h_prev``, after + a feed-forward layer. + + Parameters: + + h: torch.Tensor(..., m) + The current layer features + + h_prev: torch.Tensor(..., m), None + The features from the previous layer with a skip connection. + At ``step_idx==0``, ``h_prev`` can be set to ``None``. + + step_idx: int + Current layer index or step index in the forward loop of the architecture. + + Returns: + + h: torch.Tensor(..., m) + Either return ``h`` unchanged, or the sum with the output of a NN layer + on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. + + h_prev: torch.Tensor(..., m) + Either return ``h_prev`` unchanged, or the same value as ``h``, + depending on the ``step_idx`` and ``self.skip_steps``. + + """ + + if self._bool_apply_skip_step(step_idx): + if step_idx > 0: + h = h + self.residual_list[self.skip_count].forward(h_prev) + self.skip_count += 1 + h_prev = h + + return h, h_prev + + def _bool_apply_skip_step(self, step_idx: int): + return super()._bool_apply_skip_step(step_idx) and self.skip_count < len(self.residual_list) + + +class ResidualConnectionConcat(ResidualConnectionBase): + def __init__(self, skip_steps: int = 1): + r""" + Class for the simple residual connections proposed but where + the skip connection features are concatenated to the current + layer features. + + Parameters: + + skip_steps: int + The number of steps to skip between the residual connections. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + """ + + super().__init__(skip_steps=skip_steps) + + @classproperty + def h_dim_increase_type(cls): + r""" + Returns: + + "previous": + The dimension of the output layer is the concatenation with the previous layer. + """ + + return "previous" + + @classproperty + def has_weights(cls): + r""" + Returns: + + False + The current class does not use weights + + """ + return False + + def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): + r""" + Concatenate ``h`` with the previous layers with skip connection ``h_prev``. + + Parameters: + + h: torch.Tensor(..., m) + The current layer features + + h_prev: torch.Tensor(..., n), None + The features from the previous layer with a skip connection. + Usually, we have ``n`` equal to ``m``. + At ``step_idx==0``, ``h_prev`` can be set to ``None``. + + step_idx: int + Current layer index or step index in the forward loop of the architecture. + + Returns: + + h: torch.Tensor(..., m) or torch.Tensor(..., m + n) + Either return ``h`` unchanged, or the concatenation + with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. + + h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n) + Either return ``h_prev`` unchanged, or the same value as ``h``, + depending on the ``step_idx`` and ``self.skip_steps``. + + """ + + if self._bool_apply_skip_step(step_idx): + h_in = h + if step_idx > 0: + h = torch.cat([h, h_prev], dim=-1) + h_prev = h_in + + return h, h_prev + + +class ResidualConnectionDenseNet(ResidualConnectionBase): + def __init__(self, skip_steps: int = 1): + r""" + Class for the residual connections proposed by DenseNet, where + all previous skip connection features are concatenated to the current + layer features. + + Parameters: + + skip_steps: int + The number of steps to skip between the residual connections. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + """ + + super().__init__(skip_steps=skip_steps) + + @classproperty + def h_dim_increase_type(cls): + r""" + Returns: + + "cumulative": + The dimension of the output layer is the concatenation of all the previous layer. + """ + + return "cumulative" + + @classproperty + def has_weights(cls): + r""" + Returns: + + False + The current class does not use weights + + """ + return False + + def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): + r""" + Concatenate ``h`` with all the previous layers with skip connection ``h_prev``. + + Parameters: + + h: torch.Tensor(..., m) + The current layer features + + h_prev: torch.Tensor(..., n), None + The features from the previous layers. + n = ((step_idx // self.skip_steps) + 1) * m + + At ``step_idx==0``, ``h_prev`` can be set to ``None``. + + step_idx: int + Current layer index or step index in the forward loop of the architecture. + + Returns: + + h: torch.Tensor(..., m) or torch.Tensor(..., m + n) + Either return ``h`` unchanged, or the concatenation + with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``. + + h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n) + Either return ``h_prev`` unchanged, or the same value as ``h``, + depending on the ``step_idx`` and ``self.skip_steps``. + + """ + + if self._bool_apply_skip_step(step_idx): + if step_idx > 0: + h = torch.cat([h, h_prev], dim=-1) + h_prev = h + + return h, h_prev diff --git a/goli/trainer/__init__.py b/goli/trainer/__init__.py index 4bf6fe9ea..971baaa7c 100644 --- a/goli/trainer/__init__.py +++ b/goli/trainer/__init__.py @@ -1,5 +1,5 @@ -from . import predictor -from . import metrics -from . import model_summary - -from .predictor import PredictorModule +from . import predictor +from . import metrics +from . import model_summary + +from .predictor import PredictorModule diff --git a/goli/trainer/metrics.py b/goli/trainer/metrics.py index 3d505c255..98b21fd1c 100644 --- a/goli/trainer/metrics.py +++ b/goli/trainer/metrics.py @@ -1,314 +1,314 @@ -from typing import Union, Callable, Optional, Dict, Any - -from copy import deepcopy -import torch -from torch.nn import functional as F -import torch.nn as nn -import operator as op - -from pytorch_lightning.metrics.utils import reduce -from pytorch_lightning.metrics.functional import auroc -from pytorch_lightning.metrics.functional import ( - accuracy, - average_precision, - confusion_matrix, - f1, - fbeta, - precision_recall_curve, - precision, - recall, - auroc, - multiclass_auroc, - mean_absolute_error, - mean_squared_error, -) - -from goli.utils.tensor import nan_mean - -EPS = 1e-5 - - -class Thresholder: - def __init__( - self, - threshold: float, - operator: str = "greater", - th_on_preds: bool = True, - th_on_target: bool = False, - target_to_int: bool = False, - ): - - # Basic params - self.threshold = threshold - self.th_on_target = th_on_target - self.th_on_preds = th_on_preds - self.target_to_int = target_to_int - - # Operator can either be a string, or a callable - if isinstance(operator, str): - op_name = operator.lower() - if op_name in ["greater", "gt"]: - op_str = ">" - operator = op.gt - elif op_name in ["lower", "lt"]: - op_str = "<" - operator = op.lt - else: - raise ValueError(f"operator `{op_name}` not supported") - elif callable(operator): - op_str = operator.__name__ - elif operator is None: - pass - else: - raise TypeError(f"operator must be either `str` or `callable`, provided: `{type(operator)}`") - - self.operator = operator - self.op_str = op_str - - def compute(self, preds: torch.Tensor, target: torch.Tensor): - # Apply the threshold on the predictions - if self.th_on_preds: - preds = self.operator(preds, self.threshold) - - # Apply the threshold on the targets - if self.th_on_target: - target = self.operator(target, self.threshold) - - if self.target_to_int: - target = target.to(int) - - return preds, target - - def __call__(self, preds: torch.Tensor, target: torch.Tensor): - return self.compute(preds, target) - - def __repr__(self): - r""" - Control how the class is printed - """ - - return f"{self.op_str}{self.threshold}" - - -def pearsonr(preds: torch.Tensor, target: torch.Tensor, reduction: str = "elementwise_mean") -> torch.Tensor: - r""" - Computes the pearsonr correlation. - - Parameters: - preds: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Returns: - Tensor with the pearsonr - - !!! Example - ``` python linenums="1" - x = torch.tensor([0., 1, 2, 3]) - y = torch.tensor([0., 1, 2, 2]) - pearsonr(x, y) - >>> tensor(0.9439) - ``` - """ - - preds, target = preds.to(torch.float32), target.to(torch.float32) - - shifted_x = preds - torch.mean(preds, dim=0) - shifted_y = target - torch.mean(target, dim=0) - sigma_x = torch.sqrt(torch.sum(shifted_x ** 2, dim=0)) - sigma_y = torch.sqrt(torch.sum(shifted_y ** 2, dim=0)) - - pearson = torch.sum(shifted_x * shifted_y, dim=0) / (sigma_x * sigma_y + EPS) - pearson = torch.clamp(pearson, min=-1, max=1) - pearson = reduce(pearson, reduction=reduction) - return pearson - - -def _get_rank(values): - - arange = torch.arange(values.shape[0], dtype=values.dtype, device=values.device) - - val_sorter = torch.argsort(values, dim=0) - val_rank = torch.empty_like(values) - if values.ndim == 1: - val_rank[val_sorter] = arange - elif values.ndim == 2: - for ii in range(val_rank.shape[1]): - val_rank[val_sorter[:, ii], ii] = arange - else: - raise ValueError(f"Only supports tensors of dimensions 1 and 2, provided dim=`{preds.ndim}`") - - return val_rank - - -def spearmanr(preds: torch.Tensor, target: torch.Tensor, reduction: str = "elementwise_mean") -> torch.Tensor: - r""" - Computes the spearmanr correlation. - - Parameters: - preds: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Returns: - Tensor with the spearmanr - - !!! Example - x = torch.tensor([0., 1, 2, 3]) - y = torch.tensor([0., 1, 2, 1.5]) - spearmanr(x, y) - tensor(0.8) - """ - - spearman = pearsonr(_get_rank(preds), _get_rank(target), reduction=reduction) - return spearman - - -METRICS_CLASSIFICATION = { - "accuracy": accuracy, - "averageprecision": average_precision, - "auroc": auroc, - "confusionmatrix": confusion_matrix, - "f1": f1, - "fbeta": fbeta, - "precisionrecallcurve": precision_recall_curve, - "precision": precision, - "recall": recall, - "multiclass_auroc": multiclass_auroc, -} - -METRICS_REGRESSION = { - "mae": mean_absolute_error, - "mse": mean_squared_error, - "pearsonr": pearsonr, - "spearmanr": spearmanr, -} - -METRICS_DICT = deepcopy(METRICS_CLASSIFICATION) -METRICS_DICT.update(METRICS_REGRESSION) - - -class MetricWrapper: - r""" - Allows to initialize a metric from a name or Callable, and initialize the - `Thresholder` in case the metric requires a threshold. - """ - - def __init__( - self, - metric: Union[str, Callable], - threshold_kwargs: Optional[Dict[str, Any]] = None, - target_nan_mask: Optional[Union[str, int]] = None, - **kwargs, - ): - r""" - Parameters - metric: - The metric to use. See `METRICS_DICT` - - threshold_kwargs: - If `None`, no threshold is applied. - Otherwise, we use the class `Thresholder` is initialized with the - provided argument, and called before the `compute` - - target_nan_mask: - - - None: Do not change behaviour if there are NaNs - - - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then - all NaNs will be replaced by zeros - - - 'ignore-flatten': The Tensor will be reduced to a vector without the NaN values. - - - 'ignore-mean-label': NaNs will be ignored when computing the loss. Note that each column - has a different number of NaNs, so the metric will be computed separately - on each column, and the metric result will be averaged over all columns. - *This option might slowdown the computation if there are too many labels* - - kwargs: - Other arguments to call with the metric - """ - - self.metric = METRICS_DICT[metric] if isinstance(metric, str) else metric - - self.thresholder = None - if threshold_kwargs is not None: - self.thresholder = Thresholder(**threshold_kwargs) - - self.target_nan_mask = target_nan_mask - - self.kwargs = kwargs - - def compute(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Compute the metric, apply the thresholder if provided, and manage the NaNs - """ - - if preds.ndim == 1: - preds = preds.unsqueeze(-1) - - if target.ndim == 1: - target = target.unsqueeze(-1) - - target_nans = torch.isnan(target) - - # Threshold the prediction - if self.thresholder is not None: - preds, target = self.thresholder(preds, target) - - # Manage the NaNs - if self.target_nan_mask is None: - pass - elif isinstance(self.target_nan_mask, (int, float)): - target = target.clone() - target[torch.isnan(target)] = self.target_nan_mask - elif self.target_nan_mask == "ignore-flatten": - target = target[~target_nans] - preds = preds[~target_nans] - elif self.target_nan_mask == "ignore-mean-label": - target_list = [target[..., ii][~target_nans[..., ii]] for ii in range(target.shape[-1])] - preds_list = [preds[..., ii][~target_nans[..., ii]] for ii in range(preds.shape[-1])] - target = target_list - preds = preds_list - else: - raise ValueError(f"Invalid option `{self.target_nan_mask}`") - - if self.target_nan_mask == "ignore-mean-label": - - # Compute the metric for each column, and output nan if there's an error on a given column - metric_val = [] - for ii in range(len(target)): - try: - metric_val.append(self.metric(preds[ii], target[ii], **self.kwargs)) - except: - pass - - # Average the metric - metric_val = nan_mean(torch.stack(metric_val)) - - else: - metric_val = self.metric(preds, target, **self.kwargs) - return metric_val - - def __call__(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Compute the metric with the method `self.compute` - """ - return self.compute(preds, target) - - def __repr__(self): - r""" - Control how the class is printed - """ - full_str = f"{self.metric.__name__}" - if self.thresholder is not None: - full_str += f"({self.thresholder})" - - return full_str +from typing import Union, Callable, Optional, Dict, Any + +from copy import deepcopy +import torch +from torch.nn import functional as F +import torch.nn as nn +import operator as op + +from pytorch_lightning.metrics.utils import reduce +from pytorch_lightning.metrics.functional import auroc +from pytorch_lightning.metrics.functional import ( + accuracy, + average_precision, + confusion_matrix, + f1, + fbeta, + precision_recall_curve, + precision, + recall, + auroc, + multiclass_auroc, + mean_absolute_error, + mean_squared_error, +) + +from goli.utils.tensor import nan_mean + +EPS = 1e-5 + + +class Thresholder: + def __init__( + self, + threshold: float, + operator: str = "greater", + th_on_preds: bool = True, + th_on_target: bool = False, + target_to_int: bool = False, + ): + + # Basic params + self.threshold = threshold + self.th_on_target = th_on_target + self.th_on_preds = th_on_preds + self.target_to_int = target_to_int + + # Operator can either be a string, or a callable + if isinstance(operator, str): + op_name = operator.lower() + if op_name in ["greater", "gt"]: + op_str = ">" + operator = op.gt + elif op_name in ["lower", "lt"]: + op_str = "<" + operator = op.lt + else: + raise ValueError(f"operator `{op_name}` not supported") + elif callable(operator): + op_str = operator.__name__ + elif operator is None: + pass + else: + raise TypeError(f"operator must be either `str` or `callable`, provided: `{type(operator)}`") + + self.operator = operator + self.op_str = op_str + + def compute(self, preds: torch.Tensor, target: torch.Tensor): + # Apply the threshold on the predictions + if self.th_on_preds: + preds = self.operator(preds, self.threshold) + + # Apply the threshold on the targets + if self.th_on_target: + target = self.operator(target, self.threshold) + + if self.target_to_int: + target = target.to(int) + + return preds, target + + def __call__(self, preds: torch.Tensor, target: torch.Tensor): + return self.compute(preds, target) + + def __repr__(self): + r""" + Control how the class is printed + """ + + return f"{self.op_str}{self.threshold}" + + +def pearsonr(preds: torch.Tensor, target: torch.Tensor, reduction: str = "elementwise_mean") -> torch.Tensor: + r""" + Computes the pearsonr correlation. + + Parameters: + preds: estimated labels + target: ground truth labels + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Returns: + Tensor with the pearsonr + + !!! Example + ``` python linenums="1" + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + pearsonr(x, y) + >>> tensor(0.9439) + ``` + """ + + preds, target = preds.to(torch.float32), target.to(torch.float32) + + shifted_x = preds - torch.mean(preds, dim=0) + shifted_y = target - torch.mean(target, dim=0) + sigma_x = torch.sqrt(torch.sum(shifted_x ** 2, dim=0)) + sigma_y = torch.sqrt(torch.sum(shifted_y ** 2, dim=0)) + + pearson = torch.sum(shifted_x * shifted_y, dim=0) / (sigma_x * sigma_y + EPS) + pearson = torch.clamp(pearson, min=-1, max=1) + pearson = reduce(pearson, reduction=reduction) + return pearson + + +def _get_rank(values): + + arange = torch.arange(values.shape[0], dtype=values.dtype, device=values.device) + + val_sorter = torch.argsort(values, dim=0) + val_rank = torch.empty_like(values) + if values.ndim == 1: + val_rank[val_sorter] = arange + elif values.ndim == 2: + for ii in range(val_rank.shape[1]): + val_rank[val_sorter[:, ii], ii] = arange + else: + raise ValueError(f"Only supports tensors of dimensions 1 and 2, provided dim=`{preds.ndim}`") + + return val_rank + + +def spearmanr(preds: torch.Tensor, target: torch.Tensor, reduction: str = "elementwise_mean") -> torch.Tensor: + r""" + Computes the spearmanr correlation. + + Parameters: + preds: estimated labels + target: ground truth labels + reduction: a method to reduce metric score over labels. + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Returns: + Tensor with the spearmanr + + !!! Example + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 1.5]) + spearmanr(x, y) + tensor(0.8) + """ + + spearman = pearsonr(_get_rank(preds), _get_rank(target), reduction=reduction) + return spearman + + +METRICS_CLASSIFICATION = { + "accuracy": accuracy, + "averageprecision": average_precision, + "auroc": auroc, + "confusionmatrix": confusion_matrix, + "f1": f1, + "fbeta": fbeta, + "precisionrecallcurve": precision_recall_curve, + "precision": precision, + "recall": recall, + "multiclass_auroc": multiclass_auroc, +} + +METRICS_REGRESSION = { + "mae": mean_absolute_error, + "mse": mean_squared_error, + "pearsonr": pearsonr, + "spearmanr": spearmanr, +} + +METRICS_DICT = deepcopy(METRICS_CLASSIFICATION) +METRICS_DICT.update(METRICS_REGRESSION) + + +class MetricWrapper: + r""" + Allows to initialize a metric from a name or Callable, and initialize the + `Thresholder` in case the metric requires a threshold. + """ + + def __init__( + self, + metric: Union[str, Callable], + threshold_kwargs: Optional[Dict[str, Any]] = None, + target_nan_mask: Optional[Union[str, int]] = None, + **kwargs, + ): + r""" + Parameters + metric: + The metric to use. See `METRICS_DICT` + + threshold_kwargs: + If `None`, no threshold is applied. + Otherwise, we use the class `Thresholder` is initialized with the + provided argument, and called before the `compute` + + target_nan_mask: + + - None: Do not change behaviour if there are NaNs + + - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then + all NaNs will be replaced by zeros + + - 'ignore-flatten': The Tensor will be reduced to a vector without the NaN values. + + - 'ignore-mean-label': NaNs will be ignored when computing the loss. Note that each column + has a different number of NaNs, so the metric will be computed separately + on each column, and the metric result will be averaged over all columns. + *This option might slowdown the computation if there are too many labels* + + kwargs: + Other arguments to call with the metric + """ + + self.metric = METRICS_DICT[metric] if isinstance(metric, str) else metric + + self.thresholder = None + if threshold_kwargs is not None: + self.thresholder = Thresholder(**threshold_kwargs) + + self.target_nan_mask = target_nan_mask + + self.kwargs = kwargs + + def compute(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + r""" + Compute the metric, apply the thresholder if provided, and manage the NaNs + """ + + if preds.ndim == 1: + preds = preds.unsqueeze(-1) + + if target.ndim == 1: + target = target.unsqueeze(-1) + + target_nans = torch.isnan(target) + + # Threshold the prediction + if self.thresholder is not None: + preds, target = self.thresholder(preds, target) + + # Manage the NaNs + if self.target_nan_mask is None: + pass + elif isinstance(self.target_nan_mask, (int, float)): + target = target.clone() + target[torch.isnan(target)] = self.target_nan_mask + elif self.target_nan_mask == "ignore-flatten": + target = target[~target_nans] + preds = preds[~target_nans] + elif self.target_nan_mask == "ignore-mean-label": + target_list = [target[..., ii][~target_nans[..., ii]] for ii in range(target.shape[-1])] + preds_list = [preds[..., ii][~target_nans[..., ii]] for ii in range(preds.shape[-1])] + target = target_list + preds = preds_list + else: + raise ValueError(f"Invalid option `{self.target_nan_mask}`") + + if self.target_nan_mask == "ignore-mean-label": + + # Compute the metric for each column, and output nan if there's an error on a given column + metric_val = [] + for ii in range(len(target)): + try: + metric_val.append(self.metric(preds[ii], target[ii], **self.kwargs)) + except: + pass + + # Average the metric + metric_val = nan_mean(torch.stack(metric_val)) + + else: + metric_val = self.metric(preds, target, **self.kwargs) + return metric_val + + def __call__(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + r""" + Compute the metric with the method `self.compute` + """ + return self.compute(preds, target) + + def __repr__(self): + r""" + Control how the class is printed + """ + full_str = f"{self.metric.__name__}" + if self.thresholder is not None: + full_str += f"({self.thresholder})" + + return full_str diff --git a/goli/trainer/model_summary.py b/goli/trainer/model_summary.py index 049c473ed..7e94afaeb 100644 --- a/goli/trainer/model_summary.py +++ b/goli/trainer/model_summary.py @@ -1,74 +1,74 @@ -from typing import List, Tuple -import torch.nn as nn -import pytorch_lightning as pl -from pytorch_lightning.core.memory import ModelSummary - - -class ModelSummaryExtended(ModelSummary): - r""" - Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - The summary is extended to allow different levels. - - Args: - model: The model to summarize (also referred to as the root module) - mode: Can be one of - - - `top` (default): only the top-level modules will be recorded (the children of the root module) - - `full`: summarizes all layers and their submodules in the root module - - `top1`, `top2`, ..., `top11`: summarizes the k-top-level modules - - The string representation of this summary prints a table with columns containing - the name, type and number of parameters for each layer. - - The root module may also have an attribute ``example_input_array`` as shown in the example below. - If present, the root module will be called with it as input to determine the - intermediate input- and output shapes of all layers. Supported are tensors and - nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` - in the summary table. The summary will also display `?` for layers not used in the forward pass. - - """ - - MODE_TOP = "top" - MODE_TOP2 = "top2" - MODE_TOP3 = "top3" - MODE_TOP4 = "top4" - MODE_TOP5 = "top5" - MODE_TOP6 = "top6" - MODE_TOP7 = "top7" - MODE_TOP8 = "top8" - MODE_TOP9 = "top9" - MODE_TOP10 = "top10" - MODE_TOP11 = "top11" - MODE_FULL = "full" - MODE_DEFAULT = MODE_TOP2 - MODES = [ - MODE_TOP, - MODE_TOP2, - MODE_TOP3, - MODE_TOP4, - MODE_TOP5, - MODE_TOP6, - MODE_TOP7, - MODE_TOP8, - MODE_TOP9, - MODE_TOP10, - MODE_TOP11, - MODE_FULL, - ] - - @property - def named_modules(self) -> List[Tuple[str, nn.Module]]: - if self._mode == ModelSummaryExtended.MODE_FULL: - mods = self._model.named_modules() - mods = list(mods)[1:] # do not include root module (LightningModule) - elif self._mode == ModelSummaryExtended.MODE_TOP: - # the children are the top-level modules - mods = self._model.named_children() - elif self._mode[:3] == "top": - depth = int(self._mode[3:]) - mods_full = self._model.named_modules() - mods_full = list(mods_full)[1:] # do not include root module (LightningModule) - mods = [mod for mod in mods_full if mod[0].count(".") < depth] - else: - mods = [] - return list(mods) +from typing import List, Tuple +import torch.nn as nn +import pytorch_lightning as pl +from pytorch_lightning.core.memory import ModelSummary + + +class ModelSummaryExtended(ModelSummary): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + The summary is extended to allow different levels. + + Args: + model: The model to summarize (also referred to as the root module) + mode: Can be one of + + - `top` (default): only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module + - `top1`, `top2`, ..., `top11`: summarizes the k-top-level modules + + The string representation of this summary prints a table with columns containing + the name, type and number of parameters for each layer. + + The root module may also have an attribute ``example_input_array`` as shown in the example below. + If present, the root module will be called with it as input to determine the + intermediate input- and output shapes of all layers. Supported are tensors and + nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` + in the summary table. The summary will also display `?` for layers not used in the forward pass. + + """ + + MODE_TOP = "top" + MODE_TOP2 = "top2" + MODE_TOP3 = "top3" + MODE_TOP4 = "top4" + MODE_TOP5 = "top5" + MODE_TOP6 = "top6" + MODE_TOP7 = "top7" + MODE_TOP8 = "top8" + MODE_TOP9 = "top9" + MODE_TOP10 = "top10" + MODE_TOP11 = "top11" + MODE_FULL = "full" + MODE_DEFAULT = MODE_TOP2 + MODES = [ + MODE_TOP, + MODE_TOP2, + MODE_TOP3, + MODE_TOP4, + MODE_TOP5, + MODE_TOP6, + MODE_TOP7, + MODE_TOP8, + MODE_TOP9, + MODE_TOP10, + MODE_TOP11, + MODE_FULL, + ] + + @property + def named_modules(self) -> List[Tuple[str, nn.Module]]: + if self._mode == ModelSummaryExtended.MODE_FULL: + mods = self._model.named_modules() + mods = list(mods)[1:] # do not include root module (LightningModule) + elif self._mode == ModelSummaryExtended.MODE_TOP: + # the children are the top-level modules + mods = self._model.named_children() + elif self._mode[:3] == "top": + depth = int(self._mode[3:]) + mods_full = self._model.named_modules() + mods_full = list(mods_full)[1:] # do not include root module (LightningModule) + mods = [mod for mod in mods_full if mod[0].count(".") < depth] + else: + mods = [] + return list(mods) diff --git a/goli/trainer/predictor.py b/goli/trainer/predictor.py index 508c3cdc3..b786897a9 100644 --- a/goli/trainer/predictor.py +++ b/goli/trainer/predictor.py @@ -1,572 +1,572 @@ -from typing import Dict, List, Any, Union, Any, Callable, Tuple, Type, Optional -import os -import numpy as np -from copy import deepcopy -import yaml - -import torch -from torch import nn, Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau - -import pytorch_lightning as pl -from pytorch_lightning import _logger as log -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from goli.trainer.model_summary import ModelSummaryExtended -from goli.config.config_convert import recursive_config_reformating -from goli.utils.tensor import nan_mean, nan_std - -LOSS_DICT = { - "mse": torch.nn.MSELoss(), - "bce": torch.nn.BCELoss(), - "l1": torch.nn.L1Loss(), - "mae": torch.nn.L1Loss(), - "cosine": torch.nn.CosineEmbeddingLoss(), -} - -GOLI_PRETRAINED_MODELS = { - "goli-zinc-micro-dummy-test": "gcs://goli-public/pretrained-models/goli-zinc-micro-dummy-test.ckpt" -} - - -class EpochSummary: - r"""Container for collecting epoch-wise results""" - - def __init__(self, monitor="loss", mode: str = "min", metrics_on_progress_bar=[]): - self.monitor = monitor - self.mode = mode - self.metrics_on_progress_bar = metrics_on_progress_bar - self.summaries = {} - self.best_summaries = {} - - class Results: - def __init__( - self, - targets: Tensor, - predictions: Tensor, - loss: float, - metrics: dict, - monitored_metric: str, - n_epochs: int, - ): - self - self.predictions = predictions.detach().cpu() - self.loss = loss.detach().cpu().item() - self.monitored_metric = monitored_metric - self.monitored = metrics[monitored_metric].detach().cpu() - self.metrics = {key: value.tolist() for key, value in metrics.items()} - self.n_epochs = n_epochs - - def set_results(self, name, targets, predictions, loss, metrics, n_epochs) -> float: - metrics[f"loss/{name}"] = loss - self.summaries[name] = EpochSummary.Results( - targets=targets, - predictions=predictions, - loss=loss, - metrics=metrics, - monitored_metric=f"{self.monitor}/{name}", - n_epochs=n_epochs, - ) - if self.is_best_epoch(name, loss, metrics): - self.best_summaries[name] = self.summaries[name] - - def is_best_epoch(self, name, loss, metrics): - if not (name in self.best_summaries.keys()): - return True - - metrics[f"loss/{name}"] = loss - monitor_name = f"{self.monitor}/{name}" - if self.mode == "max": - return metrics[monitor_name] > self.best_summaries[name].monitored - elif self.mode == "min": - return metrics[monitor_name] < self.best_summaries[name].monitored - else: - return ValueError(f"Mode must be 'min' or 'max', provided `{self.mode}`") - - def get_results(self, name): - return self.summaries[name] - - def get_best_results(self, name): - return self.best_summaries[name] - - def get_results_on_progress_bar(self, name): - results = self.summaries[name] - results_prog = { - f"{kk}/{name}": results.metrics[f"{kk}/{name}"] for kk in self.metrics_on_progress_bar - } - return results_prog - - def get_dict_summary(self): - full_dict = {} - # Get metric summaries - full_dict["metric_summaries"] = {} - for key, val in self.summaries.items(): - full_dict["metric_summaries"][key] = {k: v for k, v in val.metrics.items()} - full_dict["metric_summaries"][key]["n_epochs"] = val.n_epochs - - # Get metric summaries at best epoch - full_dict["best_epoch_metric_summaries"] = {} - for key, val in self.best_summaries.items(): - full_dict["best_epoch_metric_summaries"][key] = val.metrics - full_dict["best_epoch_metric_summaries"][key]["n_epochs"] = val.n_epochs - - return full_dict - - -class PredictorModule(pl.LightningModule): - def __init__( - self, - model_class: Type[nn.Module], - model_kwargs: Dict[str, Any], - loss_fun: Union[str, Callable], - random_seed: int = 42, - optim_kwargs: Optional[Dict[str, Any]] = None, - lr_reduce_on_plateau_kwargs: Optional[Dict[str, Any]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - target_nan_mask: Optional[Union[int, float, str]] = None, - metrics: Dict[str, Callable] = None, - metrics_on_progress_bar: List[str] = [], - metrics_on_training_set: Optional[List[str]] = None, - ): - r""" - A class that allows to use regression or classification models easily - with Pytorch-Lightning. - - Parameters: - model_class: - pytorch module used to create a model - - model_kwargs: - Key-word arguments used to initialize the model from `model_class`. - - loss_fun: - Loss function used during training. - Acceptable strings are 'mse', 'bce', 'mae', 'cosine'. - Otherwise, a callable object must be provided, with a method `loss_fun._get_name()`. - - random_seed: - The random seed used by Pytorch to initialize random tensors. - - optim_kwargs: - Dictionnary used to initialize the optimizer, with possible keys below. - - - lr `float`: Learning rate (Default=`1e-3`) - - weight_decay `float`: Weight decay used to regularize the optimizer (Default=`0.`) - - lr_reduce_on_plateau_kwargs: - Dictionnary for the reduction of learning rate when reaching plateau, with possible keys below. - - - factor `float`: Factor by which to reduce the learning rate (Default=`0.5`) - - patience `int`: Number of epochs without improvement to wait before reducing - the learning rate (Default=`10`) - - mode `str`: One of min, max. In min mode, lr will be reduced when the quantity - monitored has stopped decreasing; in max mode it will be reduced when the quantity - monitored has stopped increasing. (Default=`"min"`). - - min_lr `float`: A scalar or a list of scalars. A lower bound on the learning rate - of all param groups or each group respectively (Default=`1e-4`) - - scheduler_kwargs: - Dictionnary for the scheduling of the learning rate modification - - - monitor `str`: metric to track (Default=`"loss/val"`) - - interval `str`: Whether to look at iterations or epochs (Default=`"epoch"`) - - strict `bool`: if set to True will enforce that value specified in monitor is available - while trying to call scheduler.step(), and stop training if not found. If False will - only give a warning and continue training (without calling the scheduler). (Default=`True`) - - frequency `int`: **TODO: NOT REALLY SURE HOW IT WORKS!** (Default=`1`) - - target_nan_mask: - TODO: It's not implemented for the metrics yet!! - - - None: Do not change behaviour if there are nans - - - int, float: Value used to replace nans. For example, if `target_nan_mask==0`, then - all nans will be replaced by zeros - - - 'ignore': Nans will be ignored when computing the loss. - - metrics: - A dictionnary of metrics to compute on the prediction, other than the loss function. - These metrics will be logged into TensorBoard. - - metrics_on_progress_bar: - The metrics names from `metrics` to display also on the progress bar of the training - - metrics_on_training_set: - The metrics names from `metrics` to be computed on the training set for each iteration. - If `None`, all the metrics are computed. Using less metrics can significantly improve - performance, depending on the number of readouts. - - """ - - self.save_hyperparameters() - - torch.random.manual_seed(random_seed) - np.random.seed(random_seed) - - super().__init__() - self.model = model_class(**model_kwargs) - - # Basic attributes - self.loss_fun = self.parse_loss_fun(loss_fun) - self.random_seed = random_seed - self.target_nan_mask = target_nan_mask - self.metrics = metrics if metrics is not None else {} - self.metrics_on_progress_bar = metrics_on_progress_bar - self.metrics_on_training_set = {} if metrics_on_training_set is None else metrics_on_training_set - self.n_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - self.lr_reduce_on_plateau_kwargs = lr_reduce_on_plateau_kwargs - self.optim_kwargs = optim_kwargs - self.scheduler_kwargs = scheduler_kwargs - - # Set the default value for the optimizer - self.optim_kwargs = optim_kwargs if optim_kwargs is not None else {} - self.optim_kwargs.setdefault("lr", 1e-3) - self.optim_kwargs.setdefault("weight_decay", 0.0) - - self.lr_reduce_on_plateau_kwargs = ( - lr_reduce_on_plateau_kwargs if lr_reduce_on_plateau_kwargs is not None else {} - ) - self.lr_reduce_on_plateau_kwargs.setdefault("factor", 0.5) - self.lr_reduce_on_plateau_kwargs.setdefault("patience", 10) - self.lr_reduce_on_plateau_kwargs.setdefault("min_lr", 1e-4) - - self.optim_kwargs = optim_kwargs if optim_kwargs is not None else {} - self.scheduler_kwargs.setdefault("monitor", "loss/val") - self.scheduler_kwargs.setdefault("interval", "epoch") - self.scheduler_kwargs.setdefault("mode", "min") - self.scheduler_kwargs.setdefault("frequency", 1) - self.scheduler_kwargs.setdefault("strict", True) - - monitor = scheduler_kwargs["monitor"].split("/")[0] - mode = scheduler_kwargs["mode"] - self.epoch_summary = EpochSummary( - monitor, mode=mode, metrics_on_progress_bar=self.metrics_on_progress_bar - ) - - # This helps avoid a bug when saving hparams to yaml with different - # dict or str formats - self._set_hparams(recursive_config_reformating(self.hparams)) - - @staticmethod - def parse_loss_fun(loss_fun: Union[str, Callable]) -> Callable: - r""" - Parse the loss function from a string - - Parameters: - loss_fun: - A callable corresponding to the loss function or a string - specifying the loss function from `LOSS_DICT`. Accepted strings are: - "mse", "bce", "l1", "mae", "cosine". - - Returns: - Callable: - Function or callable to compute the loss, takes `preds` and `targets` as inputs. - """ - - if isinstance(loss_fun, str): - loss_fun = LOSS_DICT[loss_fun] - elif not callable(loss_fun): - raise ValueError(f"`loss_fun` must be `str` or `callable`. Provided: {type(loss_fun)}") - - return loss_fun - - def forward(self, inputs: Dict): - r""" - Returns the result of `self.model.forward(*inputs)` on the inputs. - """ - out = self.model.forward(inputs["features"]) - return out - - def configure_optimizers(self): - optimiser = torch.optim.Adam(self.parameters(), **self.optim_kwargs) - - scheduler = { - "scheduler": ReduceLROnPlateau( - optimizer=optimiser, mode=self.scheduler_kwargs["mode"], **self.lr_reduce_on_plateau_kwargs - ), - **self.scheduler_kwargs, - } - return [optimiser], [scheduler] - - @staticmethod - def compute_loss( - preds: Tensor, - targets: Tensor, - weights: Optional[Tensor], - loss_fun: Callable, - target_nan_mask: Union[Type, str] = "ignore", - ) -> Tensor: - r""" - Compute the loss using the specified loss function, and dealing with - the nans in the `targets`. - - Parameters: - preds: - Predicted values - - targets: - Target values - - target_nan_mask: - - - None: Do not change behaviour if there are nans - - - int, float: Value used to replace nans. For example, if `target_nan_mask==0`, then - all nans will be replaced by zeros - - - 'ignore': Nans will be ignored when computing the loss. - - loss_fun: - Loss function to use - - Returns: - Tensor: - Resulting loss - """ - if target_nan_mask is None: - pass - elif isinstance(target_nan_mask, (int, float)): - targets = targets.clone() - targets[torch.isnan(targets)] = target_nan_mask - elif target_nan_mask == "ignore": - nans = torch.isnan(targets) - targets = targets[~nans] - preds = preds[~nans] - else: - raise ValueError(f"Invalid option `{target_nan_mask}`") - - if weights is None: - loss = loss_fun(preds, targets) - else: - loss = loss_fun(preds, targets, weights=weights) - - return loss - - def get_metrics_logs( - self, preds: Tensor, targets: Tensor, weights: Optional[Tensor], step_name: str, loss: Tensor - ) -> Dict[str, Any]: - r""" - Get the logs for the loss and the different metrics, in a format compatible with - Pytorch-Lightning. - - Parameters: - preds: - Predicted values - - targets: - Target values - - step_name: - A string to mention whether the metric is computed on the training, - validation or test set. - - - "train": On the training set - - "val": On the validation set - - "test": On the test set - - """ - - targets = targets.to(dtype=preds.dtype, device=preds.device) - - # Compute the metrics always used in regression tasks - metric_logs = {} - metric_logs[f"mean_pred/{step_name}"] = nan_mean(preds) - metric_logs[f"std_pred/{step_name}"] = nan_std(preds) - metric_logs[f"mean_target/{step_name}"] = nan_mean(targets) - metric_logs[f"std_target/{step_name}"] = nan_std(targets) - - # Specify which metrics to use - metrics_to_use = self.metrics - if step_name == "train": - metrics_to_use = { - key: metric for key, metric in metrics_to_use.items() if key in self.metrics_on_training_set - } - - # Compute the additional metrics - for key, metric in metrics_to_use.items(): - metric_name = f"{key}/{step_name}" - try: - metric_logs[metric_name] = metric(preds, targets) - except Exception as e: - metric_logs[metric_name] = torch.as_tensor(float("nan")) - - # Convert all metrics to CPU, except for the loss - metric_logs[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() - metric_logs = {key: metric.detach().cpu() for key, metric in metric_logs.items()} - - return metric_logs - - def _general_step(self, batch: Dict[str, Tensor], batch_idx: int, step_name: str) -> Dict[str, Any]: - r"""Common code for training_step, validation_step and testing_step""" - preds = self.forward(batch) - targets = batch.pop("labels").to(dtype=preds.dtype) - weights = batch.pop("weights", None) - - loss = self.compute_loss( - preds=preds, - targets=targets, - weights=weights, - target_nan_mask=self.target_nan_mask, - loss_fun=self.loss_fun, - ) - - preds = preds.detach().cpu() - targets = targets.detach().cpu() - if weights is not None: - weights = weights.detach().cpu() - - step_dict = {"preds": preds, "targets": targets, "weights": weights} - step_dict[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() - return loss, step_dict - - def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: - loss, step_dict = self._general_step(batch=batch, batch_idx=batch_idx, step_name="train") - metrics_logs = self.get_metrics_logs( - preds=step_dict["preds"], - targets=step_dict["targets"], - weights=step_dict["weights"], - step_name="train", - loss=loss, - ) - - step_dict.update(metrics_logs) - step_dict["loss"] = loss - - self.logger.log_metrics(metrics_logs, step=self.global_step) - - return step_dict - - def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: - return self._general_step(batch=batch, batch_idx=batch_idx, step_name="val")[1] - - def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: - return self._general_step(batch=batch, batch_idx=batch_idx, step_name="val")[1] - - def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str) -> None: - r"""Common code for training_epoch_end, validation_epoch_end and testing_epoch_end""" - - # Transform the list of dict of dict, into a dict of list of dict - preds = torch.cat([out["preds"] for out in outputs], dim=0) - targets = torch.cat([out["targets"] for out in outputs], dim=0) - if outputs[0]["weights"] is not None: - weights = torch.cat([out["weights"] for out in outputs], dim=0) - else: - weights = None - loss = self.compute_loss( - preds=preds, - targets=targets, - weights=weights, - target_nan_mask=self.target_nan_mask, - loss_fun=self.loss_fun, - ) - metrics_logs = self.get_metrics_logs( - preds=preds, targets=targets, weights=weights, step_name=step_name, loss=loss - ) - - self.epoch_summary.set_results( - name=step_name, - predictions=preds, - targets=targets, - loss=loss, - metrics=metrics_logs, - n_epochs=self.current_epoch, - ) - - return metrics_logs - - def training_epoch_end(self, outputs: Dict): - - self._general_epoch_end(outputs=outputs, step_name="train") - - def validation_epoch_end(self, outputs: List): - - metrics_logs = self._general_epoch_end(outputs=outputs, step_name="val") - - lr = self.optimizers().param_groups[0]["lr"] - metrics_logs["lr"] = lr - metrics_logs["n_epochs"] = self.current_epoch - self.log_dict(metrics_logs) - - # Save yaml file with the metrics summaries - full_dict = {} - full_dict.update(self.epoch_summary.get_dict_summary()) - tb_path = self.logger.log_dir - - # Write the YAML file with the metrics - if self.current_epoch >= 1: - with open(os.path.join(tb_path, "metrics.yaml"), "w") as file: - yaml.dump(full_dict, file) - - def test_epoch_end(self, outputs: List): - - metrics_logs = self._general_epoch_end(outputs=outputs, step_name="test") - self.log_dict(metrics_logs) - - # Save yaml file with the metrics summaries - full_dict = {} - full_dict.update(self.epoch_summary.get_dict_summary()) - tb_path = self.logger.log_dir - os.makedirs(tb_path, exist_ok=True) - with open(f"{tb_path}/metrics.yaml", "w") as file: - yaml.dump(full_dict, file) - - def on_train_start(self): - self.logger.log_hyperparams(self.hparams, self.epoch_summary.get_results("val").metrics) - - def get_progress_bar_dict(self) -> Dict[str, float]: - prog_dict = super().get_progress_bar_dict() - results_on_progress_bar = self.epoch_summary.get_results_on_progress_bar("val") - prog_dict["loss/val"] = self.epoch_summary.summaries["val"].loss - prog_dict.update(results_on_progress_bar) - return prog_dict - - def summarize(self, mode: str = ModelSummaryExtended.MODE_DEFAULT, to_print=True) -> ModelSummaryExtended: - r""" - Provide a summary of the class, usually to be printed - """ - model_summary = None - - if isinstance(mode, int): - mode = ModelSummaryExtended.MODES[mode - 1] - - if mode in ModelSummaryExtended.MODES: - model_summary = ModelSummaryExtended(self, mode=mode) - if to_print: - log.info("\n" + str(model_summary)) - elif mode is not None: - raise MisconfigurationException( - f"`mode` can be None, {', '.join(ModelSummaryExtended.MODES)}, got {mode}" - ) - - return model_summary - - def __repr__(self) -> str: - r""" - Controls how the class is printed - """ - model_str = self.model.__repr__() - summary_str = self.summarize(to_print=False).__repr__() - - return model_str + "\n\n" + summary_str - - @staticmethod - def list_pretrained_models(): - """List available pretrained models.""" - return GOLI_PRETRAINED_MODELS - - @staticmethod - def load_pretrained_models(name: str): - """Load a pretrained model from its name. - - Args: - name: Name of the model to load. List available - from `goli.trainer.PredictorModule.list_pretrained_models()`. - """ - - if name not in GOLI_PRETRAINED_MODELS: - raise ValueError( - f"The model '{name}' is not available. Choose from {set(GOLI_PRETRAINED_MODELS.keys())}." - ) - - return PredictorModule.load_from_checkpoint(GOLI_PRETRAINED_MODELS[name]) +from typing import Dict, List, Any, Union, Any, Callable, Tuple, Type, Optional +import os +import numpy as np +from copy import deepcopy +import yaml + +import torch +from torch import nn, Tensor +from torch.optim.lr_scheduler import ReduceLROnPlateau + +import pytorch_lightning as pl +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from goli.trainer.model_summary import ModelSummaryExtended +from goli.config.config_convert import recursive_config_reformating +from goli.utils.tensor import nan_mean, nan_std + +LOSS_DICT = { + "mse": torch.nn.MSELoss(), + "bce": torch.nn.BCELoss(), + "l1": torch.nn.L1Loss(), + "mae": torch.nn.L1Loss(), + "cosine": torch.nn.CosineEmbeddingLoss(), +} + +GOLI_PRETRAINED_MODELS = { + "goli-zinc-micro-dummy-test": "gcs://goli-public/pretrained-models/goli-zinc-micro-dummy-test.ckpt" +} + + +class EpochSummary: + r"""Container for collecting epoch-wise results""" + + def __init__(self, monitor="loss", mode: str = "min", metrics_on_progress_bar=[]): + self.monitor = monitor + self.mode = mode + self.metrics_on_progress_bar = metrics_on_progress_bar + self.summaries = {} + self.best_summaries = {} + + class Results: + def __init__( + self, + targets: Tensor, + predictions: Tensor, + loss: float, + metrics: dict, + monitored_metric: str, + n_epochs: int, + ): + self + self.predictions = predictions.detach().cpu() + self.loss = loss.detach().cpu().item() + self.monitored_metric = monitored_metric + self.monitored = metrics[monitored_metric].detach().cpu() + self.metrics = {key: value.tolist() for key, value in metrics.items()} + self.n_epochs = n_epochs + + def set_results(self, name, targets, predictions, loss, metrics, n_epochs) -> float: + metrics[f"loss/{name}"] = loss + self.summaries[name] = EpochSummary.Results( + targets=targets, + predictions=predictions, + loss=loss, + metrics=metrics, + monitored_metric=f"{self.monitor}/{name}", + n_epochs=n_epochs, + ) + if self.is_best_epoch(name, loss, metrics): + self.best_summaries[name] = self.summaries[name] + + def is_best_epoch(self, name, loss, metrics): + if not (name in self.best_summaries.keys()): + return True + + metrics[f"loss/{name}"] = loss + monitor_name = f"{self.monitor}/{name}" + if self.mode == "max": + return metrics[monitor_name] > self.best_summaries[name].monitored + elif self.mode == "min": + return metrics[monitor_name] < self.best_summaries[name].monitored + else: + return ValueError(f"Mode must be 'min' or 'max', provided `{self.mode}`") + + def get_results(self, name): + return self.summaries[name] + + def get_best_results(self, name): + return self.best_summaries[name] + + def get_results_on_progress_bar(self, name): + results = self.summaries[name] + results_prog = { + f"{kk}/{name}": results.metrics[f"{kk}/{name}"] for kk in self.metrics_on_progress_bar + } + return results_prog + + def get_dict_summary(self): + full_dict = {} + # Get metric summaries + full_dict["metric_summaries"] = {} + for key, val in self.summaries.items(): + full_dict["metric_summaries"][key] = {k: v for k, v in val.metrics.items()} + full_dict["metric_summaries"][key]["n_epochs"] = val.n_epochs + + # Get metric summaries at best epoch + full_dict["best_epoch_metric_summaries"] = {} + for key, val in self.best_summaries.items(): + full_dict["best_epoch_metric_summaries"][key] = val.metrics + full_dict["best_epoch_metric_summaries"][key]["n_epochs"] = val.n_epochs + + return full_dict + + +class PredictorModule(pl.LightningModule): + def __init__( + self, + model_class: Type[nn.Module], + model_kwargs: Dict[str, Any], + loss_fun: Union[str, Callable], + random_seed: int = 42, + optim_kwargs: Optional[Dict[str, Any]] = None, + lr_reduce_on_plateau_kwargs: Optional[Dict[str, Any]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + target_nan_mask: Optional[Union[int, float, str]] = None, + metrics: Dict[str, Callable] = None, + metrics_on_progress_bar: List[str] = [], + metrics_on_training_set: Optional[List[str]] = None, + ): + r""" + A class that allows to use regression or classification models easily + with Pytorch-Lightning. + + Parameters: + model_class: + pytorch module used to create a model + + model_kwargs: + Key-word arguments used to initialize the model from `model_class`. + + loss_fun: + Loss function used during training. + Acceptable strings are 'mse', 'bce', 'mae', 'cosine'. + Otherwise, a callable object must be provided, with a method `loss_fun._get_name()`. + + random_seed: + The random seed used by Pytorch to initialize random tensors. + + optim_kwargs: + Dictionnary used to initialize the optimizer, with possible keys below. + + - lr `float`: Learning rate (Default=`1e-3`) + - weight_decay `float`: Weight decay used to regularize the optimizer (Default=`0.`) + + lr_reduce_on_plateau_kwargs: + Dictionnary for the reduction of learning rate when reaching plateau, with possible keys below. + + - factor `float`: Factor by which to reduce the learning rate (Default=`0.5`) + - patience `int`: Number of epochs without improvement to wait before reducing + the learning rate (Default=`10`) + - mode `str`: One of min, max. In min mode, lr will be reduced when the quantity + monitored has stopped decreasing; in max mode it will be reduced when the quantity + monitored has stopped increasing. (Default=`"min"`). + - min_lr `float`: A scalar or a list of scalars. A lower bound on the learning rate + of all param groups or each group respectively (Default=`1e-4`) + + scheduler_kwargs: + Dictionnary for the scheduling of the learning rate modification + + - monitor `str`: metric to track (Default=`"loss/val"`) + - interval `str`: Whether to look at iterations or epochs (Default=`"epoch"`) + - strict `bool`: if set to True will enforce that value specified in monitor is available + while trying to call scheduler.step(), and stop training if not found. If False will + only give a warning and continue training (without calling the scheduler). (Default=`True`) + - frequency `int`: **TODO: NOT REALLY SURE HOW IT WORKS!** (Default=`1`) + + target_nan_mask: + TODO: It's not implemented for the metrics yet!! + + - None: Do not change behaviour if there are nans + + - int, float: Value used to replace nans. For example, if `target_nan_mask==0`, then + all nans will be replaced by zeros + + - 'ignore': Nans will be ignored when computing the loss. + + metrics: + A dictionnary of metrics to compute on the prediction, other than the loss function. + These metrics will be logged into TensorBoard. + + metrics_on_progress_bar: + The metrics names from `metrics` to display also on the progress bar of the training + + metrics_on_training_set: + The metrics names from `metrics` to be computed on the training set for each iteration. + If `None`, all the metrics are computed. Using less metrics can significantly improve + performance, depending on the number of readouts. + + """ + + self.save_hyperparameters() + + torch.random.manual_seed(random_seed) + np.random.seed(random_seed) + + super().__init__() + self.model = model_class(**model_kwargs) + + # Basic attributes + self.loss_fun = self.parse_loss_fun(loss_fun) + self.random_seed = random_seed + self.target_nan_mask = target_nan_mask + self.metrics = metrics if metrics is not None else {} + self.metrics_on_progress_bar = metrics_on_progress_bar + self.metrics_on_training_set = list(self.metrics.keys()) if metrics_on_training_set is None else metrics_on_training_set + self.n_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + self.lr_reduce_on_plateau_kwargs = lr_reduce_on_plateau_kwargs + self.optim_kwargs = optim_kwargs + self.scheduler_kwargs = scheduler_kwargs + + # Set the default value for the optimizer + self.optim_kwargs = optim_kwargs if optim_kwargs is not None else {} + self.optim_kwargs.setdefault("lr", 1e-3) + self.optim_kwargs.setdefault("weight_decay", 0.0) + + self.lr_reduce_on_plateau_kwargs = ( + lr_reduce_on_plateau_kwargs if lr_reduce_on_plateau_kwargs is not None else {} + ) + self.lr_reduce_on_plateau_kwargs.setdefault("factor", 0.5) + self.lr_reduce_on_plateau_kwargs.setdefault("patience", 10) + self.lr_reduce_on_plateau_kwargs.setdefault("min_lr", 1e-4) + + self.optim_kwargs = optim_kwargs if optim_kwargs is not None else {} + self.scheduler_kwargs.setdefault("monitor", "loss/val") + self.scheduler_kwargs.setdefault("interval", "epoch") + self.scheduler_kwargs.setdefault("mode", "min") + self.scheduler_kwargs.setdefault("frequency", 1) + self.scheduler_kwargs.setdefault("strict", True) + + monitor = scheduler_kwargs["monitor"].split("/")[0] + mode = scheduler_kwargs["mode"] + self.epoch_summary = EpochSummary( + monitor, mode=mode, metrics_on_progress_bar=self.metrics_on_progress_bar + ) + + # This helps avoid a bug when saving hparams to yaml with different + # dict or str formats + self._set_hparams(recursive_config_reformating(self.hparams)) + + @staticmethod + def parse_loss_fun(loss_fun: Union[str, Callable]) -> Callable: + r""" + Parse the loss function from a string + + Parameters: + loss_fun: + A callable corresponding to the loss function or a string + specifying the loss function from `LOSS_DICT`. Accepted strings are: + "mse", "bce", "l1", "mae", "cosine". + + Returns: + Callable: + Function or callable to compute the loss, takes `preds` and `targets` as inputs. + """ + + if isinstance(loss_fun, str): + loss_fun = LOSS_DICT[loss_fun] + elif not callable(loss_fun): + raise ValueError(f"`loss_fun` must be `str` or `callable`. Provided: {type(loss_fun)}") + + return loss_fun + + def forward(self, inputs: Dict): + r""" + Returns the result of `self.model.forward(*inputs)` on the inputs. + """ + out = self.model.forward(inputs["features"]) + return out + + def configure_optimizers(self): + optimiser = torch.optim.Adam(self.parameters(), **self.optim_kwargs) + + scheduler = { + "scheduler": ReduceLROnPlateau( + optimizer=optimiser, mode=self.scheduler_kwargs["mode"], **self.lr_reduce_on_plateau_kwargs + ), + **self.scheduler_kwargs, + } + return [optimiser], [scheduler] + + @staticmethod + def compute_loss( + preds: Tensor, + targets: Tensor, + weights: Optional[Tensor], + loss_fun: Callable, + target_nan_mask: Union[Type, str] = "ignore", + ) -> Tensor: + r""" + Compute the loss using the specified loss function, and dealing with + the nans in the `targets`. + + Parameters: + preds: + Predicted values + + targets: + Target values + + target_nan_mask: + + - None: Do not change behaviour if there are nans + + - int, float: Value used to replace nans. For example, if `target_nan_mask==0`, then + all nans will be replaced by zeros + + - 'ignore': Nans will be ignored when computing the loss. + + loss_fun: + Loss function to use + + Returns: + Tensor: + Resulting loss + """ + if target_nan_mask is None: + pass + elif isinstance(target_nan_mask, (int, float)): + targets = targets.clone() + targets[torch.isnan(targets)] = target_nan_mask + elif target_nan_mask == "ignore": + nans = torch.isnan(targets) + targets = targets[~nans] + preds = preds[~nans] + else: + raise ValueError(f"Invalid option `{target_nan_mask}`") + + if weights is None: + loss = loss_fun(preds, targets) + else: + loss = loss_fun(preds, targets, weights=weights) + + return loss + + def get_metrics_logs( + self, preds: Tensor, targets: Tensor, weights: Optional[Tensor], step_name: str, loss: Tensor + ) -> Dict[str, Any]: + r""" + Get the logs for the loss and the different metrics, in a format compatible with + Pytorch-Lightning. + + Parameters: + preds: + Predicted values + + targets: + Target values + + step_name: + A string to mention whether the metric is computed on the training, + validation or test set. + + - "train": On the training set + - "val": On the validation set + - "test": On the test set + + """ + + targets = targets.to(dtype=preds.dtype, device=preds.device) + + # Compute the metrics always used in regression tasks + metric_logs = {} + metric_logs[f"mean_pred/{step_name}"] = nan_mean(preds) + metric_logs[f"std_pred/{step_name}"] = nan_std(preds) + metric_logs[f"mean_target/{step_name}"] = nan_mean(targets) + metric_logs[f"std_target/{step_name}"] = nan_std(targets) + + # Specify which metrics to use + metrics_to_use = self.metrics + if step_name == "train": + metrics_to_use = { + key: metric for key, metric in metrics_to_use.items() if key in self.metrics_on_training_set + } + + # Compute the additional metrics + for key, metric in metrics_to_use.items(): + metric_name = f"{key}/{step_name}" + try: + metric_logs[metric_name] = metric(preds, targets) + except Exception as e: + metric_logs[metric_name] = torch.as_tensor(float("nan")) + + # Convert all metrics to CPU, except for the loss + metric_logs[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() + metric_logs = {key: metric.detach().cpu() for key, metric in metric_logs.items()} + + return metric_logs + + def _general_step(self, batch: Dict[str, Tensor], batch_idx: int, step_name: str) -> Dict[str, Any]: + r"""Common code for training_step, validation_step and testing_step""" + preds = self.forward(batch) + targets = batch.pop("labels").to(dtype=preds.dtype) + weights = batch.pop("weights", None) + + loss = self.compute_loss( + preds=preds, + targets=targets, + weights=weights, + target_nan_mask=self.target_nan_mask, + loss_fun=self.loss_fun, + ) + + preds = preds.detach().cpu() + targets = targets.detach().cpu() + if weights is not None: + weights = weights.detach().cpu() + + step_dict = {"preds": preds, "targets": targets, "weights": weights} + step_dict[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() + return loss, step_dict + + def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: + loss, step_dict = self._general_step(batch=batch, batch_idx=batch_idx, step_name="train") + metrics_logs = self.get_metrics_logs( + preds=step_dict["preds"], + targets=step_dict["targets"], + weights=step_dict["weights"], + step_name="train", + loss=loss, + ) + + step_dict.update(metrics_logs) + step_dict["loss"] = loss + + self.logger.log_metrics(metrics_logs, step=self.global_step) + + return step_dict + + def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: + return self._general_step(batch=batch, batch_idx=batch_idx, step_name="val")[1] + + def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Any]: + return self._general_step(batch=batch, batch_idx=batch_idx, step_name="val")[1] + + def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str) -> None: + r"""Common code for training_epoch_end, validation_epoch_end and testing_epoch_end""" + + # Transform the list of dict of dict, into a dict of list of dict + preds = torch.cat([out["preds"] for out in outputs], dim=0) + targets = torch.cat([out["targets"] for out in outputs], dim=0) + if outputs[0]["weights"] is not None: + weights = torch.cat([out["weights"] for out in outputs], dim=0) + else: + weights = None + loss = self.compute_loss( + preds=preds, + targets=targets, + weights=weights, + target_nan_mask=self.target_nan_mask, + loss_fun=self.loss_fun, + ) + metrics_logs = self.get_metrics_logs( + preds=preds, targets=targets, weights=weights, step_name=step_name, loss=loss + ) + + self.epoch_summary.set_results( + name=step_name, + predictions=preds, + targets=targets, + loss=loss, + metrics=metrics_logs, + n_epochs=self.current_epoch, + ) + + return metrics_logs + + def training_epoch_end(self, outputs: Dict): + + self._general_epoch_end(outputs=outputs, step_name="train") + + def validation_epoch_end(self, outputs: List): + + metrics_logs = self._general_epoch_end(outputs=outputs, step_name="val") + + lr = self.optimizers().param_groups[0]["lr"] + metrics_logs["lr"] = lr + metrics_logs["n_epochs"] = self.current_epoch + self.log_dict(metrics_logs) + + # Save yaml file with the metrics summaries + full_dict = {} + full_dict.update(self.epoch_summary.get_dict_summary()) + tb_path = self.logger.log_dir + + # Write the YAML file with the metrics + if self.current_epoch >= 1: + with open(os.path.join(tb_path, "metrics.yaml"), "w") as file: + yaml.dump(full_dict, file) + + def test_epoch_end(self, outputs: List): + + metrics_logs = self._general_epoch_end(outputs=outputs, step_name="test") + self.log_dict(metrics_logs) + + # Save yaml file with the metrics summaries + full_dict = {} + full_dict.update(self.epoch_summary.get_dict_summary()) + tb_path = self.logger.log_dir + os.makedirs(tb_path, exist_ok=True) + with open(f"{tb_path}/metrics.yaml", "w") as file: + yaml.dump(full_dict, file) + + def on_train_start(self): + self.logger.log_hyperparams(self.hparams, self.epoch_summary.get_results("val").metrics) + + def get_progress_bar_dict(self) -> Dict[str, float]: + prog_dict = super().get_progress_bar_dict() + results_on_progress_bar = self.epoch_summary.get_results_on_progress_bar("val") + prog_dict["loss/val"] = self.epoch_summary.summaries["val"].loss + prog_dict.update(results_on_progress_bar) + return prog_dict + + def summarize(self, mode: str = ModelSummaryExtended.MODE_DEFAULT, to_print=True) -> ModelSummaryExtended: + r""" + Provide a summary of the class, usually to be printed + """ + model_summary = None + + if isinstance(mode, int): + mode = ModelSummaryExtended.MODES[mode - 1] + + if mode in ModelSummaryExtended.MODES: + model_summary = ModelSummaryExtended(self, mode=mode) + if to_print: + log.info("\n" + str(model_summary)) + elif mode is not None: + raise MisconfigurationException( + f"`mode` can be None, {', '.join(ModelSummaryExtended.MODES)}, got {mode}" + ) + + return model_summary + + def __repr__(self) -> str: + r""" + Controls how the class is printed + """ + model_str = self.model.__repr__() + summary_str = self.summarize(to_print=False).__repr__() + + return model_str + "\n\n" + summary_str + + @staticmethod + def list_pretrained_models(): + """List available pretrained models.""" + return GOLI_PRETRAINED_MODELS + + @staticmethod + def load_pretrained_models(name: str): + """Load a pretrained model from its name. + + Args: + name: Name of the model to load. List available + from `goli.trainer.PredictorModule.list_pretrained_models()`. + """ + + if name not in GOLI_PRETRAINED_MODELS: + raise ValueError( + f"The model '{name}' is not available. Choose from {set(GOLI_PRETRAINED_MODELS.keys())}." + ) + + return PredictorModule.load_from_checkpoint(GOLI_PRETRAINED_MODELS[name]) diff --git a/goli/utils/__init__.py b/goli/utils/__init__.py index 8157fdd0f..5d8634457 100644 --- a/goli/utils/__init__.py +++ b/goli/utils/__init__.py @@ -1,2 +1,2 @@ -from . import fs -from . import tensor +from . import fs +from . import tensor diff --git a/goli/utils/arg_checker.py b/goli/utils/arg_checker.py index 22712ff58..060c721f8 100644 --- a/goli/utils/arg_checker.py +++ b/goli/utils/arg_checker.py @@ -1,209 +1,209 @@ -""" Argument checker module """ -import collections -import numpy as np - -# Global variable of accepted string types -KNOWN_TYPES = { - "none": None, - "str": str, - "list": list, - "tuple": tuple, - "dict": dict, - "int": int, - "float": float, - "complex": complex, - "bool": bool, - "callable": callable, -} - - -def _parse_type(type_to_validate, accepted_types): - # Check if the provided type is accepted - if (type_to_validate is not None) and (not isinstance(type_to_validate, accepted_types)): - raise TypeError( - "type_to_validate should be None, type or str. {} provided".format(type(type_to_validate)) - ) - if isinstance(type_to_validate, str): - type_to_validate = type_to_validate.lower() - if type_to_validate in KNOWN_TYPES.keys(): - type_to_validate = KNOWN_TYPES[type_to_validate] - else: - raise TypeError( - "type_to_validate is not a known type. Known types are :" - " \n{}\n Provided : \n{}".format(KNOWN_TYPES.keys(), type_to_validate) - ) - return type_to_validate - - -def _enforce_iter_type(arg, enforce_type): - # Cast the arg to be either a list or a tuple - if enforce_type is not None: - if (enforce_type == list) and (not isinstance(arg, list)): - arg = list(arg) - elif (enforce_type == tuple) and (not isinstance(arg, tuple)): - arg = tuple(arg) - elif enforce_type not in (list, tuple): - raise TypeError('enforce_type should be None, "list" or "tuple", but is {}'.format(enforce_type)) - return arg - - -def check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cast_subtype: bool = True): - r""" - Verify if the type is an iterator. If it is `None`, convert to an empty list/tuple. If it is - not a list/tuple/str, try to convert to an iterator. If it is a str or cannot be converted to - an iterator, then put the `arg` inside an iterator. - Possibly enforce the iterator type to `list` or `tuple`, if `enfoce_type` is not None. - Possibly enforce the subtype to any given type if `enforce_subtype` is not None, - and decide whether to cast the subtype or to throw an error. - - Parameters: - arg (any type): - The input to verify/convert to an iterator (list or tuple). If None, an empty iterator - is returned. - enforce_type (str or type): - The type to enforce the iterator. The valid choices are : - `None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`. - If `None`, then the iterator type is not enforced. - - enforce_subtype (type, np.dtype or str representing basic type): - Verify if all the elements inside the iterator are the desired type. - If `None`, then the sub-type is not enforced. - Accepted strings are ['none', 'str', 'list', 'tuple', 'dict', 'int', - 'float', 'complex', 'bool', 'callable'] - - cast_subtype: - If True, then the type specified by `enforce_subtype` is used to cast the - elements inside the iterator. If False, then an error is thrown if the - types do not match. - - Returns: - output (iterator): - An iterator based on the input of the desired type (list or tuple) and - the desired subtypes. - - """ - - # If not list or tuple, put into a list - if arg is None: - arg = [] - elif isinstance(arg, str): - arg = [arg] - elif isinstance(arg, tuple): - if enforce_type is None: - enforce_type = tuple - arg = list(arg) - elif not isinstance(arg, (tuple, list)): - try: - arg = list(arg) - except Exception: - arg = [arg] - - output = arg - - # Make sure that enforce_type and enforce_subtype are a good inputs - enforce_type = _parse_type(enforce_type, (type, str)) - enforce_subtype = _parse_type(enforce_subtype, (type, str, np.dtype)) - - # Cast all the subtypes of the list/tuple into the desired subtype - if enforce_subtype is not None: - if enforce_type is None: - arg2 = output - elif not isinstance(output, enforce_type): - arg2 = list(output) - else: - arg2 = output - try: - for idx, a in enumerate(output): - if not isinstance(a, enforce_subtype): - if cast_subtype: - arg2[idx] = enforce_subtype(a) - else: - raise TypeError( - "iter subtype is {}, desired subtype is {}, " - "but cast_subtype is set to False".format(type(arg2[idx]), enforce_subtype) - ) - except Exception as e: - raise TypeError( - "iterator subtype is {} and cannot be casted to {}\n{}".format(type(a), enforce_subtype, e) - ) - - output = _enforce_iter_type(arg2, enforce_type) - - output = _enforce_iter_type(output, enforce_type) - - return output - - -def check_list1_in_list2(list1, list2, throw_error=True): - r""" - Verify if the list1 (iterator) is included in list2 (iterator). If not, raise an error. - - Parameters: - list1, list2: list, tuple or object - A list or tuple containing the elements to verify the inclusion. - If an object is provided other than a list or tuple, - then it is considered as a list of a single element. - throw_error: bool - Whether to throw an error if list1 is not in list2 - - Returns: - list1_in_list2: bool - A boolean representing the inclusion of list1 in list2. It is returned if - throw_error is set to false - - - """ - - list1 = check_arg_iterator(list1) - list2 = check_arg_iterator(list2) - - # If all elements of list1 are not in list2, throw an error - list1_in_list2 = all(elem in list2 for elem in list1) - if not list1_in_list2 and throw_error: - raise ValueError( - ("Elements in list1 should be contained in list2." + "\n\nlist1 = {} \n\n list2 = {}").format( - list1, list2 - ) - ) - - return list1_in_list2 - - -def check_columns_choice(dataframe, columns_choice, extra_accepted_cols=None, enforce_type="list"): - r""" - Verify if the choice of column `columns_choice` is inside the dataframe or - the extra_accepted_cols. Otherwise, errors are thrown by the sub-functions. - - Parameters: - dataframe: (pd.DataFrame) - The dataframe on which to verify if the column choice is valid. - columns_choice: str, iterator(str) - The columns chosen from the dataframe - extra_accepted_cols: str, iterator(str) - A list - - enforce_type: str or type - The type to enforce the iterator. The valid choices are : - `None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`. - If `None`, then the iterator type is not enforced. - - - Returns: - output: iterator - A str iterator based on the input of the desired type (list or tuple) - - """ - extra_accepted_cols = [] if extra_accepted_cols is None else extra_accepted_cols - valid_columns = list(dataframe.columns) - kwargs_iterator = { - "enforce_type": enforce_type, - "enforce_subtype": None, - "cast_subtype": False, - } - columns_choice = check_arg_iterator(columns_choice, **kwargs_iterator) - extra_accepted_cols = check_arg_iterator(extra_accepted_cols, **kwargs_iterator) - valid_columns = check_arg_iterator(valid_columns, **kwargs_iterator) - valid_columns_full = valid_columns + extra_accepted_cols - check_list1_in_list2(columns_choice, valid_columns_full) - - return columns_choice +""" Argument checker module """ +import collections +import numpy as np + +# Global variable of accepted string types +KNOWN_TYPES = { + "none": None, + "str": str, + "list": list, + "tuple": tuple, + "dict": dict, + "int": int, + "float": float, + "complex": complex, + "bool": bool, + "callable": callable, +} + + +def _parse_type(type_to_validate, accepted_types): + # Check if the provided type is accepted + if (type_to_validate is not None) and (not isinstance(type_to_validate, accepted_types)): + raise TypeError( + "type_to_validate should be None, type or str. {} provided".format(type(type_to_validate)) + ) + if isinstance(type_to_validate, str): + type_to_validate = type_to_validate.lower() + if type_to_validate in KNOWN_TYPES.keys(): + type_to_validate = KNOWN_TYPES[type_to_validate] + else: + raise TypeError( + "type_to_validate is not a known type. Known types are :" + " \n{}\n Provided : \n{}".format(KNOWN_TYPES.keys(), type_to_validate) + ) + return type_to_validate + + +def _enforce_iter_type(arg, enforce_type): + # Cast the arg to be either a list or a tuple + if enforce_type is not None: + if (enforce_type == list) and (not isinstance(arg, list)): + arg = list(arg) + elif (enforce_type == tuple) and (not isinstance(arg, tuple)): + arg = tuple(arg) + elif enforce_type not in (list, tuple): + raise TypeError('enforce_type should be None, "list" or "tuple", but is {}'.format(enforce_type)) + return arg + + +def check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cast_subtype: bool = True): + r""" + Verify if the type is an iterator. If it is `None`, convert to an empty list/tuple. If it is + not a list/tuple/str, try to convert to an iterator. If it is a str or cannot be converted to + an iterator, then put the `arg` inside an iterator. + Possibly enforce the iterator type to `list` or `tuple`, if `enfoce_type` is not None. + Possibly enforce the subtype to any given type if `enforce_subtype` is not None, + and decide whether to cast the subtype or to throw an error. + + Parameters: + arg (any type): + The input to verify/convert to an iterator (list or tuple). If None, an empty iterator + is returned. + enforce_type (str or type): + The type to enforce the iterator. The valid choices are : + `None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`. + If `None`, then the iterator type is not enforced. + + enforce_subtype (type, np.dtype or str representing basic type): + Verify if all the elements inside the iterator are the desired type. + If `None`, then the sub-type is not enforced. + Accepted strings are ['none', 'str', 'list', 'tuple', 'dict', 'int', + 'float', 'complex', 'bool', 'callable'] + + cast_subtype: + If True, then the type specified by `enforce_subtype` is used to cast the + elements inside the iterator. If False, then an error is thrown if the + types do not match. + + Returns: + output (iterator): + An iterator based on the input of the desired type (list or tuple) and + the desired subtypes. + + """ + + # If not list or tuple, put into a list + if arg is None: + arg = [] + elif isinstance(arg, str): + arg = [arg] + elif isinstance(arg, tuple): + if enforce_type is None: + enforce_type = tuple + arg = list(arg) + elif not isinstance(arg, (tuple, list)): + try: + arg = list(arg) + except Exception: + arg = [arg] + + output = arg + + # Make sure that enforce_type and enforce_subtype are a good inputs + enforce_type = _parse_type(enforce_type, (type, str)) + enforce_subtype = _parse_type(enforce_subtype, (type, str, np.dtype)) + + # Cast all the subtypes of the list/tuple into the desired subtype + if enforce_subtype is not None: + if enforce_type is None: + arg2 = output + elif not isinstance(output, enforce_type): + arg2 = list(output) + else: + arg2 = output + try: + for idx, a in enumerate(output): + if not isinstance(a, enforce_subtype): + if cast_subtype: + arg2[idx] = enforce_subtype(a) + else: + raise TypeError( + "iter subtype is {}, desired subtype is {}, " + "but cast_subtype is set to False".format(type(arg2[idx]), enforce_subtype) + ) + except Exception as e: + raise TypeError( + "iterator subtype is {} and cannot be casted to {}\n{}".format(type(a), enforce_subtype, e) + ) + + output = _enforce_iter_type(arg2, enforce_type) + + output = _enforce_iter_type(output, enforce_type) + + return output + + +def check_list1_in_list2(list1, list2, throw_error=True): + r""" + Verify if the list1 (iterator) is included in list2 (iterator). If not, raise an error. + + Parameters: + list1, list2: list, tuple or object + A list or tuple containing the elements to verify the inclusion. + If an object is provided other than a list or tuple, + then it is considered as a list of a single element. + throw_error: bool + Whether to throw an error if list1 is not in list2 + + Returns: + list1_in_list2: bool + A boolean representing the inclusion of list1 in list2. It is returned if + throw_error is set to false + + + """ + + list1 = check_arg_iterator(list1) + list2 = check_arg_iterator(list2) + + # If all elements of list1 are not in list2, throw an error + list1_in_list2 = all(elem in list2 for elem in list1) + if not list1_in_list2 and throw_error: + raise ValueError( + ("Elements in list1 should be contained in list2." + "\n\nlist1 = {} \n\n list2 = {}").format( + list1, list2 + ) + ) + + return list1_in_list2 + + +def check_columns_choice(dataframe, columns_choice, extra_accepted_cols=None, enforce_type="list"): + r""" + Verify if the choice of column `columns_choice` is inside the dataframe or + the extra_accepted_cols. Otherwise, errors are thrown by the sub-functions. + + Parameters: + dataframe: (pd.DataFrame) + The dataframe on which to verify if the column choice is valid. + columns_choice: str, iterator(str) + The columns chosen from the dataframe + extra_accepted_cols: str, iterator(str) + A list + + enforce_type: str or type + The type to enforce the iterator. The valid choices are : + `None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`. + If `None`, then the iterator type is not enforced. + + + Returns: + output: iterator + A str iterator based on the input of the desired type (list or tuple) + + """ + extra_accepted_cols = [] if extra_accepted_cols is None else extra_accepted_cols + valid_columns = list(dataframe.columns) + kwargs_iterator = { + "enforce_type": enforce_type, + "enforce_subtype": None, + "cast_subtype": False, + } + columns_choice = check_arg_iterator(columns_choice, **kwargs_iterator) + extra_accepted_cols = check_arg_iterator(extra_accepted_cols, **kwargs_iterator) + valid_columns = check_arg_iterator(valid_columns, **kwargs_iterator) + valid_columns_full = valid_columns + extra_accepted_cols + check_list1_in_list2(columns_choice, valid_columns_full) + + return columns_choice diff --git a/goli/utils/decorators.py b/goli/utils/decorators.py index f75f490a5..900629de6 100644 --- a/goli/utils/decorators.py +++ b/goli/utils/decorators.py @@ -1,16 +1,16 @@ -class classproperty(property): - r""" - Decorator used to declare a class property, defined for the class - without needing to instanciate an object. - - !!! Example - - ``` python linenums="1" - @classproperty - def my_class_property(cls): - return 5 - ``` - """ - - def __get__(self, cls, owner): - return classmethod(self.fget).__get__(None, owner)() +class classproperty(property): + r""" + Decorator used to declare a class property, defined for the class + without needing to instanciate an object. + + !!! Example + + ``` python linenums="1" + @classproperty + def my_class_property(cls): + return 5 + ``` + """ + + def __get__(self, cls, owner): + return classmethod(self.fget).__get__(None, owner)() diff --git a/goli/utils/fs.py b/goli/utils/fs.py index 4ca7e3871..24ada941c 100644 --- a/goli/utils/fs.py +++ b/goli/utils/fs.py @@ -1,209 +1,209 @@ -from typing import Union -from typing import Optional - -import os -import io -import appdirs -import pathlib - -from tqdm.auto import tqdm -import fsspec - - -def get_cache_dir(suffix: str = None, create: bool = True) -> pathlib.Path: - """Get a local cache directory. You can append a suffix folder to it and optionnaly create - the folder if it doesn't exist. - """ - - cache_dir = pathlib.Path(appdirs.user_cache_dir(appname="goli")) - - if suffix is not None: - cache_dir /= suffix - - if create: - cache_dir.mkdir(exist_ok=True, parents=True) - - return cache_dir - - -def get_mapper(path: Union[str, os.PathLike]): - """Get the fsspec mapper. - Args: - path: a path supported by `fsspec` such as local, s3, gcs, etc. - """ - return fsspec.get_mapper(str(path)) - - -def get_basename(path: Union[str, os.PathLike]): - """Get the basename of a file or a folder. - Args: - path: a path supported by `fsspec` such as local, s3, gcs, etc. - """ - path = str(path) - mapper = get_mapper(path) - clean_path = path.rstrip(mapper.fs.sep) - return str(clean_path).split(mapper.fs.sep)[-1] - - -def get_extension(path: Union[str, os.PathLike]): - """Get the extension of a file. - Args: - path: a path supported by `fsspec` such as local, s3, gcs, etc. - """ - basename = get_basename(path) - return basename.split(".")[-1] - - -def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]): - """Check whether a file exists. - Args: - path: a path supported by `fsspec` such as local, s3, gcs, etc. - """ - - if isinstance(path, fsspec.core.OpenFile): - return path.fs.exists(path.path) - - elif isinstance(path, (str, pathlib.Path)): - mapper = get_mapper(str(path)) - return mapper.fs.exists(path) - - else: - # NOTE(hadim): file-like objects always exist right? - return True - - -def exists_and_not_empty(path: Union[str, os.PathLike]): - """Check whether a directory exists and is not empty.""" - - if not exists(path): - return False - - fs = get_mapper(path).fs - - return len(fs.ls(path)) > 0 - - -def mkdir(path: Union[str, os.PathLike], exist_ok: bool = True): - """Create directory including potential parents.""" - fs = get_mapper(path).fs - fs.mkdirs(path, exist_ok=exist_ok) - - -def join(*paths): - """Join paths together. The first element determine the - filesystem to use (and so the separator. - Args: - paths: a list of paths supported by `fsspec` such as local, s3, gcs, etc. - """ - paths = [str(path) for path in paths] - source_path = paths[0] - fs = get_mapper(source_path).fs - full_path = fs.sep.join(paths) - return full_path - - -def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]) -> Optional[int]: - """Get the size of a file given its path. Return None if the - size can't be retrieved. - """ - - if isinstance(file, io.IOBase) and hasattr(file, "name"): - fs_local = fsspec.filesystem("file") - file_size = fs_local.size(getattr(file, "name")) - - elif isinstance(file, (str, pathlib.Path)): - fs = get_mapper(str(file)).fs - file_size = fs.size(str(file)) - - elif isinstance(file, fsspec.core.OpenFile): - file_size = file.fs.size(file.path) - - else: - file_size = None - - return file_size - - -def copy( - source: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile], - destination: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile], - chunk_size: int = None, - force: bool = False, - progress: bool = False, - leave_progress: bool = True, -): - """Copy one file to another location across different filesystem (local, S3, GCS, etc). - - Args: - source: path or file-like object to copy from. - destination: path or file-like object to copy to. - chunk_size: the chunk size to use. If progress is enabled the chunk - size is `None`, it is set to 2048. - force: whether to overwrite the destination file it it exists. - progress: whether to display a progress bar. - leave_progress: whether to hide the progress bar once the copy is done. - """ - - if progress and chunk_size is None: - chunk_size = 2048 - - if isinstance(source, (str, os.PathLike)): - source_file = fsspec.open(str(source), "rb") - else: - source_file = source - - if isinstance(destination, (str, os.PathLike)): - - # adapt the file mode of the destination depending on the source file. - destination_mode = "wb" - if hasattr(source_file, "mode"): - destination_mode = "wb" if "b" in getattr(source_file, "mode") else "w" - elif isinstance(source_file, io.BytesIO): - destination_mode = "wb" - elif isinstance(source_file, io.StringIO): - destination_mode = "w" - - destination_file = fsspec.open(str(destination), destination_mode) - else: - destination_file = destination - - if not exists(source_file): - raise ValueError(f"The file being copied does not exist: {source}") - - if not force and exists(destination_file): - raise ValueError(f"The destination file to copy already exists: {destination}") - - with source_file as source_stream: - with destination_file as destination_stream: - - if chunk_size is None: - # copy without chunks - destination_stream.write(source_stream.read()) - - else: - # copy with chunks - - # determine the size of the source file - source_size = None - if progress: - source_size = get_size(source) - - # init progress bar - pbar = tqdm( - total=source_size, - leave=leave_progress, - disable=not progress, - unit="B", - unit_divisor=1024, - unit_scale=True, - ) - - # start the loop - while True: - data = source_stream.read(chunk_size) - if not data: - break - destination_stream.write(data) - pbar.update(chunk_size) - - pbar.close() +from typing import Union +from typing import Optional + +import os +import io +import appdirs +import pathlib + +from tqdm.auto import tqdm +import fsspec + + +def get_cache_dir(suffix: str = None, create: bool = True) -> pathlib.Path: + """Get a local cache directory. You can append a suffix folder to it and optionnaly create + the folder if it doesn't exist. + """ + + cache_dir = pathlib.Path(appdirs.user_cache_dir(appname="goli")) + + if suffix is not None: + cache_dir /= suffix + + if create: + cache_dir.mkdir(exist_ok=True, parents=True) + + return cache_dir + + +def get_mapper(path: Union[str, os.PathLike]): + """Get the fsspec mapper. + Args: + path: a path supported by `fsspec` such as local, s3, gcs, etc. + """ + return fsspec.get_mapper(str(path)) + + +def get_basename(path: Union[str, os.PathLike]): + """Get the basename of a file or a folder. + Args: + path: a path supported by `fsspec` such as local, s3, gcs, etc. + """ + path = str(path) + mapper = get_mapper(path) + clean_path = path.rstrip(mapper.fs.sep) + return str(clean_path).split(mapper.fs.sep)[-1] + + +def get_extension(path: Union[str, os.PathLike]): + """Get the extension of a file. + Args: + path: a path supported by `fsspec` such as local, s3, gcs, etc. + """ + basename = get_basename(path) + return basename.split(".")[-1] + + +def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]): + """Check whether a file exists. + Args: + path: a path supported by `fsspec` such as local, s3, gcs, etc. + """ + + if isinstance(path, fsspec.core.OpenFile): + return path.fs.exists(path.path) + + elif isinstance(path, (str, pathlib.Path)): + mapper = get_mapper(str(path)) + return mapper.fs.exists(path) + + else: + # NOTE(hadim): file-like objects always exist right? + return True + + +def exists_and_not_empty(path: Union[str, os.PathLike]): + """Check whether a directory exists and is not empty.""" + + if not exists(path): + return False + + fs = get_mapper(path).fs + + return len(fs.ls(path)) > 0 + + +def mkdir(path: Union[str, os.PathLike], exist_ok: bool = True): + """Create directory including potential parents.""" + fs = get_mapper(path).fs + fs.mkdirs(path, exist_ok=exist_ok) + + +def join(*paths): + """Join paths together. The first element determine the + filesystem to use (and so the separator. + Args: + paths: a list of paths supported by `fsspec` such as local, s3, gcs, etc. + """ + paths = [str(path) for path in paths] + source_path = paths[0] + fs = get_mapper(source_path).fs + full_path = fs.sep.join(paths) + return full_path + + +def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]) -> Optional[int]: + """Get the size of a file given its path. Return None if the + size can't be retrieved. + """ + + if isinstance(file, io.IOBase) and hasattr(file, "name"): + fs_local = fsspec.filesystem("file") + file_size = fs_local.size(getattr(file, "name")) + + elif isinstance(file, (str, pathlib.Path)): + fs = get_mapper(str(file)).fs + file_size = fs.size(str(file)) + + elif isinstance(file, fsspec.core.OpenFile): + file_size = file.fs.size(file.path) + + else: + file_size = None + + return file_size + + +def copy( + source: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile], + destination: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile], + chunk_size: int = None, + force: bool = False, + progress: bool = False, + leave_progress: bool = True, +): + """Copy one file to another location across different filesystem (local, S3, GCS, etc). + + Args: + source: path or file-like object to copy from. + destination: path or file-like object to copy to. + chunk_size: the chunk size to use. If progress is enabled the chunk + size is `None`, it is set to 2048. + force: whether to overwrite the destination file it it exists. + progress: whether to display a progress bar. + leave_progress: whether to hide the progress bar once the copy is done. + """ + + if progress and chunk_size is None: + chunk_size = 2048 + + if isinstance(source, (str, os.PathLike)): + source_file = fsspec.open(str(source), "rb") + else: + source_file = source + + if isinstance(destination, (str, os.PathLike)): + + # adapt the file mode of the destination depending on the source file. + destination_mode = "wb" + if hasattr(source_file, "mode"): + destination_mode = "wb" if "b" in getattr(source_file, "mode") else "w" + elif isinstance(source_file, io.BytesIO): + destination_mode = "wb" + elif isinstance(source_file, io.StringIO): + destination_mode = "w" + + destination_file = fsspec.open(str(destination), destination_mode) + else: + destination_file = destination + + if not exists(source_file): + raise ValueError(f"The file being copied does not exist: {source}") + + if not force and exists(destination_file): + raise ValueError(f"The destination file to copy already exists: {destination}") + + with source_file as source_stream: + with destination_file as destination_stream: + + if chunk_size is None: + # copy without chunks + destination_stream.write(source_stream.read()) + + else: + # copy with chunks + + # determine the size of the source file + source_size = None + if progress: + source_size = get_size(source) + + # init progress bar + pbar = tqdm( + total=source_size, + leave=leave_progress, + disable=not progress, + unit="B", + unit_divisor=1024, + unit_scale=True, + ) + + # start the loop + while True: + data = source_stream.read(chunk_size) + if not data: + break + destination_stream.write(data) + pbar.update(chunk_size) + + pbar.close() diff --git a/goli/utils/read_file.py b/goli/utils/read_file.py index 4e4b69a01..3aaf35534 100644 --- a/goli/utils/read_file.py +++ b/goli/utils/read_file.py @@ -1,159 +1,159 @@ -""" Utiles for data parsing""" -import os -import warnings -import numpy as np -import pandas as pd -import datamol as dm -from functools import partial -from copy import copy -import fsspec - -from loguru import logger -from rdkit import Chem -from rdkit.Chem.Descriptors import ExactMolWt - -from goli.utils.tensor import parse_valid_args, arg_in_func - - -def read_file(filepath, as_ext=None, **kwargs): - r""" - Allow to read different file format and parse them into a MolecularDataFrame. - Supported formats are: - * csv (.csv, .smile, .smiles, .tsv) - * txt (.txt) - * xls (.xls, .xlsx, .xlsm, .xls*) - * sdf (.sdf) - * pkl (.pkl) - - Arguments - ----------- - - filepath: str - The full path and name of the file to read. - It also supports the s3 url path. - as_ext: str, Optional - The file extension used to read the file. If None, the extension is deduced - from the extension of the file. Otherwise, no matter the file extension, - the file will be read according to the specified ``as_ext``. - (Default=None) - **kwargs: All the optional parameters required for the desired file reader. - - TODO: unit test to make sure it works well with all extensions - - Returns - --------- - df: pandas.DataFrame - The ``pandas.DataFrame`` containing the parsed data - - """ - - # Get the file extension - if as_ext is None: - file_ext = os.path.splitext(filepath)[-1].lower()[1:] - else: - file_ext = as_ext - if not isinstance(file_ext, str): - raise "`file_type` must be a `str`. Provided: {}".format(file_ext) - - open_mode = "r" - - # Read the file according to the right extension - if file_ext in ["csv", "smile", "smiles", "smi", "tsv"]: - file_reader = pd.read_csv - elif file_ext == "txt": - file_reader = pd.read_table - elif file_ext[0:3] == "xls": - open_mode = "rb" - file_reader = partial(pd.read_excel, engine="openpyxl") - elif file_ext == "sdf": - file_reader = parse_sdf_to_dataframe - elif file_ext == "pkl": - open_mode = "rb" - file_reader = pd.read_pickle - else: - raise 'File extension "{}" not supported'.format(file_ext) - - kwargs = parse_valid_args(fn=file_reader, param_dict=kwargs) - - if file_ext[0:3] not in ["sdf", "xls"]: - with file_opener(filepath, open_mode) as file_in: - data = file_reader(file_in, **kwargs) - else: - data = file_reader(filepath, **kwargs) - return data - - -def parse_sdf_to_dataframe(sdf_path, as_cxsmiles=True, skiprows=None): - r""" - Allows to read an SDF file containing molecular informations, convert - it to a pandas DataFrame and convert the molecules to SMILES. It also - lists a warning of all the molecules that couldn't be read. - - Arguments - ----------- - - sdf_path: str - The full path and name of the sdf file to read - as_cxsmiles: bool, optional - Whether to use the CXSMILES notation, which preserves atomic coordinates, - stereocenters, and much more. - See `https://dl.chemaxon.com/marvin-archive/latest/help/formats/cxsmiles-doc.html` - (Default = True) - skiprows: int, list - The rows to skip from dataset. The enumerate index starts from 1 insted of 0. - (Default = None) - - """ - - # read the SDF file - # locally or from s3 - data = dm.read_sdf(sdf_path) - - # For each molecule in the SDF file, read all the properties and add it to a list of dict. - # Also count the number of molecules that cannot be read. - data_list = [] - count_none = 0 - if skiprows is not None: - if isinstance(skiprows, int): - skiprows = range(0, skiprows - 1) - skiprows = np.array(skiprows) - 1 - - for idx, mol in enumerate(data): - if (skiprows is not None) and (idx in skiprows): - continue - - if (mol is not None) and (ExactMolWt(mol) > 0): - mol_dict = mol.GetPropsAsDict() - data_list.append(mol_dict) - if as_cxsmiles: - smiles = Chem.rdmolfiles.MolToCXSmiles(mol, canonical=True) - else: - smiles = dm.to_smiles(mol, canonical=True) - data_list[-1]["SMILES"] = smiles - else: - count_none += 1 - logger.info(f"Could not read molecule # {idx}") - - # Display a message or warning after the SDF is done parsing - if count_none == 0: - logger.info("Successfully read the SDF file without error: {}".format(sdf_path)) - else: - warnings.warn( - ( - 'Error reading {} molecules from the "{}" file.\ - {} molecules read successfully.' - ).format(count_none, sdf_path, len(data_list)) - ) - return pd.DataFrame(data_list) - - -def file_opener(filename, mode="r"): - """File reader stream""" - filename = str(filename) - if "w" in mode: - filename = "simplecache::" + filename - if filename.endswith(".gz"): - instream = fsspec.open(filename, mode=mode, compression="gzip") - else: - instream = fsspec.open(filename, mode=mode) - return instream +""" Utiles for data parsing""" +import os +import warnings +import numpy as np +import pandas as pd +import datamol as dm +from functools import partial +from copy import copy +import fsspec + +from loguru import logger +from rdkit import Chem +from rdkit.Chem.Descriptors import ExactMolWt + +from goli.utils.tensor import parse_valid_args, arg_in_func + + +def read_file(filepath, as_ext=None, **kwargs): + r""" + Allow to read different file format and parse them into a MolecularDataFrame. + Supported formats are: + * csv (.csv, .smile, .smiles, .tsv) + * txt (.txt) + * xls (.xls, .xlsx, .xlsm, .xls*) + * sdf (.sdf) + * pkl (.pkl) + + Arguments + ----------- + + filepath: str + The full path and name of the file to read. + It also supports the s3 url path. + as_ext: str, Optional + The file extension used to read the file. If None, the extension is deduced + from the extension of the file. Otherwise, no matter the file extension, + the file will be read according to the specified ``as_ext``. + (Default=None) + **kwargs: All the optional parameters required for the desired file reader. + + TODO: unit test to make sure it works well with all extensions + + Returns + --------- + df: pandas.DataFrame + The ``pandas.DataFrame`` containing the parsed data + + """ + + # Get the file extension + if as_ext is None: + file_ext = os.path.splitext(filepath)[-1].lower()[1:] + else: + file_ext = as_ext + if not isinstance(file_ext, str): + raise "`file_type` must be a `str`. Provided: {}".format(file_ext) + + open_mode = "r" + + # Read the file according to the right extension + if file_ext in ["csv", "smile", "smiles", "smi", "tsv"]: + file_reader = pd.read_csv + elif file_ext == "txt": + file_reader = pd.read_table + elif file_ext[0:3] == "xls": + open_mode = "rb" + file_reader = partial(pd.read_excel, engine="openpyxl") + elif file_ext == "sdf": + file_reader = parse_sdf_to_dataframe + elif file_ext == "pkl": + open_mode = "rb" + file_reader = pd.read_pickle + else: + raise 'File extension "{}" not supported'.format(file_ext) + + kwargs = parse_valid_args(fn=file_reader, param_dict=kwargs) + + if file_ext[0:3] not in ["sdf", "xls"]: + with file_opener(filepath, open_mode) as file_in: + data = file_reader(file_in, **kwargs) + else: + data = file_reader(filepath, **kwargs) + return data + + +def parse_sdf_to_dataframe(sdf_path, as_cxsmiles=True, skiprows=None): + r""" + Allows to read an SDF file containing molecular informations, convert + it to a pandas DataFrame and convert the molecules to SMILES. It also + lists a warning of all the molecules that couldn't be read. + + Arguments + ----------- + + sdf_path: str + The full path and name of the sdf file to read + as_cxsmiles: bool, optional + Whether to use the CXSMILES notation, which preserves atomic coordinates, + stereocenters, and much more. + See `https://dl.chemaxon.com/marvin-archive/latest/help/formats/cxsmiles-doc.html` + (Default = True) + skiprows: int, list + The rows to skip from dataset. The enumerate index starts from 1 insted of 0. + (Default = None) + + """ + + # read the SDF file + # locally or from s3 + data = dm.read_sdf(sdf_path) + + # For each molecule in the SDF file, read all the properties and add it to a list of dict. + # Also count the number of molecules that cannot be read. + data_list = [] + count_none = 0 + if skiprows is not None: + if isinstance(skiprows, int): + skiprows = range(0, skiprows - 1) + skiprows = np.array(skiprows) - 1 + + for idx, mol in enumerate(data): + if (skiprows is not None) and (idx in skiprows): + continue + + if (mol is not None) and (ExactMolWt(mol) > 0): + mol_dict = mol.GetPropsAsDict() + data_list.append(mol_dict) + if as_cxsmiles: + smiles = Chem.rdmolfiles.MolToCXSmiles(mol, canonical=True) + else: + smiles = dm.to_smiles(mol, canonical=True) + data_list[-1]["SMILES"] = smiles + else: + count_none += 1 + logger.info(f"Could not read molecule # {idx}") + + # Display a message or warning after the SDF is done parsing + if count_none == 0: + logger.info("Successfully read the SDF file without error: {}".format(sdf_path)) + else: + warnings.warn( + ( + 'Error reading {} molecules from the "{}" file.\ + {} molecules read successfully.' + ).format(count_none, sdf_path, len(data_list)) + ) + return pd.DataFrame(data_list) + + +def file_opener(filename, mode="r"): + """File reader stream""" + filename = str(filename) + if "w" in mode: + filename = "simplecache::" + filename + if filename.endswith(".gz"): + instream = fsspec.open(filename, mode=mode, compression="gzip") + else: + instream = fsspec.open(filename, mode=mode) + return instream diff --git a/goli/utils/spaces.py b/goli/utils/spaces.py index e29f7d47c..20d1524a8 100644 --- a/goli/utils/spaces.py +++ b/goli/utils/spaces.py @@ -1,50 +1,50 @@ -from copy import deepcopy - -from goli.nn.base_layers import FCLayer - -from goli.nn.dgl_layers import ( - GATLayer, - GCNLayer, - GINLayer, - GatedGCNLayer, - PNAConvolutionalLayer, - PNAMessagePassingLayer, - DGNConvolutionalLayer, - DGNMessagePassingLayer, -) - -from goli.nn.residual_connections import ( - ResidualConnectionConcat, - ResidualConnectionDenseNet, - ResidualConnectionNone, - ResidualConnectionSimple, - ResidualConnectionWeighted, -) - - -FC_LAYERS_DICT = { - "fc": FCLayer, -} - -DGL_LAYERS_DICT = { - "gcn": GCNLayer, - "gin": GINLayer, - "gat": GATLayer, - "gated-gcn": GatedGCNLayer, - "pna-conv": PNAConvolutionalLayer, - "pna-msgpass": PNAMessagePassingLayer, - "dgn-conv": DGNConvolutionalLayer, - "dgn-msgpass": DGNMessagePassingLayer, -} - -LAYERS_DICT = deepcopy(DGL_LAYERS_DICT) -LAYERS_DICT.update(deepcopy(FC_LAYERS_DICT)) - - -RESIDUALS_DICT = { - "none": ResidualConnectionNone, - "simple": ResidualConnectionSimple, - "weighted": ResidualConnectionWeighted, - "concat": ResidualConnectionConcat, - "densenet": ResidualConnectionDenseNet, -} +from copy import deepcopy + +from goli.nn.base_layers import FCLayer + +from goli.nn.dgl_layers import ( + GATLayer, + GCNLayer, + GINLayer, + GatedGCNLayer, + PNAConvolutionalLayer, + PNAMessagePassingLayer, + DGNConvolutionalLayer, + DGNMessagePassingLayer, +) + +from goli.nn.residual_connections import ( + ResidualConnectionConcat, + ResidualConnectionDenseNet, + ResidualConnectionNone, + ResidualConnectionSimple, + ResidualConnectionWeighted, +) + + +FC_LAYERS_DICT = { + "fc": FCLayer, +} + +DGL_LAYERS_DICT = { + "gcn": GCNLayer, + "gin": GINLayer, + "gat": GATLayer, + "gated-gcn": GatedGCNLayer, + "pna-conv": PNAConvolutionalLayer, + "pna-msgpass": PNAMessagePassingLayer, + "dgn-conv": DGNConvolutionalLayer, + "dgn-msgpass": DGNMessagePassingLayer, +} + +LAYERS_DICT = deepcopy(DGL_LAYERS_DICT) +LAYERS_DICT.update(deepcopy(FC_LAYERS_DICT)) + + +RESIDUALS_DICT = { + "none": ResidualConnectionNone, + "simple": ResidualConnectionSimple, + "weighted": ResidualConnectionWeighted, + "concat": ResidualConnectionConcat, + "densenet": ResidualConnectionDenseNet, +} diff --git a/goli/utils/tensor.py b/goli/utils/tensor.py index a3f6f6894..02bac1606 100644 --- a/goli/utils/tensor.py +++ b/goli/utils/tensor.py @@ -1,268 +1,268 @@ -import os -import torch -import numpy as np -import pandas as pd -from matplotlib import pyplot as plt -from typing import List, Union -from inspect import getfullargspec -from copy import copy, deepcopy -from loguru import logger - -from rdkit.Chem import AllChem -from torch.tensor import Tensor - - -def save_im(im_dir, im_name: str, ext: List[str] = ["svg", "png"], dpi: int = 600) -> None: - - if not os.path.exists(im_dir): - if im_dir[-1] not in ["/", "\\"]: - im_dir += "/" - os.makedirs(im_dir) - - if isinstance(ext, str): - ext = [ext] - - full_name = os.path.join(im_dir, im_name) - for this_ext in ext: - plt.savefig(f"{full_name}.{this_ext}", dpi=dpi, bbox_inches="tight", pad_inches=0) - - -def is_dtype_torch_tensor(dtype: Union[np.dtype, torch.dtype]) -> bool: - r""" - Verify if the dtype is a torch dtype - - Parameters: - dtype: dtype - The dtype of a value. E.g. np.int32, str, torch.float - - Returns: - A boolean saying if the dtype is a torch dtype - """ - return isinstance(dtype, torch.dtype) or (dtype == Tensor) - - -def is_dtype_numpy_array(dtype: Union[np.dtype, torch.dtype]) -> bool: - r""" - Verify if the dtype is a numpy dtype - - Parameters: - dtype: dtype - The dtype of a value. E.g. np.int32, str, torch.float - - Returns: - A boolean saying if the dtype is a numpy dtype - """ - is_torch = is_dtype_torch_tensor(dtype) - is_num = dtype in (int, float, complex) - if hasattr(dtype, "__module__"): - is_numpy = dtype.__module__ == "numpy" - else: - is_numpy = False - - return (is_num or is_numpy) and not is_torch - - -def one_of_k_encoding(val: int, num_classes: int, dtype=int) -> np.ndarray: - r"""Converts a single value to a one-hot vector. - - Parameters: - val: int - class to be converted into a one hot vector - (integers from 0 to num_classes). - num_classes: iterator - a list or 1D array of allowed - choices for val to take - dtype: type - data type of the the return. - Possible types are int, float, bool, ... - Returns: - A numpy 1D array of length len(num_classes) + 1 - """ - - encoding = np.zeros(len(num_classes) + 1, dtype=dtype) - # not using index of, in case, someone fuck up - # and there are duplicates in the allowed choices - for i, v in enumerate(num_classes): - if v == val: - encoding[i] = 1 - if np.sum(encoding) == 0: # aka not found - encoding[-1] = 1 - return encoding - - -def is_device_cuda(device: torch.device, ignore_errors: bool = False) -> bool: - r"""Check wheter the given device is a cuda device. - - Parameters: - device: str, torch.device - object to check for cuda - ignore_errors: bool - Whether to ignore the error if the device is not recognized. - Otherwise, ``False`` is returned in case of errors. - Returns: - is_cuda: bool - """ - - if ignore_errors: - is_cuda = False - try: - is_cuda = torch.device(device).type == "cuda" - except: - pass - else: - is_cuda = torch.device(device).type == "cuda" - return is_cuda - - -def nan_mean(input: Tensor, **kwargs) -> Tensor: - r""" - Return the mean of all elements, while ignoring the NaNs. - - Parameters: - - input: The input tensor. - - dim (int or tuple(int)): The dimension or dimensions to reduce. - - keepdim (bool): whether the output tensor has dim retained or not. - - dtype (torch.dtype, optional): - The desired data type of returned tensor. - If specified, the input tensor is casted to dtype before the operation is performed. - This is useful for preventing data type overflows. Default: None. - - Returns: - output: The resulting mean of the tensor - """ - - sum = torch.nansum(input, **kwargs) - num = torch.sum(~torch.isnan(input), **kwargs) - mean = sum / num - return mean - - -def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: - r""" - Return the variace of all elements, while ignoring the NaNs. - If unbiased is True, Bessel’s correction will be used. - Otherwise, the sample deviation is calculated, without any correction. - - Parameters: - - input: The input tensor. - - unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1). - - dim (int or tuple(int)): The dimension or dimensions to reduce. - - keepdim (bool): whether the output tensor has dim retained or not. - - dtype (torch.dtype, optional): - The desired data type of returned tensor. - If specified, the input tensor is casted to dtype before the operation is performed. - This is useful for preventing data type overflows. Default: None. - - Returns: - output: The resulting variance of the tensor - """ - - mean_kwargs = deepcopy(kwargs) - mean_kwargs.pop("keepdim", None) - dim = mean_kwargs.pop("dim", [ii for ii in range(input.ndim)]) - mean = nan_mean(input, dim=dim, keepdim=True, **mean_kwargs) - dist = (input - mean).abs() ** 2 - var = nan_mean(dist, **kwargs) - - if unbiased: - num = torch.sum(~torch.isnan(input), **kwargs) - var = var * num / (num - 1) - - return var - - -def nan_std(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: - r""" - Return the standard deviation of all elements, while ignoring the NaNs. - If unbiased is True, Bessel’s correction will be used. - Otherwise, the sample deviation is calculated, without any correction. - - Parameters: - - input: The input tensor. - - unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1). - - dim (int or tuple(int)): The dimension or dimensions to reduce. - - keepdim (bool): whether the output tensor has dim retained or not. - - dtype (torch.dtype, optional): - The desired data type of returned tensor. - If specified, the input tensor is casted to dtype before the operation is performed. - This is useful for preventing data type overflows. Default: None. - - Returns: - output: The resulting standard deviation of the tensor - """ - - return torch.sqrt(nan_var(input=input, unbiased=unbiased, **kwargs)) - - -class ModuleListConcat(torch.nn.ModuleList): - def __init__(self, dim: int = -1): - super().__init__() - self.dim = dim - - def forward(self, *args, **kwargs) -> Tensor: - h = [] - for module in self: - h.append(module.forward(*args, **kwargs)) - - return torch.cat(h, dim=self.dim) - - -def parse_valid_args(param_dict, fn): - r""" - Check if a function takes the given argument. - - Parameters - ---------- - fn: func - The function to check the argument. - param_dict: dict - Dictionary of the argument. - - Returns - ------- - param_dict: dict - Valid paramter dictionary for the given fucntions. - """ - param_dict_cp = copy(param_dict) - for key, param in param_dict.items(): - if not arg_in_func(fn=fn, arg=key): - logger.warning( - f"{key} is not an available argument for {fn.__name__}, and is ignored by default." - ) - param_dict_cp.pop(key) - - return param_dict_cp - - -def arg_in_func(fn, arg): - r""" - Check if a function takes the given argument. - - Parameters - ---------- - fn: func - The function to check the argument. - arg: str - The name of the argument. - - Returns - ------- - res: bool - True if the function contains the argument, otherwise False. - """ - fn_args = getfullargspec(fn) - return (fn_args.varkw is not None) or (arg in fn_args[0]) +import os +import torch +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from typing import List, Union +from inspect import getfullargspec +from copy import copy, deepcopy +from loguru import logger + +from rdkit.Chem import AllChem +from torch.tensor import Tensor + + +def save_im(im_dir, im_name: str, ext: List[str] = ["svg", "png"], dpi: int = 600) -> None: + + if not os.path.exists(im_dir): + if im_dir[-1] not in ["/", "\\"]: + im_dir += "/" + os.makedirs(im_dir) + + if isinstance(ext, str): + ext = [ext] + + full_name = os.path.join(im_dir, im_name) + for this_ext in ext: + plt.savefig(f"{full_name}.{this_ext}", dpi=dpi, bbox_inches="tight", pad_inches=0) + + +def is_dtype_torch_tensor(dtype: Union[np.dtype, torch.dtype]) -> bool: + r""" + Verify if the dtype is a torch dtype + + Parameters: + dtype: dtype + The dtype of a value. E.g. np.int32, str, torch.float + + Returns: + A boolean saying if the dtype is a torch dtype + """ + return isinstance(dtype, torch.dtype) or (dtype == Tensor) + + +def is_dtype_numpy_array(dtype: Union[np.dtype, torch.dtype]) -> bool: + r""" + Verify if the dtype is a numpy dtype + + Parameters: + dtype: dtype + The dtype of a value. E.g. np.int32, str, torch.float + + Returns: + A boolean saying if the dtype is a numpy dtype + """ + is_torch = is_dtype_torch_tensor(dtype) + is_num = dtype in (int, float, complex) + if hasattr(dtype, "__module__"): + is_numpy = dtype.__module__ == "numpy" + else: + is_numpy = False + + return (is_num or is_numpy) and not is_torch + + +def one_of_k_encoding(val: int, num_classes: int, dtype=int) -> np.ndarray: + r"""Converts a single value to a one-hot vector. + + Parameters: + val: int + class to be converted into a one hot vector + (integers from 0 to num_classes). + num_classes: iterator + a list or 1D array of allowed + choices for val to take + dtype: type + data type of the the return. + Possible types are int, float, bool, ... + Returns: + A numpy 1D array of length len(num_classes) + 1 + """ + + encoding = np.zeros(len(num_classes) + 1, dtype=dtype) + # not using index of, in case, someone fuck up + # and there are duplicates in the allowed choices + for i, v in enumerate(num_classes): + if v == val: + encoding[i] = 1 + if np.sum(encoding) == 0: # aka not found + encoding[-1] = 1 + return encoding + + +def is_device_cuda(device: torch.device, ignore_errors: bool = False) -> bool: + r"""Check wheter the given device is a cuda device. + + Parameters: + device: str, torch.device + object to check for cuda + ignore_errors: bool + Whether to ignore the error if the device is not recognized. + Otherwise, ``False`` is returned in case of errors. + Returns: + is_cuda: bool + """ + + if ignore_errors: + is_cuda = False + try: + is_cuda = torch.device(device).type == "cuda" + except: + pass + else: + is_cuda = torch.device(device).type == "cuda" + return is_cuda + + +def nan_mean(input: Tensor, **kwargs) -> Tensor: + r""" + Return the mean of all elements, while ignoring the NaNs. + + Parameters: + + input: The input tensor. + + dim (int or tuple(int)): The dimension or dimensions to reduce. + + keepdim (bool): whether the output tensor has dim retained or not. + + dtype (torch.dtype, optional): + The desired data type of returned tensor. + If specified, the input tensor is casted to dtype before the operation is performed. + This is useful for preventing data type overflows. Default: None. + + Returns: + output: The resulting mean of the tensor + """ + + sum = torch.nansum(input, **kwargs) + num = torch.sum(~torch.isnan(input), **kwargs) + mean = sum / num + return mean + + +def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: + r""" + Return the variace of all elements, while ignoring the NaNs. + If unbiased is True, Bessel’s correction will be used. + Otherwise, the sample deviation is calculated, without any correction. + + Parameters: + + input: The input tensor. + + unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1). + + dim (int or tuple(int)): The dimension or dimensions to reduce. + + keepdim (bool): whether the output tensor has dim retained or not. + + dtype (torch.dtype, optional): + The desired data type of returned tensor. + If specified, the input tensor is casted to dtype before the operation is performed. + This is useful for preventing data type overflows. Default: None. + + Returns: + output: The resulting variance of the tensor + """ + + mean_kwargs = deepcopy(kwargs) + mean_kwargs.pop("keepdim", None) + dim = mean_kwargs.pop("dim", [ii for ii in range(input.ndim)]) + mean = nan_mean(input, dim=dim, keepdim=True, **mean_kwargs) + dist = (input - mean).abs() ** 2 + var = nan_mean(dist, **kwargs) + + if unbiased: + num = torch.sum(~torch.isnan(input), **kwargs) + var = var * num / (num - 1) + + return var + + +def nan_std(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: + r""" + Return the standard deviation of all elements, while ignoring the NaNs. + If unbiased is True, Bessel’s correction will be used. + Otherwise, the sample deviation is calculated, without any correction. + + Parameters: + + input: The input tensor. + + unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1). + + dim (int or tuple(int)): The dimension or dimensions to reduce. + + keepdim (bool): whether the output tensor has dim retained or not. + + dtype (torch.dtype, optional): + The desired data type of returned tensor. + If specified, the input tensor is casted to dtype before the operation is performed. + This is useful for preventing data type overflows. Default: None. + + Returns: + output: The resulting standard deviation of the tensor + """ + + return torch.sqrt(nan_var(input=input, unbiased=unbiased, **kwargs)) + + +class ModuleListConcat(torch.nn.ModuleList): + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, *args, **kwargs) -> Tensor: + h = [] + for module in self: + h.append(module.forward(*args, **kwargs)) + + return torch.cat(h, dim=self.dim) + + +def parse_valid_args(param_dict, fn): + r""" + Check if a function takes the given argument. + + Parameters + ---------- + fn: func + The function to check the argument. + param_dict: dict + Dictionary of the argument. + + Returns + ------- + param_dict: dict + Valid paramter dictionary for the given fucntions. + """ + param_dict_cp = copy(param_dict) + for key, param in param_dict.items(): + if not arg_in_func(fn=fn, arg=key): + logger.warning( + f"{key} is not an available argument for {fn.__name__}, and is ignored by default." + ) + param_dict_cp.pop(key) + + return param_dict_cp + + +def arg_in_func(fn, arg): + r""" + Check if a function takes the given argument. + + Parameters + ---------- + fn: func + The function to check the argument. + arg: str + The name of the argument. + + Returns + ------- + res: bool + True if the function contains the argument, otherwise False. + """ + fn_args = getfullargspec(fn) + return (fn_args.varkw is not None) or (arg in fn_args[0]) diff --git a/goli/visualization/vis_utils.py b/goli/visualization/vis_utils.py index 97248b527..5cf481a28 100644 --- a/goli/visualization/vis_utils.py +++ b/goli/visualization/vis_utils.py @@ -1,36 +1,36 @@ -import inspect -from copy import deepcopy - -from matplotlib.offsetbox import AnchoredText - - -def annotate_metric(ax, metrics, x, y, fontsize=10, loc="upper left", **kwargs): - # Compute each metric from `metrics` on the `x` and `y` data. - # Then Annotate the plot with with the the results of each metric - # Compute the metrics and generate strings for each metric - - stat_text = "" - - for metric_name, metric in metrics.items(): - kwargs_copy = deepcopy(kwargs) - stat = metric(x, y, **kwargs_copy) - stat_text += "\n" + metric_name + " = {:0.3f}".format(stat) - - stat_text = stat_text[1:] - - # Display the metrics on the plot - _annotate(ax=ax.ax_joint, text=stat_text, loc=loc) - - -def _annotate(ax, text, loc="upper left", bbox_to_anchor=(1.2, 1)): - text = text.strip() - - text_loc_outside = dict() - if loc == "outside": - text_loc_outside["bbox_to_anchor"] = bbox_to_anchor - text_loc_outside["bbox_transform"] = ax.transAxes - loc = "upper left" - - anchored_text = AnchoredText(text, loc=loc, **text_loc_outside) - anchored_text.patch._alpha = 0.25 - ax.add_artist(anchored_text) +import inspect +from copy import deepcopy + +from matplotlib.offsetbox import AnchoredText + + +def annotate_metric(ax, metrics, x, y, fontsize=10, loc="upper left", **kwargs): + # Compute each metric from `metrics` on the `x` and `y` data. + # Then Annotate the plot with with the the results of each metric + # Compute the metrics and generate strings for each metric + + stat_text = "" + + for metric_name, metric in metrics.items(): + kwargs_copy = deepcopy(kwargs) + stat = metric(x, y, **kwargs_copy) + stat_text += "\n" + metric_name + " = {:0.3f}".format(stat) + + stat_text = stat_text[1:] + + # Display the metrics on the plot + _annotate(ax=ax.ax_joint, text=stat_text, loc=loc) + + +def _annotate(ax, text, loc="upper left", bbox_to_anchor=(1.2, 1)): + text = text.strip() + + text_loc_outside = dict() + if loc == "outside": + text_loc_outside["bbox_to_anchor"] = bbox_to_anchor + text_loc_outside["bbox_transform"] = ax.transAxes + loc = "upper left" + + anchored_text = AnchoredText(text, loc=loc, **text_loc_outside) + anchored_text.patch._alpha = 0.25 + ax.add_artist(anchored_text) diff --git a/mkdocs.yml b/mkdocs.yml index 358bf3217..f398cd8ed 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,83 +1,83 @@ -site_name: "goli" -site_description: "A deep learning library focused on graph representation learning for real-world chemical tasks." -site_url: "https://github.com/valence-discovery/goli" -repo_url: "https://github.com/valence-discovery/goli" -repo_name: "valence-discovery/goli" -copyright: Copyright 2020 - 2021 Valence Discovery - -remote_branch: "privpage" -use_directory_urls: false -docs_dir: "docs" - -nav: - - Overview: index.md - - Tutorials: - - Using GNN layers: tutorials/basics/using_gnn_layers.ipynb - - Implementing GNN layers: tutorials/basics/implementing_gnn_layers.ipynb - - Making GNN networks: tutorials/basics/making_gnn_networks.ipynb - - Simple Molecular Model: tutorials/model_training/simple-molecular-model.ipynb - - Design: design.md - - Datasets: datasets.md - - Pretrained Models: pretrained_models.md - - Contribute: contribute.md - - License: license.md - - API: - - goli.nn: api/goli.nn.md - - goli.features: api/goli.features.md - - goli.trainer: api/goli.trainer.md - - goli.data: api/goli.data.md - - goli.utils: api/goli.utils.md - - CLI: cli_references.md - -theme: - name: material - palette: - primary: red - accent: indigo - features: - - navigation.expand - favicon: images/logo.png - logo: images/logo.png - -extra_css: - - _assets/css/custom.css - -extra_javascript: - - javascripts/config.js - - https://polyfill.io/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - -markdown_extensions: - - admonition - - markdown_include.include - - pymdownx.emoji - - pymdownx.magiclink - - pymdownx.superfences - - pymdownx.tabbed - - pymdownx.tasklist - - pymdownx.details - - mkdocs-click - - pymdownx.arithmatex: - generic: true - - toc: - permalink: true - -plugins: - - search - - mkdocstrings: - watch: - - goli/ - handlers: - python: - setup_commands: - - import sys - - sys.path.append("docs") - - sys.path.append("goli") - selection: - new_path_syntax: yes - rendering: - show_root_heading: yes - heading_level: 3 - - mkdocs-jupyter: - execute: False - # kernel_name: python3 +site_name: "goli" +site_description: "A deep learning library focused on graph representation learning for real-world chemical tasks." +site_url: "https://github.com/valence-discovery/goli" +repo_url: "https://github.com/valence-discovery/goli" +repo_name: "valence-discovery/goli" +copyright: Copyright 2020 - 2021 Valence Discovery + +remote_branch: "privpage" +use_directory_urls: false +docs_dir: "docs" + +nav: + - Overview: index.md + - Tutorials: + - Using GNN layers: tutorials/basics/using_gnn_layers.ipynb + - Implementing GNN layers: tutorials/basics/implementing_gnn_layers.ipynb + - Making GNN networks: tutorials/basics/making_gnn_networks.ipynb + - Simple Molecular Model: tutorials/model_training/simple-molecular-model.ipynb + - Design: design.md + - Datasets: datasets.md + - Pretrained Models: pretrained_models.md + - Contribute: contribute.md + - License: license.md + - API: + - goli.nn: api/goli.nn.md + - goli.features: api/goli.features.md + - goli.trainer: api/goli.trainer.md + - goli.data: api/goli.data.md + - goli.utils: api/goli.utils.md + - CLI: cli_references.md + +theme: + name: material + palette: + primary: red + accent: indigo + features: + - navigation.expand + favicon: images/logo.png + logo: images/logo.png + +extra_css: + - _assets/css/custom.css + +extra_javascript: + - javascripts/config.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +markdown_extensions: + - admonition + - markdown_include.include + - pymdownx.emoji + - pymdownx.magiclink + - pymdownx.superfences + - pymdownx.tabbed + - pymdownx.tasklist + - pymdownx.details + - mkdocs-click + - pymdownx.arithmatex: + generic: true + - toc: + permalink: true + +plugins: + - search + - mkdocstrings: + watch: + - goli/ + handlers: + python: + setup_commands: + - import sys + - sys.path.append("docs") + - sys.path.append("goli") + selection: + new_path_syntax: yes + rendering: + show_root_heading: yes + heading_level: 3 + - mkdocs-jupyter: + execute: False + # kernel_name: python3 diff --git a/news/TEMPLATE.rst b/news/TEMPLATE.rst index 790d30b19..a16e51283 100644 --- a/news/TEMPLATE.rst +++ b/news/TEMPLATE.rst @@ -1,23 +1,23 @@ -**Added:** - -* - -**Changed:** - -* - -**Deprecated:** - -* - -**Removed:** - -* - -**Fixed:** - -* - -**Security:** - -* +**Added:** + +* + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/news/cache.rst b/news/cache.rst index 1f13c59ee..f2d0045fb 100644 --- a/news/cache.rst +++ b/news/cache.rst @@ -1,23 +1,23 @@ -**Added:** - -* - -**Changed:** - -* Save featurization args in datamodule cache and prevent reloading when the feature args are different than the one in the cache. - -**Deprecated:** - -* - -**Removed:** - -* - -**Fixed:** - -* - -**Security:** - -* +**Added:** + +* + +**Changed:** + +* Save featurization args in datamodule cache and prevent reloading when the feature args are different than the one in the cache. + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/news/datasets.rst b/news/datasets.rst index 0224cba02..c9e85ae54 100644 --- a/news/datasets.rst +++ b/news/datasets.rst @@ -1,24 +1,24 @@ -**Added:** - -* Add functions and CLI to list and download datasets from Goli public GCS bucket. -* Add logic to load a pretrained model from the Goli GCS bucket. - -**Changed:** - -* Remove examples folder in doc to tutorials. - -**Deprecated:** - -* - -**Removed:** - -* - -**Fixed:** - -* - -**Security:** - -* +**Added:** + +* Add functions and CLI to list and download datasets from Goli public GCS bucket. +* Add logic to load a pretrained model from the Goli GCS bucket. + +**Changed:** + +* Remove examples folder in doc to tutorials. + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/news/ogb.rst b/news/ogb.rst index 88b00c8e4..8ed37e92b 100644 --- a/news/ogb.rst +++ b/news/ogb.rst @@ -1,23 +1,23 @@ -**Added:** - -* Add a datamodule for OGB - -**Changed:** - -* - -**Deprecated:** - -* - -**Removed:** - -* - -**Fixed:** - -* - -**Security:** - -* +**Added:** + +* Add a datamodule for OGB + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/notebooks/dev-datamodule-invalidate-cache.ipynb b/notebooks/dev-datamodule-invalidate-cache.ipynb index 615202b0d..91349a92d 100644 --- a/notebooks/dev-datamodule-invalidate-cache.ipynb +++ b/notebooks/dev-datamodule-invalidate-cache.ipynb @@ -1,449 +1,449 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import pathlib\n", - "import functools\n", - "import tempfile\n", - "\n", - "import numpy as np\n", - "import pytorch_lightning as pl\n", - "import torch\n", - "import datamol as dm\n", - "\n", - "import goli" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# Setup a temporary cache file. Only for\n", - "# demo purposes, use a known path in prod.\n", - "cache_data_path = pathlib.Path(tempfile.mkdtemp()) / \"cache.pkl\"\n", - "cache_data_path = \"/home/hadim/test-cache.pkl\"\n", - "\n", - "# Load a dataframe\n", - "df = goli.data.load_tiny_zinc()\n", - "df.head()\n", - "\n", - "# Setup the featurization\n", - "featurization_args = {}\n", - "featurization_args[\"atom_property_list_onehot\"] = [\"atomic-number\", \"valence\"]\n", - "featurization_args[\"atom_property_list_float\"] = [\"mass\", \"electronegativity\", \"in-ring\"]\n", - "featurization_args[\"edge_property_list\"] = [\"bond-type-onehot\", \"stereo\", \"in-ring\"]\n", - "featurization_args[\"add_self_loop\"] = False\n", - "featurization_args[\"use_bonds_weights\"] = False\n", - "featurization_args[\"explicit_H\"] = False\n", - "\n", - "# Config for datamodule\n", - "dm_args = {}\n", - "dm_args[\"df\"] = df\n", - "dm_args[\"cache_data_path\"] = cache_data_path\n", - "dm_args[\"featurization\"] = featurization_args\n", - "dm_args[\"smiles_col\"] = \"SMILES\"\n", - "dm_args[\"label_cols\"] = [\"SA\"]\n", - "dm_args[\"split_val\"] = 0.2\n", - "dm_args[\"split_test\"] = 0.2\n", - "dm_args[\"split_seed\"] = 19\n", - "dm_args[\"batch_size_train_val\"] = 16\n", - "dm_args[\"batch_size_test\"] = 16\n", - "dm_args[\"num_workers\"] = 0\n", - "dm_args[\"pin_memory\"] = True\n", - "dm_args[\"featurization_n_jobs\"] = 16\n", - "dm_args[\"featurization_progress\"] = True\n", - "\n", - "datam = goli.data.DGLFromSmilesDataModule(**dm_args)\n", - "# datam" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-04-30 14:19:26.972 | INFO | goli.data.datamodule:_load_from_cache:460 - Try reloading the data module from /home/hadim/test-cache.pkl.\n", - "2021-04-30 14:19:27.001 | INFO | goli.data.datamodule:_load_from_cache:485 - Cache featurizer arguments are different than the provided ones.\n", - "2021-04-30 14:19:27.001 | INFO | goli.data.datamodule:_load_from_cache:486 - Cache featurizer arguments: {'atom_property_list_onehot': ['atomic-number', 'valence'], 'atom_property_list_float': ['mass', 'electronegativity', 'in-ring'], 'edge_property_list': ['bond-type-onehot', 'stereo'], 'add_self_loop': False, 'explicit_H': False, 'use_bonds_weights': False, 'pos_encoding_as_features': None, 'pos_encoding_as_directions': None, 'dtype': torch.float32}\n", - "2021-04-30 14:19:27.002 | INFO | goli.data.datamodule:_load_from_cache:487 - Provided featurizer arguments: {'atom_property_list_onehot': ['atomic-number', 'valence'], 'atom_property_list_float': ['mass', 'electronegativity', 'in-ring'], 'edge_property_list': ['bond-type-onehot', 'stereo', 'in-ring'], 'add_self_loop': False, 'explicit_H': False, 'use_bonds_weights': False, 'pos_encoding_as_features': None, 'pos_encoding_as_directions': None, 'dtype': torch.float32}.\n", - "2021-04-30 14:19:27.002 | INFO | goli.data.datamodule:_load_from_cache:488 - Fallback to regular data preparation steps.\n", - "2021-04-30 14:19:27.003 | INFO | goli.data.datamodule:prepare_data:313 - Prepare dataset with 100 data points.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "00871e6c7c86454181162e37d837daf2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/100 [00:00 3\n", - " - precision > 3\n", - " loss_fun: mse\n", - " random_seed: 42\n", - " optim_kwargs:\n", - " lr: 0.01\n", - " weight_decay: 1.0e-07\n", - " lr_reduce_on_plateau_kwargs:\n", - " factor: 0.5\n", - " patience: 7\n", - " scheduler_kwargs:\n", - " monitor: loss/val\n", - " frequency: 1\n", - " target_nan_mask: 0\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"predictor\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Metrics\n", - "\n", - "All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.\n", - "\n", - "See class `goli.trainer.metrics.MetricWrapper` for more details.\n", - "\n", - "See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "metrics:\n", - "- name: mae\n", - " metric: mae\n", - " threshold_kwargs: null\n", - "- name: pearsonr\n", - " metric: pearsonr\n", - " threshold_kwargs: null\n", - "- name: f1 > 3\n", - " metric: f1\n", - " num_classes: 2\n", - " average: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 3\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "- name: f1 > 5\n", - " metric: f1\n", - " num_classes: 2\n", - " average: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 5\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "- name: precision > 3\n", - " metric: precision\n", - " class_reduction: micro\n", - " threshold_kwargs:\n", - " operator: greater\n", - " threshold: 3\n", - " th_on_preds: true\n", - " th_on_target: true\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"metrics\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Trainer\n", - "\n", - "Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainer:\n", - " logger:\n", - " save_dir: logs/micro_ZINC\n", - " early_stopping:\n", - " monitor: loss/val\n", - " min_delta: 0\n", - " patience: 10\n", - " mode: min\n", - " model_checkpoint:\n", - " dirpath: models_checkpoints/micro_ZINC/\n", - " filename: bob\n", - " monitor: loss/val\n", - " mode: min\n", - " save_top_k: 1\n", - " period: 1\n", - " trainer:\n", - " max_epochs: 25\n", - " min_epochs: 5\n", - " gpus: 1\n", - "\n" - ] - } - ], - "source": [ - "print_config_with_key(yaml_config, \"trainer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training the model\n", - "\n", - "Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using backend: pytorch\n", - "2021-03-25 09:44:37.314 | WARNING | goli.config._loader:load_trainer:111 - Number of GPUs selected is `1`, but will be ignored since no GPU are available on this device\n", - "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: Checkpoint directory models_checkpoints/micro_ZINC/ exists and is not empty.\n", - " warnings.warn(*args, **kwargs)\n", - "GPU available: False, used: False\n", - "TPU available: None, using: 0 TPU cores\n", - "2021-03-25 09:44:37.331 | INFO | goli.data.datamodule:prepare_data:153 - Reload data from goli/data/cache/micro_ZINC/full.cache.\n", - "\n", - "datamodule:\n", - " name: DGLFromSmilesDataModule\n", - "len: 1000\n", - "batch_size_train_val: 128\n", - "batch_size_test: 256\n", - "num_node_feats: 55\n", - "num_edge_feats: 0\n", - "collate_fn: goli_collate_fn\n", - "featurization:\n", - " atom_property_list_onehot:\n", - " - atomic-number\n", - " - valence\n", - " atom_property_list_float:\n", - " - mass\n", - " - electronegativity\n", - " - in-ring\n", - " edge_property_list: []\n", - " add_self_loop: false\n", - " explicit_H: false\n", - " use_bonds_weights: false\n", - " \n", - "\n", - "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 3': f1(>3), 'f1 > 5': f1(>5), 'precision > 3': precision(>3)}\n", - "DGL_GNN\n", - "---------\n", - " pre-NN(depth=1, ResidualConnectionNone)\n", - " [FCLayer[55 -> 32]\n", - " \n", - " GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", - " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", - " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", - " \n", - " post-NN(depth=2, ResidualConnectionNone)\n", - " [FCLayer[32 -> 32 -> 1]\n", - " | Name | Type | Params\n", - "------------------------------------------------------------------------------\n", - "0 | model | FullDGLNetwork | 69.7 K\n", - "1 | model.pre_nn | FeedForwardNN | 1.9 K \n", - "2 | model.pre_nn.activation | ReLU | 0 \n", - "3 | model.pre_nn.residual_layer | ResidualConnectionNone | 0 \n", - "4 | model.pre_nn.layers | ModuleList | 1.9 K \n", - "5 | model.pre_nn.layers.0 | FCLayer | 1.9 K \n", - "6 | model.gnn | FeedForwardDGL | 66.7 K\n", - "7 | model.gnn.activation | ReLU | 0 \n", - "8 | model.gnn.layers | ModuleList | 62.2 K\n", - "9 | model.gnn.layers.0 | PNAMessagePassingLayer | 15.6 K\n", - "10 | model.gnn.layers.1 | PNAMessagePassingLayer | 15.6 K\n", - "11 | model.gnn.layers.2 | PNAMessagePassingLayer | 15.6 K\n", - "12 | model.gnn.layers.3 | PNAMessagePassingLayer | 15.6 K\n", - "13 | model.gnn.virtual_node_layers | ModuleList | 3.4 K \n", - "14 | model.gnn.virtual_node_layers.0 | VirtualNode | 1.1 K \n", - "15 | model.gnn.virtual_node_layers.1 | VirtualNode | 1.1 K \n", - "16 | model.gnn.virtual_node_layers.2 | VirtualNode | 1.1 K \n", - "17 | model.gnn.residual_layer | ResidualConnectionSimple | 0 \n", - "18 | model.gnn.global_pool_layer | ModuleListConcat | 0 \n", - "19 | model.gnn.global_pool_layer.0 | SumPooling | 0 \n", - "20 | model.gnn.out_linear | FCLayer | 1.1 K \n", - "21 | model.gnn.out_linear.linear | Linear | 1.1 K \n", - "22 | model.gnn.out_linear.dropout | Dropout | 0 \n", - "23 | model.gnn.out_linear.batch_norm | BatchNorm1d | 64 \n", - "24 | model.post_nn | FeedForwardNN | 1.2 K \n", - "25 | model.post_nn.activation | ReLU | 0 \n", - "26 | model.post_nn.residual_layer | ResidualConnectionNone | 0 \n", - "27 | model.post_nn.layers | ModuleList | 1.2 K \n", - "28 | model.post_nn.layers.0 | FCLayer | 1.1 K \n", - "29 | model.post_nn.layers.1 | FCLayer | 33 \n", - "30 | loss_fun | MSELoss | 0 \n", - "------------------------------------------------------------------------------\n", - "69.7 K Trainable params\n", - "0 Non-trainable params\n", - "69.7 K Total params\n", - "0.279 Total estimated model params size (MB)\n", - "\n", - " | Name | Type | Params\n", - "--------------------------------------------\n", - "0 | model | FullDGLNetwork | 69.7 K\n", - "1 | loss_fun | MSELoss | 0 \n", - "--------------------------------------------\n", - "69.7 K Trainable params\n", - "0 Non-trainable params\n", - "69.7 K Total params\n", - "0.279 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "61dc1894ee264599ab493d982b390430", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation sanity check: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "657567b3d0b546a1a648173c2bfb1e4a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "01792491c7fd49b08ce5086832135c7b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b9d96f2469234fe380d07ffd806350ed", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eed00b1f81524fe99e07e2084c952532", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c435a81142c24be09a0e38d7575b365b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ae0357d542ec4021b0c9c38fb38bd11c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1d24f83f1504d3b80b0741d1c0f404b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "647aee1f810f407c90697c394d0f604f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c69086a4802e421f816cab1ebbc20ae9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ebb4a32a0f78470ba21ff5bea8b450f4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "840d27d095344fcd9a74862e61a2fe7f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "905c7ee70f4b4282871c130b0c7b9f0a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "daab3285c0854c23a4e3dd47846c820e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e0e0d9096fc64198a470ae1b3cd7f351", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "99ac351f4e334e8c838a6913ef6bee08", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "69b47fad071248eab8095d67e33b5d5e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8c68dcc01135429e845427bb6908f414", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "096aaea9ce2649fba9bf70b99b7e7955", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d8ecff999d934a119157a3e0ca7a1c6a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4287a56d059b4eb2966eb2e90498a210", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0a5ffab4db4e4768a4876b01a8b10f96", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6177ef595f9542598e5b065d6d77bb32", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e86938e35b0b443791119e37dd2e2199", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "98aff21b49cc434dbaaaf12c355ab783", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e4c88c49c0c843c09934e9786e9b6aa5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "dcb917418a084d4ba36d57f5b0406819", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\n", - "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", - "os.chdir(MAIN_DIR)\n", - "\n", - "cfg = dict(deepcopy(yaml_config))\n", - "\n", - "# Load and initialize the dataset\n", - "datamodule = load_datamodule(cfg)\n", - "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n", - "\n", - "# Initialize the network\n", - "model_class, model_kwargs = load_architecture(\n", - " cfg,\n", - " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", - " in_dim_edges=datamodule.num_edge_feats,\n", - ")\n", - "\n", - "metrics = load_metrics(cfg)\n", - "print(metrics)\n", - "\n", - "predictor = load_predictor(cfg, model_class, model_kwargs, metrics)\n", - "\n", - "print(predictor.model)\n", - "print(predictor.summarize(mode=4, to_print=False))\n", - "\n", - "trainer = load_trainer(cfg, metrics)\n", - "\n", - "# Run the model training\n", - "trainer.fit(model=predictor, datamodule=datamodule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:goli]", - "language": "python", - "name": "conda-env-goli-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": {}, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "import goli\n", + "# from goli.config._loader import (load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Constants\n", + "\n", + "First, we define the constants such as the random seed and whether the model should raise or ignore an error." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "constants:\n", + " seed: 42\n", + " raise_train_error: true\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"constants\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Datamodule\n", + "\n", + "Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.\n", + "\n", + "For more details, see class `goli.data.datamodule.DGLFromSmilesDataModule`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "datamodule:\n", + " df_path: goli/data/micro_ZINC/micro_ZINC.csv\n", + " cache_data_path: goli/data/cache/micro_ZINC/full.cache\n", + " label_cols:\n", + " - score\n", + " smiles_col: SMILES\n", + " featurization_n_jobs: -1\n", + " featurization_progress: true\n", + " featurization:\n", + " atom_property_list_onehot:\n", + " - atomic-number\n", + " - valence\n", + " atom_property_list_float:\n", + " - mass\n", + " - electronegativity\n", + " - in-ring\n", + " edge_property_list: []\n", + " add_self_loop: false\n", + " explicit_H: false\n", + " use_bonds_weights: false\n", + " split_val: 0.2\n", + " split_test: 0.2\n", + " split_seed: 42\n", + " splits_path: null\n", + " batch_size_train_val: 128\n", + " batch_size_test: 256\n", + " num_workers: -1\n", + " pin_memory: false\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"datamodule\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Architecture\n", + "\n", + "In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).\n", + "\n", + "The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as `gcn`, `gin`, `gat`, `gated-gcn`, `pna-conv` and `pna-msgpass`.\n", + "\n", + "For more details, see the following classes:\n", + "\n", + "- `goli.nn.architecture.FullDGLNetwork`: Main class for the architecture\n", + "- `goli.nn.architecture.FeedForwardNN`: Main class for the inputs and outputs MLP\n", + "- `goli.nn.architecture.FeedForwardDGL`: Main class for the GNN layers" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "architecture:\n", + " model_type: fulldglnetwork\n", + " pre_nn:\n", + " out_dim: 32\n", + " hidden_dims: 32\n", + " depth: 1\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: true\n", + " residual_type: none\n", + " gnn:\n", + " out_dim: 32\n", + " hidden_dims: 32\n", + " depth: 4\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: true\n", + " residual_type: simple\n", + " pooling: sum\n", + " virtual_node: sum\n", + " layer_type: pna-msgpass\n", + " layer_kwargs:\n", + " aggregators:\n", + " - mean\n", + " - max\n", + " - min\n", + " - std\n", + " scalers:\n", + " - identity\n", + " - amplification\n", + " - attenuation\n", + " post_nn:\n", + " out_dim: 1\n", + " hidden_dims: 32\n", + " depth: 2\n", + " activation: relu\n", + " last_activation: none\n", + " dropout: 0.1\n", + " batch_norm: true\n", + " last_batch_norm: false\n", + " residual_type: none\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"architecture\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Predictor\n", + "\n", + "In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predictor:\n", + " metrics_on_progress_bar:\n", + " - mae\n", + " - pearsonr\n", + " - f1 > 3\n", + " - precision > 3\n", + " loss_fun: mse\n", + " random_seed: 42\n", + " optim_kwargs:\n", + " lr: 0.01\n", + " weight_decay: 1.0e-07\n", + " lr_reduce_on_plateau_kwargs:\n", + " factor: 0.5\n", + " patience: 7\n", + " scheduler_kwargs:\n", + " monitor: loss/val\n", + " frequency: 1\n", + " target_nan_mask: 0\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"predictor\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metrics\n", + "\n", + "All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.\n", + "\n", + "See class `goli.trainer.metrics.MetricWrapper` for more details.\n", + "\n", + "See `goli.trainer.metrics.METRICS_CLASSIFICATION` and `goli.trainer.metrics.METRICS_REGRESSION` for a dictionnary of accepted metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "metrics:\n", + "- name: mae\n", + " metric: mae\n", + " threshold_kwargs: null\n", + "- name: pearsonr\n", + " metric: pearsonr\n", + " threshold_kwargs: null\n", + "- name: f1 > 3\n", + " metric: f1\n", + " num_classes: 2\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 3\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "- name: f1 > 5\n", + " metric: f1\n", + " num_classes: 2\n", + " average: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 5\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "- name: precision > 3\n", + " metric: precision\n", + " class_reduction: micro\n", + " threshold_kwargs:\n", + " operator: greater\n", + " threshold: 3\n", + " th_on_preds: true\n", + " th_on_target: true\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"metrics\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Trainer\n", + "\n", + "Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainer:\n", + " logger:\n", + " save_dir: logs/micro_ZINC\n", + " early_stopping:\n", + " monitor: loss/val\n", + " min_delta: 0\n", + " patience: 10\n", + " mode: min\n", + " model_checkpoint:\n", + " dirpath: models_checkpoints/micro_ZINC/\n", + " filename: bob\n", + " monitor: loss/val\n", + " mode: min\n", + " save_top_k: 1\n", + " period: 1\n", + " trainer:\n", + " max_epochs: 25\n", + " min_epochs: 5\n", + " gpus: 1\n", + "\n" + ] + } + ], + "source": [ + "print_config_with_key(yaml_config, \"trainer\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the model\n", + "\n", + "Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n", + "2021-03-25 09:44:37.314 | WARNING | goli.config._loader:load_trainer:111 - Number of GPUs selected is `1`, but will be ignored since no GPU are available on this device\n", + "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: Checkpoint directory models_checkpoints/micro_ZINC/ exists and is not empty.\n", + " warnings.warn(*args, **kwargs)\n", + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n", + "2021-03-25 09:44:37.331 | INFO | goli.data.datamodule:prepare_data:153 - Reload data from goli/data/cache/micro_ZINC/full.cache.\n", + "\n", + "datamodule:\n", + " name: DGLFromSmilesDataModule\n", + "len: 1000\n", + "batch_size_train_val: 128\n", + "batch_size_test: 256\n", + "num_node_feats: 55\n", + "num_edge_feats: 0\n", + "collate_fn: goli_collate_fn\n", + "featurization:\n", + " atom_property_list_onehot:\n", + " - atomic-number\n", + " - valence\n", + " atom_property_list_float:\n", + " - mass\n", + " - electronegativity\n", + " - in-ring\n", + " edge_property_list: []\n", + " add_self_loop: false\n", + " explicit_H: false\n", + " use_bonds_weights: false\n", + " \n", + "\n", + "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 3': f1(>3), 'f1 > 5': f1(>5), 'precision > 3': precision(>3)}\n", + "DGL_GNN\n", + "---------\n", + " pre-NN(depth=1, ResidualConnectionNone)\n", + " [FCLayer[55 -> 32]\n", + " \n", + " GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", + " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", + " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", + " \n", + " post-NN(depth=2, ResidualConnectionNone)\n", + " [FCLayer[32 -> 32 -> 1]\n", + " | Name | Type | Params\n", + "------------------------------------------------------------------------------\n", + "0 | model | FullDGLNetwork | 69.7 K\n", + "1 | model.pre_nn | FeedForwardNN | 1.9 K \n", + "2 | model.pre_nn.activation | ReLU | 0 \n", + "3 | model.pre_nn.residual_layer | ResidualConnectionNone | 0 \n", + "4 | model.pre_nn.layers | ModuleList | 1.9 K \n", + "5 | model.pre_nn.layers.0 | FCLayer | 1.9 K \n", + "6 | model.gnn | FeedForwardDGL | 66.7 K\n", + "7 | model.gnn.activation | ReLU | 0 \n", + "8 | model.gnn.layers | ModuleList | 62.2 K\n", + "9 | model.gnn.layers.0 | PNAMessagePassingLayer | 15.6 K\n", + "10 | model.gnn.layers.1 | PNAMessagePassingLayer | 15.6 K\n", + "11 | model.gnn.layers.2 | PNAMessagePassingLayer | 15.6 K\n", + "12 | model.gnn.layers.3 | PNAMessagePassingLayer | 15.6 K\n", + "13 | model.gnn.virtual_node_layers | ModuleList | 3.4 K \n", + "14 | model.gnn.virtual_node_layers.0 | VirtualNode | 1.1 K \n", + "15 | model.gnn.virtual_node_layers.1 | VirtualNode | 1.1 K \n", + "16 | model.gnn.virtual_node_layers.2 | VirtualNode | 1.1 K \n", + "17 | model.gnn.residual_layer | ResidualConnectionSimple | 0 \n", + "18 | model.gnn.global_pool_layer | ModuleListConcat | 0 \n", + "19 | model.gnn.global_pool_layer.0 | SumPooling | 0 \n", + "20 | model.gnn.out_linear | FCLayer | 1.1 K \n", + "21 | model.gnn.out_linear.linear | Linear | 1.1 K \n", + "22 | model.gnn.out_linear.dropout | Dropout | 0 \n", + "23 | model.gnn.out_linear.batch_norm | BatchNorm1d | 64 \n", + "24 | model.post_nn | FeedForwardNN | 1.2 K \n", + "25 | model.post_nn.activation | ReLU | 0 \n", + "26 | model.post_nn.residual_layer | ResidualConnectionNone | 0 \n", + "27 | model.post_nn.layers | ModuleList | 1.2 K \n", + "28 | model.post_nn.layers.0 | FCLayer | 1.1 K \n", + "29 | model.post_nn.layers.1 | FCLayer | 33 \n", + "30 | loss_fun | MSELoss | 0 \n", + "------------------------------------------------------------------------------\n", + "69.7 K Trainable params\n", + "0 Non-trainable params\n", + "69.7 K Total params\n", + "0.279 Total estimated model params size (MB)\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------\n", + "0 | model | FullDGLNetwork | 69.7 K\n", + "1 | loss_fun | MSELoss | 0 \n", + "--------------------------------------------\n", + "69.7 K Trainable params\n", + "0 Non-trainable params\n", + "69.7 K Total params\n", + "0.279 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "61dc1894ee264599ab493d982b390430", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation sanity check: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dominique/anaconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:51: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "657567b3d0b546a1a648173c2bfb1e4a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "01792491c7fd49b08ce5086832135c7b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b9d96f2469234fe380d07ffd806350ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eed00b1f81524fe99e07e2084c952532", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c435a81142c24be09a0e38d7575b365b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ae0357d542ec4021b0c9c38fb38bd11c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a1d24f83f1504d3b80b0741d1c0f404b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "647aee1f810f407c90697c394d0f604f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c69086a4802e421f816cab1ebbc20ae9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ebb4a32a0f78470ba21ff5bea8b450f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "840d27d095344fcd9a74862e61a2fe7f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "905c7ee70f4b4282871c130b0c7b9f0a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "daab3285c0854c23a4e3dd47846c820e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e0e0d9096fc64198a470ae1b3cd7f351", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "99ac351f4e334e8c838a6913ef6bee08", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "69b47fad071248eab8095d67e33b5d5e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8c68dcc01135429e845427bb6908f414", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "096aaea9ce2649fba9bf70b99b7e7955", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d8ecff999d934a119157a3e0ca7a1c6a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4287a56d059b4eb2966eb2e90498a210", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0a5ffab4db4e4768a4876b01a8b10f96", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6177ef595f9542598e5b065d6d77bb32", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e86938e35b0b443791119e37dd2e2199", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98aff21b49cc434dbaaaf12c355ab783", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4c88c49c0c843c09934e9786e9b6aa5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dcb917418a084d4ba36d57f5b0406819", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", + "os.chdir(MAIN_DIR)\n", + "\n", + "cfg = dict(deepcopy(yaml_config))\n", + "\n", + "# Load and initialize the dataset\n", + "datamodule = load_datamodule(cfg)\n", + "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n", + "\n", + "# Initialize the network\n", + "model_class, model_kwargs = load_architecture(\n", + " cfg,\n", + " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", + " in_dim_edges=datamodule.num_edge_feats,\n", + ")\n", + "\n", + "metrics = load_metrics(cfg)\n", + "print(metrics)\n", + "\n", + "predictor = load_predictor(cfg, model_class, model_kwargs, metrics)\n", + "\n", + "print(predictor.model)\n", + "print(predictor.summarize(mode=4, to_print=False))\n", + "\n", + "trainer = load_trainer(cfg, metrics)\n", + "\n", + "# Run the model training\n", + "trainer.fit(model=predictor, datamodule=datamodule)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:goli]", + "language": "python", + "name": "conda-env-goli-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/dev.ipynb b/notebooks/dev.ipynb index f2c4597ae..b66cf5421 100644 --- a/notebooks/dev.ipynb +++ b/notebooks/dev.ipynb @@ -1,113 +1,113 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import pathlib\n", - "import functools\n", - "import tempfile\n", - "import yaml\n", - "\n", - "from loguru import logger\n", - "\n", - "import numpy as np\n", - "import pytorch_lightning as pl\n", - "import torch\n", - "import datamol as dm\n", - "import pandas as pd\n", - "\n", - "import goli" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/hadim/test-data/goli-zinc-micro\n" - ] - } - ], - "source": [ - "import goli\n", - "\n", - "dataset_dir = \"/home/hadim/test-data\"\n", - "data_path = goli.data.utils.download_goli_dataset(\"goli-zinc-micro\", output_path=dataset_dir)\n", - "print(data_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/home/hadim/test-data/goli-zinc-micro'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:goli]", - "language": "python", - "name": "conda-env-goli-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.4" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": {}, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import pathlib\n", + "import functools\n", + "import tempfile\n", + "import yaml\n", + "\n", + "from loguru import logger\n", + "\n", + "import numpy as np\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import datamol as dm\n", + "import pandas as pd\n", + "\n", + "import goli" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/hadim/test-data/goli-zinc-micro\n" + ] + } + ], + "source": [ + "import goli\n", + "\n", + "dataset_dir = \"/home/hadim/test-data\"\n", + "data_path = goli.data.utils.download_goli_dataset(\"goli-zinc-micro\", output_path=dataset_dir)\n", + "print(data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/hadim/test-data/goli-zinc-micro'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:goli]", + "language": "python", + "name": "conda-env-goli-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.4" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/running-model-from-config.ipynb b/notebooks/running-model-from-config.ipynb index 5192e9988..2f252f890 100644 --- a/notebooks/running-model-from-config.ipynb +++ b/notebooks/running-model-from-config.ipynb @@ -1,220 +1,220 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using backend: pytorch\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "# General imports\n", - "import os\n", - "from os.path import dirname, abspath\n", - "import yaml\n", - "from copy import deepcopy\n", - "from omegaconf import DictConfig, OmegaConf\n", - "\n", - "\n", - "# Current project imports\n", - "import goli\n", - "from goli.utils.config_loader import (\n", - " config_load_constants,\n", - " config_load_dataset,\n", - " config_load_architecture,\n", - " config_load_metrics,\n", - " config_load_predictor,\n", - " config_load_training,\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Read the config file" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dominique/anaconda3/envs/goli_env/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /opt/conda/conda-bld/pytorch_1607370156314/work/c10/cuda/CUDAFunctions.cpp:100.)\n", - " return torch._C._cuda_getDeviceCount() > 0\n" - ] - } - ], - "source": [ - "# Set up the working directory\n", - "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", - "os.chdir(MAIN_DIR)\n", - "\n", - "with open(os.path.join(MAIN_DIR, \"expts/config_micro_ZINC.yaml\"), \"r\") as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "cfg = dict(deepcopy(cfg))\n", - "\n", - "# Get the general parameters and generate the train/val/test datasets\n", - "data_device, model_device, dtype, exp_name, seed, raise_train_error = config_load_constants(\n", - " **cfg[\"constants\"], main_dir=MAIN_DIR\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load a dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "datamodule:\n", - " name: DGLFromSmilesDataModule\n", - "len: 1000\n", - "batch_size_train_val: 128\n", - "batch_size_test: 256\n", - "num_node_feats: 55\n", - "num_edge_feats: 13\n", - "collate_fn: goli_collate_fn\n", - "featurization:\n", - " atom_property_list_onehot:\n", - " - atomic-number\n", - " - valence\n", - " atom_property_list_float:\n", - " - mass\n", - " - electronegativity\n", - " - in-ring\n", - " edge_property_list:\n", - " - bond-type-onehot\n", - " - stereo\n", - " - in-ring\n", - " add_self_loop: false\n", - " explicit_H: false\n", - " use_bonds_weights: false\n", - " \n", - "\n" - ] - } - ], - "source": [ - "\n", - "# Load and initialize the dataset\n", - "datamodule = config_load_dataset(**cfg[\"datasets\"], main_dir=MAIN_DIR,)\n", - "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "model:\n", - " DGL_GNN\n", - "---------\n", - " pre-trans-NN(depth=1, ResidualConnectionNone)\n", - " [FCLayer[55 -> 32] -> Linear(32)\n", - " \n", - " main-GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", - " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", - " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", - " \n", - " post-trans-NN(depth=2, ResidualConnectionNone)\n", - " [FCLayer[32 -> 32 -> 32] -> Linear(32) \n", - "\n" - ] - } - ], - "source": [ - "# Initialize the network\n", - "model = config_load_architecture(\n", - " **cfg[\"architecture\"],\n", - " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", - " in_dim_edges=datamodule.num_edge_feats\n", - ")\n", - "\n", - "print(\"\\nmodel:\\n\", model, \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 5': f1(>5), 'precision > 5': precision(>5), 'auroc > 5': auroc(>5)}\n" - ] - } - ], - "source": [ - "metrics = config_load_metrics(cfg[\"metrics\"])\n", - "print(metrics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.6" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": {}, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# General imports\n", + "import os\n", + "from os.path import dirname, abspath\n", + "import yaml\n", + "from copy import deepcopy\n", + "from omegaconf import DictConfig, OmegaConf\n", + "\n", + "\n", + "# Current project imports\n", + "import goli\n", + "from goli.utils.config_loader import (\n", + " config_load_constants,\n", + " config_load_dataset,\n", + " config_load_architecture,\n", + " config_load_metrics,\n", + " config_load_predictor,\n", + " config_load_training,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read the config file" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dominique/anaconda3/envs/goli_env/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /opt/conda/conda-bld/pytorch_1607370156314/work/c10/cuda/CUDAFunctions.cpp:100.)\n", + " return torch._C._cuda_getDeviceCount() > 0\n" + ] + } + ], + "source": [ + "# Set up the working directory\n", + "MAIN_DIR = dirname(dirname(abspath(goli.__file__)))\n", + "os.chdir(MAIN_DIR)\n", + "\n", + "with open(os.path.join(MAIN_DIR, \"expts/config_micro_ZINC.yaml\"), \"r\") as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "cfg = dict(deepcopy(cfg))\n", + "\n", + "# Get the general parameters and generate the train/val/test datasets\n", + "data_device, model_device, dtype, exp_name, seed, raise_train_error = config_load_constants(\n", + " **cfg[\"constants\"], main_dir=MAIN_DIR\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load a dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "datamodule:\n", + " name: DGLFromSmilesDataModule\n", + "len: 1000\n", + "batch_size_train_val: 128\n", + "batch_size_test: 256\n", + "num_node_feats: 55\n", + "num_edge_feats: 13\n", + "collate_fn: goli_collate_fn\n", + "featurization:\n", + " atom_property_list_onehot:\n", + " - atomic-number\n", + " - valence\n", + " atom_property_list_float:\n", + " - mass\n", + " - electronegativity\n", + " - in-ring\n", + " edge_property_list:\n", + " - bond-type-onehot\n", + " - stereo\n", + " - in-ring\n", + " add_self_loop: false\n", + " explicit_H: false\n", + " use_bonds_weights: false\n", + " \n", + "\n" + ] + } + ], + "source": [ + "\n", + "# Load and initialize the dataset\n", + "datamodule = config_load_dataset(**cfg[\"datasets\"], main_dir=MAIN_DIR,)\n", + "print(\"\\ndatamodule:\\n\", datamodule, \"\\n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "model:\n", + " DGL_GNN\n", + "---------\n", + " pre-trans-NN(depth=1, ResidualConnectionNone)\n", + " [FCLayer[55 -> 32] -> Linear(32)\n", + " \n", + " main-GNN(depth=4, ResidualConnectionSimple(skip_steps=1))\n", + " PNAMessagePassingLayer[32 -> 32 -> 32 -> 32 -> 32]\n", + " -> Pooling(sum) -> FCLayer(32 -> 32, activation=None)\n", + " \n", + " post-trans-NN(depth=2, ResidualConnectionNone)\n", + " [FCLayer[32 -> 32 -> 32] -> Linear(32) \n", + "\n" + ] + } + ], + "source": [ + "# Initialize the network\n", + "model = config_load_architecture(\n", + " **cfg[\"architecture\"],\n", + " in_dim_nodes=datamodule.num_node_feats_with_positional_encoding,\n", + " in_dim_edges=datamodule.num_edge_feats\n", + ")\n", + "\n", + "print(\"\\nmodel:\\n\", model, \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'mae': mean_absolute_error, 'pearsonr': pearsonr, 'f1 > 5': f1(>5), 'precision > 5': precision(>5), 'auroc > 5': auroc(>5)}\n" + ] + } + ], + "source": [ + "metrics = config_load_metrics(cfg[\"metrics\"])\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyproject.toml b/pyproject.toml index 2a695b50e..a8193dc65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,23 @@ -[tool.black] -line-length = 110 -target-version = ['py37', 'py38'] -include = '\.pyi?$' - -[tool.pytest.ini_options] -minversion = "6.0" -# nb_test_files = false # buggy at the moment -# addopts = "--verbose --cov --cov-report xml --cov-report term" -addopts = "--verbose" -testpaths = ["tests", "notebooks"] -filterwarnings = [ - "ignore::DeprecationWarning:tensorboard.*:", - "ignore::DeprecationWarning:ray.*:", - "ignore::DeprecationWarning:numba.*:", - "ignore::UserWarning:umap.*:", -] - -[tool.coverage.run] -omit = ["setup.py", "tests/*"] - -[tool.coverage.xml] -output = "cov.xml" +[tool.black] +line-length = 110 +target-version = ['py37', 'py38'] +include = '\.pyi?$' + +[tool.pytest.ini_options] +minversion = "6.0" +# nb_test_files = false # buggy at the moment +# addopts = "--verbose --cov --cov-report xml --cov-report term" +addopts = "--verbose" +testpaths = ["tests", "notebooks"] +filterwarnings = [ + "ignore::DeprecationWarning:tensorboard.*:", + "ignore::DeprecationWarning:ray.*:", + "ignore::DeprecationWarning:numba.*:", + "ignore::UserWarning:umap.*:", +] + +[tool.coverage.run] +omit = ["setup.py", "tests/*"] + +[tool.coverage.xml] +output = "cov.xml" diff --git a/rever.xsh b/rever.xsh index b3bc502b9..8249f4baf 100644 --- a/rever.xsh +++ b/rever.xsh @@ -1,28 +1,28 @@ -# Configuration - -$PROJECT = $GITHUB_REPO = 'goli' -$GITHUB_ORG = 'valence-discovery' -$PUSH_TAG_REMOTE = 'git@github.com:valence-discovery/goli.git' - -# Logic - -$AUTHORS_FILENAME = 'AUTHORS.rst' -$AUTHORS_METADATA = '.authors.yml' -$AUTHORS_SORTBY = 'alpha' -$AUTHORS_MAILMAP = '.mailmap' - -$CHANGELOG_FILENAME = 'CHANGELOG.rst' -$CHANGELOG_TEMPLATE = 'TEMPLATE.rst' -$CHANGELOG_NEWS = 'news' - -$FORGE_FEEDSTOCK_ORG = 'valence-forge' -$FORGE_RERENDER = True -$FORGE_USE_GIT_URL = True -$FORGE_FORK = False -$FORGE_PULL_REQUEST = False - -$ACTIVITIES = ['check', 'authors', 'changelog', 'version_bump', 'tag', 'push_tag', 'ghrelease'] - -$VERSION_BUMP_PATTERNS = [('goli/_version.py', r'__version__\s*=.*', "__version__ = \"$VERSION\""), - ('setup.py', r'version\s*=.*,', "version=\"$VERSION\",") - ] +# Configuration + +$PROJECT = $GITHUB_REPO = 'goli' +$GITHUB_ORG = 'valence-discovery' +$PUSH_TAG_REMOTE = 'git@github.com:valence-discovery/goli.git' + +# Logic + +$AUTHORS_FILENAME = 'AUTHORS.rst' +$AUTHORS_METADATA = '.authors.yml' +$AUTHORS_SORTBY = 'alpha' +$AUTHORS_MAILMAP = '.mailmap' + +$CHANGELOG_FILENAME = 'CHANGELOG.rst' +$CHANGELOG_TEMPLATE = 'TEMPLATE.rst' +$CHANGELOG_NEWS = 'news' + +$FORGE_FEEDSTOCK_ORG = 'valence-forge' +$FORGE_RERENDER = True +$FORGE_USE_GIT_URL = True +$FORGE_FORK = False +$FORGE_PULL_REQUEST = False + +$ACTIVITIES = ['check', 'authors', 'changelog', 'version_bump', 'tag', 'push_tag', 'ghrelease'] + +$VERSION_BUMP_PATTERNS = [('goli/_version.py', r'__version__\s*=.*', "__version__ = \"$VERSION\""), + ('setup.py', r'version\s*=.*,', "version=\"$VERSION\",") + ] diff --git a/setup.py b/setup.py index 06d9defba..6ccfad651 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,17 @@ -from setuptools import setup -from setuptools import find_packages - -setup( - name="goli", - version="0.1.0", - author="Valence Discovery", - author_email="dominique@valencediscovery.com", - url="https://github.com/valence-discovery/goli", - description="A deep learning library focused on graph representation learning for real-world chemical tasks.", - long_description_content_type="text/markdown", - packages=find_packages(), - include_package_data=True, - entry_points={ - "console_scripts": ["goli=goli.cli:main_cli"], - }, -) +from setuptools import setup +from setuptools import find_packages + +setup( + name="goli", + version="0.1.0", + author="Valence Discovery", + author_email="dominique@valencediscovery.com", + url="https://github.com/valence-discovery/goli", + description="A deep learning library focused on graph representation learning for real-world chemical tasks.", + long_description_content_type="text/markdown", + packages=find_packages(), + include_package_data=True, + entry_points={ + "console_scripts": ["goli=goli.cli:main_cli"], + }, +) diff --git a/tests/conftest.py b/tests/conftest.py index 8f76705f3..16d777885 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ -import pathlib - -import pytest - -TEST_DIR_PATH = pathlib.Path(__file__).parent / "data" -DATA_DIR_PATH = TEST_DIR_PATH.absolute() - - -@pytest.fixture -def datadir(request): - return DATA_DIR_PATH +import pathlib + +import pytest + +TEST_DIR_PATH = pathlib.Path(__file__).parent / "data" +DATA_DIR_PATH = TEST_DIR_PATH.absolute() + + +@pytest.fixture +def datadir(request): + return DATA_DIR_PATH diff --git a/tests/data/config_micro_ZINC.yaml b/tests/data/config_micro_ZINC.yaml index f67018daa..e8e85b3bd 100644 --- a/tests/data/config_micro_ZINC.yaml +++ b/tests/data/config_micro_ZINC.yaml @@ -1,169 +1,169 @@ -constants: - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -datamodule: - module_type: "DGLFromSmilesDataModule" - args: - df_path: goli/data/micro_ZINC/micro_ZINC.csv - cache_data_path: goli/data/cache/micro_ZINC/full.cache - label_cols: ['score'] - smiles_col: SMILES - - # Featurization - featurization_n_jobs: -1 - featurization_progress: True - featurization: - atom_property_list_onehot: [atomic-number, valence] - atom_property_list_float: [mass, electronegativity, in-ring] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False - use_bonds_weights: False - pos_encoding_as_features: &pos_enc - pos_type: laplacian_eigvec - num_pos: 3 - normalization: "none" - disconnected_comp: True - pos_encoding_as_directions: *pos_enc - - # Train, val, test parameters - split_val: 0.2 - split_test: 0.2 - split_seed: *seed - splits_path: null - batch_size_train_val: 128 - batch_size_test: 128 - - # Data loading - num_workers: 0 - pin_memory: False - persistent_workers: False # Keep True on Windows if running multiple workers - - -architecture: - model_type: fulldglnetwork - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 32 - hidden_dims: 32 - depth: 1 - activation: relu - last_activation: none - dropout: &dropout 0.1 - batch_norm: &batch_norm True - last_batch_norm: *batch_norm - residual_type: none - - pre_nn_edges: # Set as null to avoid a pre-nn network - out_dim: 16 - hidden_dims: 16 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: none - - gnn: # Set as null to avoid a post-nn network - out_dim: 32 - hidden_dims: 32 - depth: 4 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: *batch_norm - residual_type: simple - pooling: [sum, max, dir1] - virtual_node: 'sum' - layer_type: 'dgn-msgpass' - layer_kwargs: - # num_heads: 3 - aggregators: [mean, max, dir1/dx_abs, dir1/smooth] - scalers: [identity, amplification, attenuation] - - post_nn: - out_dim: 1 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - batch_norm: *batch_norm - last_batch_norm: False - residual_type: none - -predictor: - metrics_on_progress_bar: ["mae", "pearsonr", "f1 > 3", "precision > 3"] - loss_fun: mse - random_seed: *seed - optim_kwargs: - lr: 1.e-2 - weight_decay: 1.e-7 - lr_reduce_on_plateau_kwargs: - factor: 0.5 - patience: 7 - scheduler_kwargs: - monitor: &monitor loss/val - frequency: 1 - target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss - - -metrics: - - name: mae - metric: mae - threshold_kwargs: null - - - name: pearsonr - metric: pearsonr - threshold_kwargs: null - - - name: f1 > 3 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: &threshold_3 - operator: greater - threshold: 3 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: f1 > 5 - metric: f1 - num_classes: 2 - average: micro - threshold_kwargs: - operator: greater - threshold: 5 - th_on_preds: True - th_on_target: True - target_to_int: True - - - name: precision > 3 - metric: precision - class_reduction: micro - threshold_kwargs: *threshold_3 - -trainer: - logger: - save_dir: logs/micro_ZINC - early_stopping: - monitor: *monitor - min_delta: 0 - patience: 10 - mode: &mode min - model_checkpoint: - dirpath: models_checkpoints/micro_ZINC/ - filename: "model" - monitor: *monitor - mode: *mode - save_top_k: 1 - period: 1 - trainer: - max_epochs: 25 - min_epochs: 5 - gpus: 1 - +constants: + seed: &seed 42 + raise_train_error: true # Whether the code should raise an error if it crashes during training + +datamodule: + module_type: "DGLFromSmilesDataModule" + args: + df_path: goli/data/micro_ZINC/micro_ZINC.csv + cache_data_path: goli/data/cache/micro_ZINC/full.cache + label_cols: ['score'] + smiles_col: SMILES + + # Featurization + featurization_n_jobs: -1 + featurization_progress: True + featurization: + atom_property_list_onehot: [atomic-number, valence] + atom_property_list_float: [mass, electronegativity, in-ring] + edge_property_list: [bond-type-onehot, stereo, in-ring] + add_self_loop: False + explicit_H: False + use_bonds_weights: False + pos_encoding_as_features: &pos_enc + pos_type: laplacian_eigvec + num_pos: 3 + normalization: "none" + disconnected_comp: True + pos_encoding_as_directions: *pos_enc + + # Train, val, test parameters + split_val: 0.2 + split_test: 0.2 + split_seed: *seed + splits_path: null + batch_size_train_val: 128 + batch_size_test: 128 + + # Data loading + num_workers: 0 + pin_memory: False + persistent_workers: False # Keep True on Windows if running multiple workers + + +architecture: + model_type: fulldglnetwork + pre_nn: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 32 + depth: 1 + activation: relu + last_activation: none + dropout: &dropout 0.1 + batch_norm: &batch_norm True + last_batch_norm: *batch_norm + residual_type: none + + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 16 + hidden_dims: 16 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: none + + gnn: # Set as null to avoid a post-nn network + out_dim: 32 + hidden_dims: 32 + depth: 4 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: *batch_norm + residual_type: simple + pooling: [sum, max, dir1] + virtual_node: 'sum' + layer_type: 'dgn-msgpass' + layer_kwargs: + # num_heads: 3 + aggregators: [mean, max, dir1/dx_abs, dir1/smooth] + scalers: [identity, amplification, attenuation] + + post_nn: + out_dim: 1 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: *dropout + batch_norm: *batch_norm + last_batch_norm: False + residual_type: none + +predictor: + metrics_on_progress_bar: ["mae", "pearsonr", "f1 > 3", "precision > 3"] + loss_fun: mse + random_seed: *seed + optim_kwargs: + lr: 1.e-2 + weight_decay: 1.e-7 + lr_reduce_on_plateau_kwargs: + factor: 0.5 + patience: 7 + scheduler_kwargs: + monitor: &monitor loss/val + frequency: 1 + target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss + + +metrics: + - name: mae + metric: mae + threshold_kwargs: null + + - name: pearsonr + metric: pearsonr + threshold_kwargs: null + + - name: f1 > 3 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: &threshold_3 + operator: greater + threshold: 3 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: f1 > 5 + metric: f1 + num_classes: 2 + average: micro + threshold_kwargs: + operator: greater + threshold: 5 + th_on_preds: True + th_on_target: True + target_to_int: True + + - name: precision > 3 + metric: precision + class_reduction: micro + threshold_kwargs: *threshold_3 + +trainer: + logger: + save_dir: logs/micro_ZINC + early_stopping: + monitor: *monitor + min_delta: 0 + patience: 10 + mode: &mode min + model_checkpoint: + dirpath: models_checkpoints/micro_ZINC/ + filename: "model" + monitor: *monitor + mode: *mode + save_top_k: 1 + period: 1 + trainer: + max_epochs: 25 + min_epochs: 5 + gpus: 1 + \ No newline at end of file diff --git a/tests/test_architectures.py b/tests/test_architectures.py index 13cb815cd..320a92a1f 100644 --- a/tests/test_architectures.py +++ b/tests/test_architectures.py @@ -1,633 +1,633 @@ -""" -Unit tests for the different architectures of goli/nn/architectures... - -The layers are not thoroughly tested due to the difficulty of testing them -""" - -import torch -import unittest as ut -import dgl -from copy import deepcopy - -from goli.nn.architectures import FeedForwardNN, FeedForwardDGL, FullDGLNetwork, FullDGLSiameseNetwork -from goli.nn.base_layers import FCLayer -from goli.nn.residual_connections import ( - ResidualConnectionConcat, - ResidualConnectionDenseNet, - ResidualConnectionNone, - ResidualConnectionSimple, - ResidualConnectionWeighted, -) - - -class test_FeedForwardNN(ut.TestCase): - - kwargs = { - "activation": "relu", - "last_activation": "none", - "batch_norm": False, - "dropout": 0.2, - "name": "LNN", - "layer_type": FCLayer, - } - - def test_forward_no_residual(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [5, 6, 7] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="none", - residual_skip_steps=1, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_simple_residual_1(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="simple", - residual_skip_steps=1, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_simple_residual_2(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="simple", - residual_skip_steps=2, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_concat_residual_1(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="concat", - residual_skip_steps=1, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, 2 * hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, 2 * hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, 2 * hidden_dims[4]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_concat_residual_2(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="concat", - residual_skip_steps=2, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, 1 * hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, 1 * hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, 1 * hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, 2 * hidden_dims[4]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_densenet_residual_1(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="densenet", - residual_skip_steps=1, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, 2 * hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, 3 * hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, 4 * hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, 5 * hidden_dims[4]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_densenet_residual_2(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="densenet", - residual_skip_steps=2, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, 1 * hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, 1 * hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, 3 * hidden_dims[4]) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_weighted_residual_1(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="weighted", - residual_skip_steps=1, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) - - self.assertEqual(len(lnn.residual_layer.residual_list), len(hidden_dims)) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - def test_forward_weighted_residual_2(self): - in_dim = 8 - out_dim = 16 - hidden_dims = [6, 6, 6, 6, 6] - batch = 2 - - lnn = FeedForwardNN( - in_dim=in_dim, - out_dim=out_dim, - hidden_dims=hidden_dims, - residual_type="weighted", - residual_skip_steps=2, - **self.kwargs, - ) - - self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) - self.assertEqual(lnn.layers[0].in_dim, in_dim) - self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) - self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) - self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) - self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) - self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) - - self.assertEqual(len(lnn.residual_layer.residual_list), (len(hidden_dims) // 2 + 1)) - - h = torch.FloatTensor(batch, in_dim) - h_out = lnn.forward(h) - - self.assertListEqual(list(h_out.shape), [batch, out_dim]) - - -class test_FeedForwardDGL(ut.TestCase): - - kwargs = { - "activation": "relu", - "last_activation": "none", - "batch_norm": False, - "dropout": 0.2, - "name": "LNN", - } - - in_dim = 7 - out_dim = 11 - in_dim_edges = 13 - hidden_dims = [6, 6, 6, 6, 6] - - g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) - g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) - g1.ndata["h"] = torch.zeros(g1.num_nodes(), in_dim, dtype=torch.float32) - g1.edata["e"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=torch.float32) - g2.ndata["h"] = torch.ones(g2.num_nodes(), in_dim, dtype=torch.float32) - g2.edata["e"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=torch.float32) - bg = dgl.batch([g1, g2]) - bg = dgl.add_self_loop(bg) - - virtual_nodes = ["none", "mean", "sum"] - pna_kwargs = {"aggregators": ["mean", "max", "sum"], "scalers": ["identity", "amplification"]} - - layers_kwargs = { - "gcn": {}, - "gin": {}, - "gat": {"layer_kwargs": {"num_heads": 3}}, - "gated-gcn": {"in_dim_edges": in_dim_edges, "hidden_dims_edges": hidden_dims}, - "pna-conv": {"layer_kwargs": pna_kwargs}, - "pna-msgpass#1": {"layer_kwargs": pna_kwargs, "in_dim_edges": 0}, - "pna-msgpass#2": {"layer_kwargs": pna_kwargs, "in_dim_edges": in_dim_edges}, - } - - def test_dgl_forward_no_residual(self): - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn = FeedForwardDGL( - in_dim=self.in_dim, - out_dim=self.out_dim, - hidden_dims=self.hidden_dims, - residual_type="none", - residual_skip_steps=residual_skip_steps, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - gnn.to(torch.float32) - - self.assertIsInstance(gnn.residual_layer, ResidualConnectionNone) - self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) - self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) - self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) - - f = gnn.layers[0].out_dim_factor - self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) - self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) - - bg = deepcopy(self.bg) - h_out = gnn.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - def test_dgl_forward_simple_residual(self): - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn = FeedForwardDGL( - in_dim=self.in_dim, - out_dim=self.out_dim, - hidden_dims=self.hidden_dims, - residual_type="simple", - residual_skip_steps=1, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - gnn.to(torch.float32) - - self.assertIsInstance(gnn.residual_layer, ResidualConnectionSimple) - self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) - self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) - self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) - - f = gnn.layers[0].out_dim_factor - self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) - self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) - - bg = deepcopy(self.bg) - h_out = gnn.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - def test_dgl_forward_weighted_residual(self): - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn = FeedForwardDGL( - in_dim=self.in_dim, - out_dim=self.out_dim, - hidden_dims=self.hidden_dims, - residual_type="weighted", - residual_skip_steps=residual_skip_steps, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - gnn.to(torch.float32) - - self.assertIsInstance(gnn.residual_layer, ResidualConnectionWeighted) - self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) - self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) - self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) - - f = gnn.layers[0].out_dim_factor - self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) - self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) - - bg = deepcopy(self.bg) - h_out = gnn.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - def test_dgl_forward_concat_residual(self): - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn = FeedForwardDGL( - in_dim=self.in_dim, - out_dim=self.out_dim, - hidden_dims=self.hidden_dims, - residual_type="concat", - residual_skip_steps=residual_skip_steps, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - gnn.to(torch.float32) - - self.assertIsInstance(gnn.residual_layer, ResidualConnectionConcat) - self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) - self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) - self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) - - f = gnn.layers[0].out_dim_factor - f2 = [2 * f if ((ii % residual_skip_steps) == 0 and ii > 0) else f for ii in range(6)] - self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) - self.assertEqual(gnn.layers[1].in_dim, f2[0] * self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[2].in_dim, f2[1] * self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[3].in_dim, f2[2] * self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[4].in_dim, f2[3] * self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[5].in_dim, f2[4] * self.hidden_dims[4], msg=err_msg) - - bg = deepcopy(self.bg) - h_out = gnn.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - def test_dgl_forward_densenet_residual(self): - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn = FeedForwardDGL( - in_dim=self.in_dim, - out_dim=self.out_dim, - hidden_dims=self.hidden_dims, - residual_type="densenet", - residual_skip_steps=residual_skip_steps, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - gnn.to(torch.float32) - - self.assertIsInstance(gnn.residual_layer, ResidualConnectionDenseNet) - self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) - self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) - self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) - - f = gnn.layers[0].out_dim_factor - f2 = [ - ((ii // residual_skip_steps) + 1) * f - if ((ii % residual_skip_steps) == 0 and ii > 0) - else f - for ii in range(6) - ] - self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) - self.assertEqual(gnn.layers[1].in_dim, f2[0] * self.hidden_dims[0], msg=err_msg) - self.assertEqual(gnn.layers[2].in_dim, f2[1] * self.hidden_dims[1], msg=err_msg) - self.assertEqual(gnn.layers[3].in_dim, f2[2] * self.hidden_dims[2], msg=err_msg) - self.assertEqual(gnn.layers[4].in_dim, f2[3] * self.hidden_dims[3], msg=err_msg) - self.assertEqual(gnn.layers[5].in_dim, f2[4] * self.hidden_dims[4], msg=err_msg) - - bg = deepcopy(self.bg) - h_out = gnn.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - -class test_FullDGLNetwork(ut.TestCase): - - kwargs = { - "activation": "relu", - "last_activation": "none", - "batch_norm": False, - "dropout": 0.2, - "name": "LNN", - } - - in_dim = 7 - out_dim = 11 - in_dim_edges = 13 - hidden_dims = [6, 6, 6, 6, 6] - - g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) - g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) - g1.ndata["feat"] = torch.zeros(g1.num_nodes(), in_dim, dtype=torch.float32) - g1.edata["feat"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=torch.float32) - g2.ndata["feat"] = torch.ones(g2.num_nodes(), in_dim, dtype=torch.float32) - g2.edata["feat"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=torch.float32) - bg = dgl.batch([g1, g2]) - bg = dgl.add_self_loop(bg) - - virtual_nodes = ["none", "mean", "sum"] - pna_kwargs = {"aggregators": ["mean", "max", "sum"], "scalers": ["identity", "amplification"]} - - gnn_layers_kwargs = { - "gcn": {}, - "gin": {}, - "gat": {"layer_kwargs": {"num_heads": 3}}, - "gated-gcn": {"in_dim_edges": in_dim_edges, "hidden_dims_edges": hidden_dims}, - "pna-conv": {"layer_kwargs": pna_kwargs}, - "pna-msgpass#1": {"layer_kwargs": pna_kwargs, "in_dim_edges": 0}, - "pna-msgpass#2": {"layer_kwargs": pna_kwargs, "in_dim_edges": in_dim_edges}, - } - - def test_full_network_densenet(self): - - temp_dim_1 = 5 - temp_dim_2 = 17 - - pre_nn_kwargs = dict(in_dim=self.in_dim, out_dim=temp_dim_1, hidden_dims=[4, 4, 4, 4, 4]) - - post_nn_kwargs = dict(in_dim=temp_dim_2, out_dim=self.out_dim, hidden_dims=[3, 3, 3, 3]) - - for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: - for residual_skip_steps in [1, 2, 3]: - for virtual_node in self.virtual_nodes: - for layer_name, this_kwargs in self.gnn_layers_kwargs.items(): - err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" - layer_type = layer_name.split("#")[0] - - gnn_kwargs = dict( - in_dim=temp_dim_1, - out_dim=temp_dim_2, - hidden_dims=self.hidden_dims, - residual_type="densenet", - residual_skip_steps=residual_skip_steps, - layer_type=layer_type, - pooling=pooling, - **this_kwargs, - **self.kwargs, - ) - - net = FullDGLNetwork( - pre_nn_kwargs=pre_nn_kwargs, gnn_kwargs=gnn_kwargs, post_nn_kwargs=post_nn_kwargs - ) - bg = deepcopy(self.bg) - h_out = net.forward(bg) - - dim_1 = bg.num_nodes() if pooling == ["none"] else 1 - self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the different architectures of goli/nn/architectures... + +The layers are not thoroughly tested due to the difficulty of testing them +""" + +import torch +import unittest as ut +import dgl +from copy import deepcopy + +from goli.nn.architectures import FeedForwardNN, FeedForwardDGL, FullDGLNetwork, FullDGLSiameseNetwork +from goli.nn.base_layers import FCLayer +from goli.nn.residual_connections import ( + ResidualConnectionConcat, + ResidualConnectionDenseNet, + ResidualConnectionNone, + ResidualConnectionSimple, + ResidualConnectionWeighted, +) + + +class test_FeedForwardNN(ut.TestCase): + + kwargs = { + "activation": "relu", + "last_activation": "none", + "batch_norm": False, + "dropout": 0.2, + "name": "LNN", + "layer_type": FCLayer, + } + + def test_forward_no_residual(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [5, 6, 7] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="none", + residual_skip_steps=1, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_simple_residual_1(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="simple", + residual_skip_steps=1, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_simple_residual_2(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="simple", + residual_skip_steps=2, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_concat_residual_1(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="concat", + residual_skip_steps=1, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, 2 * hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, 2 * hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, 2 * hidden_dims[4]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_concat_residual_2(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="concat", + residual_skip_steps=2, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, 1 * hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, 1 * hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, 1 * hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, 2 * hidden_dims[4]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_densenet_residual_1(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="densenet", + residual_skip_steps=1, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, 2 * hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, 3 * hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, 4 * hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, 5 * hidden_dims[4]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_densenet_residual_2(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="densenet", + residual_skip_steps=2, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, 1 * hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, 2 * hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, 1 * hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, 3 * hidden_dims[4]) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_weighted_residual_1(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="weighted", + residual_skip_steps=1, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) + + self.assertEqual(len(lnn.residual_layer.residual_list), len(hidden_dims)) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + def test_forward_weighted_residual_2(self): + in_dim = 8 + out_dim = 16 + hidden_dims = [6, 6, 6, 6, 6] + batch = 2 + + lnn = FeedForwardNN( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + residual_type="weighted", + residual_skip_steps=2, + **self.kwargs, + ) + + self.assertEqual(len(lnn.layers), len(hidden_dims) + 1) + self.assertEqual(lnn.layers[0].in_dim, in_dim) + self.assertEqual(lnn.layers[1].in_dim, hidden_dims[0]) + self.assertEqual(lnn.layers[2].in_dim, hidden_dims[1]) + self.assertEqual(lnn.layers[3].in_dim, hidden_dims[2]) + self.assertEqual(lnn.layers[4].in_dim, hidden_dims[3]) + self.assertEqual(lnn.layers[5].in_dim, hidden_dims[4]) + + self.assertEqual(len(lnn.residual_layer.residual_list), (len(hidden_dims) // 2 + 1)) + + h = torch.FloatTensor(batch, in_dim) + h_out = lnn.forward(h) + + self.assertListEqual(list(h_out.shape), [batch, out_dim]) + + +class test_FeedForwardDGL(ut.TestCase): + + kwargs = { + "activation": "relu", + "last_activation": "none", + "batch_norm": False, + "dropout": 0.2, + "name": "LNN", + } + + in_dim = 7 + out_dim = 11 + in_dim_edges = 13 + hidden_dims = [6, 6, 6, 6, 6] + + g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) + g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) + g1.ndata["h"] = torch.zeros(g1.num_nodes(), in_dim, dtype=torch.float32) + g1.edata["e"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=torch.float32) + g2.ndata["h"] = torch.ones(g2.num_nodes(), in_dim, dtype=torch.float32) + g2.edata["e"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=torch.float32) + bg = dgl.batch([g1, g2]) + bg = dgl.add_self_loop(bg) + + virtual_nodes = ["none", "mean", "sum"] + pna_kwargs = {"aggregators": ["mean", "max", "sum"], "scalers": ["identity", "amplification"]} + + layers_kwargs = { + "gcn": {}, + "gin": {}, + "gat": {"layer_kwargs": {"num_heads": 3}}, + "gated-gcn": {"in_dim_edges": in_dim_edges, "hidden_dims_edges": hidden_dims}, + "pna-conv": {"layer_kwargs": pna_kwargs}, + "pna-msgpass#1": {"layer_kwargs": pna_kwargs, "in_dim_edges": 0}, + "pna-msgpass#2": {"layer_kwargs": pna_kwargs, "in_dim_edges": in_dim_edges}, + } + + def test_dgl_forward_no_residual(self): + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn = FeedForwardDGL( + in_dim=self.in_dim, + out_dim=self.out_dim, + hidden_dims=self.hidden_dims, + residual_type="none", + residual_skip_steps=residual_skip_steps, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + gnn.to(torch.float32) + + self.assertIsInstance(gnn.residual_layer, ResidualConnectionNone) + self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) + self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) + self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) + + f = gnn.layers[0].out_dim_factor + self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) + self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) + + bg = deepcopy(self.bg) + h_out = gnn.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + def test_dgl_forward_simple_residual(self): + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn = FeedForwardDGL( + in_dim=self.in_dim, + out_dim=self.out_dim, + hidden_dims=self.hidden_dims, + residual_type="simple", + residual_skip_steps=1, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + gnn.to(torch.float32) + + self.assertIsInstance(gnn.residual_layer, ResidualConnectionSimple) + self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) + self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) + self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) + + f = gnn.layers[0].out_dim_factor + self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) + self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) + + bg = deepcopy(self.bg) + h_out = gnn.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + def test_dgl_forward_weighted_residual(self): + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn = FeedForwardDGL( + in_dim=self.in_dim, + out_dim=self.out_dim, + hidden_dims=self.hidden_dims, + residual_type="weighted", + residual_skip_steps=residual_skip_steps, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + gnn.to(torch.float32) + + self.assertIsInstance(gnn.residual_layer, ResidualConnectionWeighted) + self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) + self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) + self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) + + f = gnn.layers[0].out_dim_factor + self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) + self.assertEqual(gnn.layers[1].in_dim, f * self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[2].in_dim, f * self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[3].in_dim, f * self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[4].in_dim, f * self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[5].in_dim, f * self.hidden_dims[4], msg=err_msg) + + bg = deepcopy(self.bg) + h_out = gnn.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + def test_dgl_forward_concat_residual(self): + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn = FeedForwardDGL( + in_dim=self.in_dim, + out_dim=self.out_dim, + hidden_dims=self.hidden_dims, + residual_type="concat", + residual_skip_steps=residual_skip_steps, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + gnn.to(torch.float32) + + self.assertIsInstance(gnn.residual_layer, ResidualConnectionConcat) + self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) + self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) + self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) + + f = gnn.layers[0].out_dim_factor + f2 = [2 * f if ((ii % residual_skip_steps) == 0 and ii > 0) else f for ii in range(6)] + self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) + self.assertEqual(gnn.layers[1].in_dim, f2[0] * self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[2].in_dim, f2[1] * self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[3].in_dim, f2[2] * self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[4].in_dim, f2[3] * self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[5].in_dim, f2[4] * self.hidden_dims[4], msg=err_msg) + + bg = deepcopy(self.bg) + h_out = gnn.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + def test_dgl_forward_densenet_residual(self): + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn = FeedForwardDGL( + in_dim=self.in_dim, + out_dim=self.out_dim, + hidden_dims=self.hidden_dims, + residual_type="densenet", + residual_skip_steps=residual_skip_steps, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + gnn.to(torch.float32) + + self.assertIsInstance(gnn.residual_layer, ResidualConnectionDenseNet) + self.assertEqual(len(gnn.layers), len(self.hidden_dims) + 1, msg=err_msg) + self.assertEqual(gnn.layers[0].out_dim, self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[1].out_dim, self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[2].out_dim, self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[3].out_dim, self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[4].out_dim, self.hidden_dims[4], msg=err_msg) + self.assertEqual(gnn.layers[5].out_dim, self.out_dim, msg=err_msg) + + f = gnn.layers[0].out_dim_factor + f2 = [ + ((ii // residual_skip_steps) + 1) * f + if ((ii % residual_skip_steps) == 0 and ii > 0) + else f + for ii in range(6) + ] + self.assertEqual(gnn.layers[0].in_dim, self.in_dim, msg=err_msg) + self.assertEqual(gnn.layers[1].in_dim, f2[0] * self.hidden_dims[0], msg=err_msg) + self.assertEqual(gnn.layers[2].in_dim, f2[1] * self.hidden_dims[1], msg=err_msg) + self.assertEqual(gnn.layers[3].in_dim, f2[2] * self.hidden_dims[2], msg=err_msg) + self.assertEqual(gnn.layers[4].in_dim, f2[3] * self.hidden_dims[3], msg=err_msg) + self.assertEqual(gnn.layers[5].in_dim, f2[4] * self.hidden_dims[4], msg=err_msg) + + bg = deepcopy(self.bg) + h_out = gnn.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + +class test_FullDGLNetwork(ut.TestCase): + + kwargs = { + "activation": "relu", + "last_activation": "none", + "batch_norm": False, + "dropout": 0.2, + "name": "LNN", + } + + in_dim = 7 + out_dim = 11 + in_dim_edges = 13 + hidden_dims = [6, 6, 6, 6, 6] + + g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) + g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) + g1.ndata["feat"] = torch.zeros(g1.num_nodes(), in_dim, dtype=torch.float32) + g1.edata["feat"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=torch.float32) + g2.ndata["feat"] = torch.ones(g2.num_nodes(), in_dim, dtype=torch.float32) + g2.edata["feat"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=torch.float32) + bg = dgl.batch([g1, g2]) + bg = dgl.add_self_loop(bg) + + virtual_nodes = ["none", "mean", "sum"] + pna_kwargs = {"aggregators": ["mean", "max", "sum"], "scalers": ["identity", "amplification"]} + + gnn_layers_kwargs = { + "gcn": {}, + "gin": {}, + "gat": {"layer_kwargs": {"num_heads": 3}}, + "gated-gcn": {"in_dim_edges": in_dim_edges, "hidden_dims_edges": hidden_dims}, + "pna-conv": {"layer_kwargs": pna_kwargs}, + "pna-msgpass#1": {"layer_kwargs": pna_kwargs, "in_dim_edges": 0}, + "pna-msgpass#2": {"layer_kwargs": pna_kwargs, "in_dim_edges": in_dim_edges}, + } + + def test_full_network_densenet(self): + + temp_dim_1 = 5 + temp_dim_2 = 17 + + pre_nn_kwargs = dict(in_dim=self.in_dim, out_dim=temp_dim_1, hidden_dims=[4, 4, 4, 4, 4]) + + post_nn_kwargs = dict(in_dim=temp_dim_2, out_dim=self.out_dim, hidden_dims=[3, 3, 3, 3]) + + for pooling in [["none"], ["sum"], ["mean", "s2s", "max"]]: + for residual_skip_steps in [1, 2, 3]: + for virtual_node in self.virtual_nodes: + for layer_name, this_kwargs in self.gnn_layers_kwargs.items(): + err_msg = f"pooling={pooling}, virtual_node={virtual_node}, layer_name={layer_name}, residual_skip_steps={residual_skip_steps}" + layer_type = layer_name.split("#")[0] + + gnn_kwargs = dict( + in_dim=temp_dim_1, + out_dim=temp_dim_2, + hidden_dims=self.hidden_dims, + residual_type="densenet", + residual_skip_steps=residual_skip_steps, + layer_type=layer_type, + pooling=pooling, + **this_kwargs, + **self.kwargs, + ) + + net = FullDGLNetwork( + pre_nn_kwargs=pre_nn_kwargs, gnn_kwargs=gnn_kwargs, post_nn_kwargs=post_nn_kwargs + ) + bg = deepcopy(self.bg) + h_out = net.forward(bg) + + dim_1 = bg.num_nodes() if pooling == ["none"] else 1 + self.assertListEqual(list(h_out.shape), [dim_1, self.out_dim], msg=err_msg) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 57723da77..f6a648ceb 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -1,20 +1,20 @@ -import pandas as pd - -import goli - - -def test_list_datasets(): - datasets = goli.data.utils.list_goli_datasets() - assert isinstance(datasets, set) - assert len(datasets) > 0 - - -def test_download_datasets(tmpdir): - dataset_dir = tmpdir.mkdir("goli-datasets") - - data_path = goli.data.utils.download_goli_dataset("goli-zinc-micro", output_path=dataset_dir) - - fpath = goli.utils.fs.join(data_path, "ZINC-micro.csv") - df = pd.read_csv(fpath) - assert df.shape == (1000, 4) # type: ignore - assert df.columns.tolist() == ["SMILES", "SA", "logp", "score"] # type: ignore +import pandas as pd + +import goli + + +def test_list_datasets(): + datasets = goli.data.utils.list_goli_datasets() + assert isinstance(datasets, set) + assert len(datasets) > 0 + + +def test_download_datasets(tmpdir): + dataset_dir = tmpdir.mkdir("goli-datasets") + + data_path = goli.data.utils.download_goli_dataset("goli-zinc-micro", output_path=dataset_dir) + + fpath = goli.utils.fs.join(data_path, "ZINC-micro.csv") + df = pd.read_csv(fpath) + assert df.shape == (1000, 4) # type: ignore + assert df.columns.tolist() == ["SMILES", "SA", "logp", "score"] # type: ignore diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 97ba32ea5..639a22c0f 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -1,209 +1,209 @@ -import pathlib -import tempfile - -import unittest as ut - -import goli - - -class Test_DataModule(ut.TestCase): - def test_dglfromsmiles_dm(self): - - df = goli.data.load_tiny_zinc() - # Setup the featurization - featurization_args = {} - featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] - featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] - featurization_args["edge_property_list"] = ["in-ring", "bond-type-onehot"] - featurization_args["add_self_loop"] = False - featurization_args["use_bonds_weights"] = False - featurization_args["explicit_H"] = False - - # Config for datamodule - dm_args = {} - dm_args["df"] = df - dm_args["cache_data_path"] = None - dm_args["featurization"] = featurization_args - dm_args["smiles_col"] = "SMILES" - dm_args["label_cols"] = ["SA"] - dm_args["split_val"] = 0.2 - dm_args["split_test"] = 0.2 - dm_args["split_seed"] = 19 - dm_args["batch_size_train_val"] = 16 - dm_args["batch_size_test"] = 16 - dm_args["num_workers"] = 0 - dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 16 - dm_args["featurization_progress"] = True - - dm = goli.data.DGLFromSmilesDataModule(**dm_args) - - assert dm.num_node_feats == 50 - assert dm.num_edge_feats == 6 - - dm.prepare_data() - dm.setup() - - assert len(dm) == 100 - assert len(dm.train_ds) == 60 # type: ignore - assert len(dm.val_ds) == 20 # type: ignore - assert len(dm.test_ds) == 20 # type: ignore - assert dm.num_node_feats == 50 - assert dm.num_edge_feats == 6 - - for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - it = iter(dl) - batch = next(it) - - assert set(batch.keys()) == {"labels", "features", "smiles"} - assert batch["labels"].shape == (16, 1) - - def test_dglfromsmiles_from_config(self): - - config = goli.load_config(name="zinc_default_fulldgl") - df = goli.data.load_tiny_zinc() - - dm_args = dict(config.data.args) - dm_args["df"] = df - - dm = goli.data.DGLFromSmilesDataModule(**dm_args) - - assert dm.num_node_feats == 50 - assert dm.num_edge_feats == 6 - - dm.prepare_data() - dm.setup() - - assert len(dm.train_ds) == 60 # type: ignore - assert len(dm.val_ds) == 20 # type: ignore - assert len(dm.test_ds) == 20 # type: ignore - assert dm.num_node_feats == 50 - assert dm.num_edge_feats == 6 - - for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - it = iter(dl) - batch = next(it) - - assert set(batch.keys()) == {"labels", "features", "smiles"} - assert batch["labels"].shape == (16, 1) - - def test_ogb_datamodule(self): - - # other datasets are too large to be tested - dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] - dataset_name = dataset_names[3] - - # Setup the featurization - featurization_args = {} - featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] - featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] - featurization_args["edge_property_list"] = ["bond-type-onehot"] - featurization_args["add_self_loop"] = False - featurization_args["use_bonds_weights"] = False - featurization_args["explicit_H"] = False - - # Config for datamodule - dm_args = {} - dm_args["dataset_name"] = dataset_name - dm_args["cache_data_path"] = None - dm_args["featurization"] = featurization_args - dm_args["batch_size_train_val"] = 16 - dm_args["batch_size_test"] = 16 - dm_args["num_workers"] = 0 - dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 16 - dm_args["featurization_progress"] = True - - ds = goli.data.DGLOGBDataModule(**dm_args) - - # test metadata - assert set(ds.metadata.keys()) == { - "num tasks", - "eval metric", - "download_name", - "version", - "url", - "add_inverse_edge", - "data type", - "has_node_attr", - "has_edge_attr", - "task type", - "num classes", - "split", - "additional node files", - "additional edge files", - "binary", - } - - ds.prepare_data() - ds.setup() - - # test module - assert ds.num_edge_feats == 5 - assert ds.num_node_feats == 50 - assert len(ds) == 642 - assert ds.dataset_name == "ogbg-molfreesolv" - - # test dataset - assert set(ds.train_ds[0].keys()) == {"smiles", "indices", "features", "labels"} - - # test batch loader - batch = next(iter(ds.train_dataloader())) - assert len(batch["smiles"]) == 16 - assert len(batch["labels"]) == 16 - assert len(batch["indices"]) == 16 - - def test_datamodule_cache_invalidation(self): - - df = goli.data.load_tiny_zinc() - - cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl" - - # 1. Build a module with specific feature arguments - - featurization_args = {} - featurization_args["atom_property_list_float"] = ["mass", "electronegativity", "in-ring"] - featurization_args["edge_property_list"] = ["bond-type-onehot", "stereo", "in-ring"] - - dm_args = {} - dm_args["df"] = df - dm_args["cache_data_path"] = cache_data_path - dm_args["featurization"] = featurization_args - datam = goli.data.DGLFromSmilesDataModule(**dm_args) - datam.prepare_data() - datam.setup() - - assert datam.num_node_feats == 3 - assert datam.num_edge_feats == 13 - - # 2. Reload with the same arguments should not trigger a new preparation and give - # the same feature's dimensions. - - datam = goli.data.DGLFromSmilesDataModule(**dm_args) - datam.prepare_data() - datam.setup() - - assert datam.num_node_feats == 3 - assert datam.num_edge_feats == 13 - - # 3. Reloading from the same cache file should trigger a new data preparation and - # so different feature's dimensions. - - featurization_args = {} - featurization_args["edge_property_list"] = ["stereo", "in-ring"] - featurization_args["atom_property_list_float"] = ["mass", "electronegativity"] - - dm_args = {} - dm_args["df"] = df - dm_args["cache_data_path"] = cache_data_path - dm_args["featurization"] = featurization_args - datam = goli.data.DGLFromSmilesDataModule(**dm_args) - datam.prepare_data() - datam.setup() - - assert datam.num_node_feats == 2 - assert datam.num_edge_feats == 8 - - -if __name__ == "__main__": - ut.main() +import pathlib +import tempfile + +import unittest as ut + +import goli + + +class Test_DataModule(ut.TestCase): + def test_dglfromsmiles_dm(self): + + df = goli.data.load_tiny_zinc() + # Setup the featurization + featurization_args = {} + featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] + featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] + featurization_args["edge_property_list"] = ["in-ring", "bond-type-onehot"] + featurization_args["add_self_loop"] = False + featurization_args["use_bonds_weights"] = False + featurization_args["explicit_H"] = False + + # Config for datamodule + dm_args = {} + dm_args["df"] = df + dm_args["cache_data_path"] = None + dm_args["featurization"] = featurization_args + dm_args["smiles_col"] = "SMILES" + dm_args["label_cols"] = ["SA"] + dm_args["split_val"] = 0.2 + dm_args["split_test"] = 0.2 + dm_args["split_seed"] = 19 + dm_args["batch_size_train_val"] = 16 + dm_args["batch_size_test"] = 16 + dm_args["num_workers"] = 0 + dm_args["pin_memory"] = True + dm_args["featurization_n_jobs"] = 16 + dm_args["featurization_progress"] = True + + dm = goli.data.DGLFromSmilesDataModule(**dm_args) + + assert dm.num_node_feats == 50 + assert dm.num_edge_feats == 6 + + dm.prepare_data() + dm.setup() + + assert len(dm) == 100 + assert len(dm.train_ds) == 60 # type: ignore + assert len(dm.val_ds) == 20 # type: ignore + assert len(dm.test_ds) == 20 # type: ignore + assert dm.num_node_feats == 50 + assert dm.num_edge_feats == 6 + + for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + it = iter(dl) + batch = next(it) + + assert set(batch.keys()) == {"labels", "features", "smiles"} + assert batch["labels"].shape == (16, 1) + + def test_dglfromsmiles_from_config(self): + + config = goli.load_config(name="zinc_default_fulldgl") + df = goli.data.load_tiny_zinc() + + dm_args = dict(config.data.args) + dm_args["df"] = df + + dm = goli.data.DGLFromSmilesDataModule(**dm_args) + + assert dm.num_node_feats == 50 + assert dm.num_edge_feats == 6 + + dm.prepare_data() + dm.setup() + + assert len(dm.train_ds) == 60 # type: ignore + assert len(dm.val_ds) == 20 # type: ignore + assert len(dm.test_ds) == 20 # type: ignore + assert dm.num_node_feats == 50 + assert dm.num_edge_feats == 6 + + for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + it = iter(dl) + batch = next(it) + + assert set(batch.keys()) == {"labels", "features", "smiles"} + assert batch["labels"].shape == (16, 1) + + def test_ogb_datamodule(self): + + # other datasets are too large to be tested + dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] + dataset_name = dataset_names[3] + + # Setup the featurization + featurization_args = {} + featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] + featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] + featurization_args["edge_property_list"] = ["bond-type-onehot"] + featurization_args["add_self_loop"] = False + featurization_args["use_bonds_weights"] = False + featurization_args["explicit_H"] = False + + # Config for datamodule + dm_args = {} + dm_args["dataset_name"] = dataset_name + dm_args["cache_data_path"] = None + dm_args["featurization"] = featurization_args + dm_args["batch_size_train_val"] = 16 + dm_args["batch_size_test"] = 16 + dm_args["num_workers"] = 0 + dm_args["pin_memory"] = True + dm_args["featurization_n_jobs"] = 16 + dm_args["featurization_progress"] = True + + ds = goli.data.DGLOGBDataModule(**dm_args) + + # test metadata + assert set(ds.metadata.keys()) == { + "num tasks", + "eval metric", + "download_name", + "version", + "url", + "add_inverse_edge", + "data type", + "has_node_attr", + "has_edge_attr", + "task type", + "num classes", + "split", + "additional node files", + "additional edge files", + "binary", + } + + ds.prepare_data() + ds.setup() + + # test module + assert ds.num_edge_feats == 5 + assert ds.num_node_feats == 50 + assert len(ds) == 642 + assert ds.dataset_name == "ogbg-molfreesolv" + + # test dataset + assert set(ds.train_ds[0].keys()) == {"smiles", "indices", "features", "labels"} + + # test batch loader + batch = next(iter(ds.train_dataloader())) + assert len(batch["smiles"]) == 16 + assert len(batch["labels"]) == 16 + assert len(batch["indices"]) == 16 + + def test_datamodule_cache_invalidation(self): + + df = goli.data.load_tiny_zinc() + + cache_data_path = pathlib.Path(tempfile.mkdtemp()) / "cache.pkl" + + # 1. Build a module with specific feature arguments + + featurization_args = {} + featurization_args["atom_property_list_float"] = ["mass", "electronegativity", "in-ring"] + featurization_args["edge_property_list"] = ["bond-type-onehot", "stereo", "in-ring"] + + dm_args = {} + dm_args["df"] = df + dm_args["cache_data_path"] = cache_data_path + dm_args["featurization"] = featurization_args + datam = goli.data.DGLFromSmilesDataModule(**dm_args) + datam.prepare_data() + datam.setup() + + assert datam.num_node_feats == 3 + assert datam.num_edge_feats == 13 + + # 2. Reload with the same arguments should not trigger a new preparation and give + # the same feature's dimensions. + + datam = goli.data.DGLFromSmilesDataModule(**dm_args) + datam.prepare_data() + datam.setup() + + assert datam.num_node_feats == 3 + assert datam.num_edge_feats == 13 + + # 3. Reloading from the same cache file should trigger a new data preparation and + # so different feature's dimensions. + + featurization_args = {} + featurization_args["edge_property_list"] = ["stereo", "in-ring"] + featurization_args["atom_property_list_float"] = ["mass", "electronegativity"] + + dm_args = {} + dm_args["df"] = df + dm_args["cache_data_path"] = cache_data_path + dm_args["featurization"] = featurization_args + datam = goli.data.DGLFromSmilesDataModule(**dm_args) + datam.prepare_data() + datam.setup() + + assert datam.num_node_feats == 2 + assert datam.num_edge_feats == 8 + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_featurizer.py b/tests/test_featurizer.py index 2f25d5059..5a8c84aa9 100644 --- a/tests/test_featurizer.py +++ b/tests/test_featurizer.py @@ -1,217 +1,217 @@ -""" -Unit tests for the different datasets of goli/features/featurizer.py -""" - -import numpy as np -import unittest as ut -from copy import deepcopy -from rdkit import Chem -import datamol as dm - -from goli.features.featurizer import ( - get_mol_atomic_features_onehot, - get_mol_atomic_features_float, - get_mol_edge_features, - mol_to_adj_and_features, - mol_to_dglgraph, -) - - -class test_featurizer(ut.TestCase): - - smiles = [ - "C", - "CC", - "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", - "OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N", - "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5", - ] - - atomic_onehot_props = [ - "atomic-number", - "valence", - "degree", - "implicit-valence", - "hybridization", - "chirality", - ] - - atomic_float_props = [ - "atomic-number", - "mass", - "valence", - "implicit-valence", - "hybridization", - "chirality", - "aromatic", - "in-ring", - "degree", - "radical-electron", - "formal-charge", - "vdw-radius", - "covalent-radius", - "electronegativity", - "ionization", - "melting-point", - "metal", - "single-bond", - "aromatic-bond", - "double-bond", - "triple-bond", - "is-carbon", - ] - - edge_props = [ - "bond-type-onehot", - "bond-type-float", - "stereo", - "in-ring", - "conjugated", - "estimated-bond-length", - "conformer-bond-length", - ] - - def test_get_mol_atomic_features_onehot(self): - props = deepcopy(self.atomic_onehot_props) - bad_props = ["bob"] - - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - - for ii in range(len(props)): - this_props = props[:ii] - err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_onehot(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumAtoms(), msg=err_msg3) - self.assertGreater(val.shape[1], 1, msg=err_msg3) - self.assertTrue(np.all((val == 0) | (val == 1)), msg=err_msg3) - - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_onehot(mol, property_list=bad_props) - - def test_get_mol_atomic_features_float(self): - props = deepcopy(self.atomic_float_props) - - bad_props = ["bob"] - - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - - for ii in range(len(props)): - this_props = props[:ii] - err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_float(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertListEqual(list(val.shape), [mol.GetNumAtoms()], msg=err_msg3) - - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_float(mol, property_list=bad_props) - - def test_get_mol_edge_features(self): - props = deepcopy(self.edge_props) - bad_props = ["bob"] - - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - for ii in range(len(props)): - this_props = props[: ii + 1] - err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_edge_features(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumBonds(), msg=err_msg3) - - if mol.GetNumBonds() > 0: - with self.assertRaises(ValueError, msg=err_msg): - get_mol_edge_features(mol, property_list=bad_props) - - def test_mol_to_adj_and_features(self): - - np.random.seed(42) - - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore - - for explicit_H in [True, False]: - this_mol = mol_Hs if explicit_H else mol_No_Hs - for ii in np.arange(0, 5, 0.2): - num_props = int(round(ii)) - err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" - - adj, ndata, edata, _, _, _ = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False - ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False - ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), - add_self_loop=False, - explicit_H=explicit_H, - use_bonds_weights=False, - ) - - self.assertEqual(adj.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if num_props > 0: - self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if this_mol.GetNumBonds() > 0: - self.assertEqual(edata.shape[0], this_mol.GetNumBonds(), msg=err_msg2) - self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) - self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) - - def test_mol_to_dglgraph(self): - - np.random.seed(42) - - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore - - for explicit_H in [True, False]: - this_mol = mol_Hs if explicit_H else mol_No_Hs - for ii in np.arange(0, 5, 0.2): - num_props = int(round(ii)) - err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" - - graph = mol_to_dglgraph( - mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False - ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False - ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), - add_self_loop=False, - explicit_H=explicit_H, - use_bonds_weights=False, - ) - - self.assertEqual(graph.num_nodes(), this_mol.GetNumAtoms(), msg=err_msg2) - self.assertEqual(graph.num_edges(), 2 * this_mol.GetNumBonds(), msg=err_msg2) - if num_props > 0: - ndata = graph.ndata["feat"] - edata = graph.edata["feat"] - self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - self.assertEqual(edata.shape[0], 2 * this_mol.GetNumBonds(), msg=err_msg2) - self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) - self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the different datasets of goli/features/featurizer.py +""" + +import numpy as np +import unittest as ut +from copy import deepcopy +from rdkit import Chem +import datamol as dm + +from goli.features.featurizer import ( + get_mol_atomic_features_onehot, + get_mol_atomic_features_float, + get_mol_edge_features, + mol_to_adj_and_features, + mol_to_dglgraph, +) + + +class test_featurizer(ut.TestCase): + + smiles = [ + "C", + "CC", + "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", + "OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N", + "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5", + ] + + atomic_onehot_props = [ + "atomic-number", + "valence", + "degree", + "implicit-valence", + "hybridization", + "chirality", + ] + + atomic_float_props = [ + "atomic-number", + "mass", + "valence", + "implicit-valence", + "hybridization", + "chirality", + "aromatic", + "in-ring", + "degree", + "radical-electron", + "formal-charge", + "vdw-radius", + "covalent-radius", + "electronegativity", + "ionization", + "melting-point", + "metal", + "single-bond", + "aromatic-bond", + "double-bond", + "triple-bond", + "is-carbon", + ] + + edge_props = [ + "bond-type-onehot", + "bond-type-float", + "stereo", + "in-ring", + "conjugated", + "estimated-bond-length", + "conformer-bond-length", + ] + + def test_get_mol_atomic_features_onehot(self): + props = deepcopy(self.atomic_onehot_props) + bad_props = ["bob"] + + for s in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" + mol = dm.to_mol(s) + + for ii in range(len(props)): + this_props = props[:ii] + err_msg2 = err_msg + f"\n\t\tprops: {this_props}" + prop_dict = get_mol_atomic_features_onehot(mol, property_list=this_props) + self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) + for key, val in prop_dict.items(): + err_msg3 = err_msg2 + f"\n\t\tkey: {key}" + self.assertEqual(val.shape[0], mol.GetNumAtoms(), msg=err_msg3) + self.assertGreater(val.shape[1], 1, msg=err_msg3) + self.assertTrue(np.all((val == 0) | (val == 1)), msg=err_msg3) + + with self.assertRaises(ValueError, msg=err_msg): + get_mol_atomic_features_onehot(mol, property_list=bad_props) + + def test_get_mol_atomic_features_float(self): + props = deepcopy(self.atomic_float_props) + + bad_props = ["bob"] + + for s in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" + mol = dm.to_mol(s) + + for ii in range(len(props)): + this_props = props[:ii] + err_msg2 = err_msg + f"\n\t\tprops: {this_props}" + prop_dict = get_mol_atomic_features_float(mol, property_list=this_props) + self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) + for key, val in prop_dict.items(): + err_msg3 = err_msg2 + f"\n\t\tkey: {key}" + self.assertListEqual(list(val.shape), [mol.GetNumAtoms()], msg=err_msg3) + + with self.assertRaises(ValueError, msg=err_msg): + get_mol_atomic_features_float(mol, property_list=bad_props) + + def test_get_mol_edge_features(self): + props = deepcopy(self.edge_props) + bad_props = ["bob"] + + for s in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" + mol = dm.to_mol(s) + for ii in range(len(props)): + this_props = props[: ii + 1] + err_msg2 = err_msg + f"\n\t\tprops: {this_props}" + prop_dict = get_mol_edge_features(mol, property_list=this_props) + self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) + for key, val in prop_dict.items(): + err_msg3 = err_msg2 + f"\n\t\tkey: {key}" + self.assertEqual(val.shape[0], mol.GetNumBonds(), msg=err_msg3) + + if mol.GetNumBonds() > 0: + with self.assertRaises(ValueError, msg=err_msg): + get_mol_edge_features(mol, property_list=bad_props) + + def test_mol_to_adj_and_features(self): + + np.random.seed(42) + + for s in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" + mol = dm.to_mol(s) + mol_Hs = Chem.AddHs(mol) # type: ignore + mol_No_Hs = Chem.RemoveHs(mol) # type: ignore + + for explicit_H in [True, False]: + this_mol = mol_Hs if explicit_H else mol_No_Hs + for ii in np.arange(0, 5, 0.2): + num_props = int(round(ii)) + err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" + + adj, ndata, edata, _, _, _ = mol_to_adj_and_features( + mol=mol, + atom_property_list_onehot=np.random.choice( + self.atomic_onehot_props, size=num_props, replace=False + ), + atom_property_list_float=np.random.choice( + self.atomic_float_props, size=num_props, replace=False + ), + edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), + add_self_loop=False, + explicit_H=explicit_H, + use_bonds_weights=False, + ) + + self.assertEqual(adj.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) + if num_props > 0: + self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) + if this_mol.GetNumBonds() > 0: + self.assertEqual(edata.shape[0], this_mol.GetNumBonds(), msg=err_msg2) + self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) + self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) + + def test_mol_to_dglgraph(self): + + np.random.seed(42) + + for s in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" + mol = dm.to_mol(s) + mol_Hs = Chem.AddHs(mol) # type: ignore + mol_No_Hs = Chem.RemoveHs(mol) # type: ignore + + for explicit_H in [True, False]: + this_mol = mol_Hs if explicit_H else mol_No_Hs + for ii in np.arange(0, 5, 0.2): + num_props = int(round(ii)) + err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" + + graph = mol_to_dglgraph( + mol=mol, + atom_property_list_onehot=np.random.choice( + self.atomic_onehot_props, size=num_props, replace=False + ), + atom_property_list_float=np.random.choice( + self.atomic_float_props, size=num_props, replace=False + ), + edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), + add_self_loop=False, + explicit_H=explicit_H, + use_bonds_weights=False, + ) + + self.assertEqual(graph.num_nodes(), this_mol.GetNumAtoms(), msg=err_msg2) + self.assertEqual(graph.num_edges(), 2 * this_mol.GetNumBonds(), msg=err_msg2) + if num_props > 0: + ndata = graph.ndata["feat"] + edata = graph.edata["feat"] + self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) + self.assertEqual(edata.shape[0], 2 * this_mol.GetNumBonds(), msg=err_msg2) + self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) + self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_gnn_layers.py b/tests/test_gnn_layers.py index 80e71ff72..fa0cc6b4b 100644 --- a/tests/test_gnn_layers.py +++ b/tests/test_gnn_layers.py @@ -1,365 +1,365 @@ -""" -Unit tests for the different layers of goli/dgl/dgl_layers/... - -The layers are not thoroughly tested due to the difficulty of testing them -""" - -import numpy as np -import torch -import unittest as ut -import dgl -from copy import deepcopy - -from goli.nn.dgl_layers import ( - GATLayer, - GCNLayer, - GINLayer, - GatedGCNLayer, - PNAConvolutionalLayer, - PNAMessagePassingLayer, - DGNConvolutionalLayer, - DGNMessagePassingLayer, -) - - -class test_DGL_Layers(ut.TestCase): - - in_dim = 21 - out_dim = 11 - in_dim_edges = 13 - out_dim_edges = 17 - - g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) - g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) - g1.ndata["h"] = torch.zeros(g1.num_nodes(), in_dim, dtype=float) - g1.edata["e"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=float) - g2.ndata["h"] = torch.ones(g2.num_nodes(), in_dim, dtype=float) - g2.edata["e"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=float) - bg = dgl.batch([g1, g2]) - bg = dgl.add_self_loop(bg) - bg.ndata["pos_dir"] = torch.randn_like(bg.ndata["h"]) - - kwargs = { - "activation": "relu", - "dropout": 0.1, - "batch_norm": True, - } - - def test_gcnlayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - layer = GCNLayer(in_dim=self.in_dim, out_dim=self.out_dim, **self.kwargs).to(float) - - # Check the re-implementation of abstract methods - self.assertFalse(layer.layer_supports_edges) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_ginlayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - layer = GINLayer(in_dim=self.in_dim, out_dim=self.out_dim, **self.kwargs).to(float) - - # Check the re-implementation of abstract methods - self.assertFalse(layer.layer_supports_edges) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_gatlayer(self): - - num_heads = 3 - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - layer = GATLayer(in_dim=self.in_dim, out_dim=self.out_dim, num_heads=num_heads, **self.kwargs).to( - float - ) - - # Check the re-implementation of abstract methods - self.assertFalse(layer.layer_supports_edges) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, num_heads) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_gatedgcnlayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - e_in = bg.edata["e"] - layer = GatedGCNLayer( - in_dim=self.in_dim, - out_dim=self.out_dim, - in_dim_edges=self.in_dim_edges, - out_dim_edges=self.out_dim_edges, - **self.kwargs, - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertTrue(layer.layer_inputs_edges) - self.assertTrue(layer.layer_outputs_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertIsInstance(layer.layer_inputs_edges, bool) - self.assertIsInstance(layer.layer_outputs_edges, bool) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer without edges - with self.assertRaises(TypeError): - h = layer.forward(g=bg, h=h_in) - - # Apply layer with edges - h, e = layer.forward(g=bg, h=h_in, e=e_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_pnaconvolutionallayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - e_in = bg.edata["e"] - aggregators = ["mean", "max", "min", "lap", "std", "moment3", "moment4", "sum"] - scalers = ["identity", "amplification", "attenuation"] - - layer = PNAConvolutionalLayer( - in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - # Now try with edges - layer = PNAConvolutionalLayer( - in_dim=self.in_dim, - out_dim=self.out_dim, - aggregators=aggregators, - scalers=scalers, - in_dim_edges=self.in_dim_edges, - **self.kwargs, - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertTrue(layer.layer_inputs_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertIsInstance(layer.layer_inputs_edges, bool) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - with self.assertRaises(Exception): - h = layer.forward(g=bg, h=h_in) - - h = layer.forward(g=bg, h=h_in, e=e_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_pnamessagepassinglayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - e_in = bg.edata["e"] - aggregators = ["mean", "max", "min", "lap", "std", "moment3", "moment4", "sum"] - scalers = ["identity", "amplification", "attenuation"] - - layer = PNAMessagePassingLayer( - in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - # Now try with edges - layer = PNAMessagePassingLayer( - in_dim=self.in_dim, - out_dim=self.out_dim, - aggregators=aggregators, - scalers=scalers, - in_dim_edges=self.in_dim_edges, - **self.kwargs, - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertTrue(layer.layer_inputs_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertIsInstance(layer.layer_inputs_edges, bool) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - with self.assertRaises(Exception): - h = layer.forward(g=bg, h=h_in) - - h = layer.forward(g=bg, h=h_in, e=e_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_dgnconvolutionallayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - e_in = bg.edata["e"] - aggregators = [ - "mean", - "max", - "min", - "lap", - "std", - "moment3", - "moment4", - "sum", - "dir0/dx_abs", - "dir1/dx_abs", - "dir2/dx_no_abs", - "dir1/smooth", - "dir1/forward", - "dir1/backward/0.5", - "dir4/dx_abs/5", - ] - scalers = ["identity", "amplification", "attenuation"] - - layer = DGNConvolutionalLayer( - in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - # Now try with edges - layer = DGNConvolutionalLayer( - in_dim=self.in_dim, - out_dim=self.out_dim, - aggregators=aggregators, - scalers=scalers, - in_dim_edges=self.in_dim_edges, - **self.kwargs, - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertTrue(layer.layer_inputs_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertIsInstance(layer.layer_inputs_edges, bool) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - with self.assertRaises(Exception): - h = layer.forward(g=bg, h=h_in) - - h = layer.forward(g=bg, h=h_in, e=e_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - def test_dgnmessagepassinglayer(self): - - bg = deepcopy(self.bg) - h_in = bg.ndata["h"] - e_in = bg.edata["e"] - aggregators = [ - "mean", - "max", - "min", - "lap", - "std", - "moment3", - "moment4", - "sum", - "dir0/dx_abs", - "dir1/dx_abs", - "dir2/dx_no_abs", - "dir1/smooth", - "dir1/forward", - "dir1/backward/0.5", - "dir4/dx_abs/5", - ] - scalers = ["identity", "amplification", "attenuation"] - - layer = DGNMessagePassingLayer( - in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertFalse(layer.layer_inputs_edges) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - h = layer.forward(g=bg, h=h_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - # Now try with edges - layer = DGNMessagePassingLayer( - in_dim=self.in_dim, - out_dim=self.out_dim, - aggregators=aggregators, - scalers=scalers, - in_dim_edges=self.in_dim_edges, - **self.kwargs, - ).to(float) - - # Check the re-implementation of abstract methods - self.assertTrue(layer.layer_supports_edges) - self.assertTrue(layer.layer_inputs_edges) - self.assertIsInstance(layer.layer_supports_edges, bool) - self.assertIsInstance(layer.layer_inputs_edges, bool) - self.assertFalse(layer.layer_outputs_edges) - self.assertEqual(layer.out_dim_factor, 1) - - # Apply layer - with self.assertRaises(Exception): - h = layer.forward(g=bg, h=h_in) - - h = layer.forward(g=bg, h=h_in, e=e_in) - self.assertEqual(h.shape[0], h_in.shape[0]) - self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the different layers of goli/dgl/dgl_layers/... + +The layers are not thoroughly tested due to the difficulty of testing them +""" + +import numpy as np +import torch +import unittest as ut +import dgl +from copy import deepcopy + +from goli.nn.dgl_layers import ( + GATLayer, + GCNLayer, + GINLayer, + GatedGCNLayer, + PNAConvolutionalLayer, + PNAMessagePassingLayer, + DGNConvolutionalLayer, + DGNMessagePassingLayer, +) + + +class test_DGL_Layers(ut.TestCase): + + in_dim = 21 + out_dim = 11 + in_dim_edges = 13 + out_dim_edges = 17 + + g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) + g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) + g1.ndata["h"] = torch.zeros(g1.num_nodes(), in_dim, dtype=float) + g1.edata["e"] = torch.ones(g1.num_edges(), in_dim_edges, dtype=float) + g2.ndata["h"] = torch.ones(g2.num_nodes(), in_dim, dtype=float) + g2.edata["e"] = torch.zeros(g2.num_edges(), in_dim_edges, dtype=float) + bg = dgl.batch([g1, g2]) + bg = dgl.add_self_loop(bg) + bg.ndata["pos_dir"] = torch.randn_like(bg.ndata["h"]) + + kwargs = { + "activation": "relu", + "dropout": 0.1, + "batch_norm": True, + } + + def test_gcnlayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + layer = GCNLayer(in_dim=self.in_dim, out_dim=self.out_dim, **self.kwargs).to(float) + + # Check the re-implementation of abstract methods + self.assertFalse(layer.layer_supports_edges) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_ginlayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + layer = GINLayer(in_dim=self.in_dim, out_dim=self.out_dim, **self.kwargs).to(float) + + # Check the re-implementation of abstract methods + self.assertFalse(layer.layer_supports_edges) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_gatlayer(self): + + num_heads = 3 + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + layer = GATLayer(in_dim=self.in_dim, out_dim=self.out_dim, num_heads=num_heads, **self.kwargs).to( + float + ) + + # Check the re-implementation of abstract methods + self.assertFalse(layer.layer_supports_edges) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, num_heads) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_gatedgcnlayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + e_in = bg.edata["e"] + layer = GatedGCNLayer( + in_dim=self.in_dim, + out_dim=self.out_dim, + in_dim_edges=self.in_dim_edges, + out_dim_edges=self.out_dim_edges, + **self.kwargs, + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertTrue(layer.layer_inputs_edges) + self.assertTrue(layer.layer_outputs_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertIsInstance(layer.layer_inputs_edges, bool) + self.assertIsInstance(layer.layer_outputs_edges, bool) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer without edges + with self.assertRaises(TypeError): + h = layer.forward(g=bg, h=h_in) + + # Apply layer with edges + h, e = layer.forward(g=bg, h=h_in, e=e_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_pnaconvolutionallayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + e_in = bg.edata["e"] + aggregators = ["mean", "max", "min", "lap", "std", "moment3", "moment4", "sum"] + scalers = ["identity", "amplification", "attenuation"] + + layer = PNAConvolutionalLayer( + in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + # Now try with edges + layer = PNAConvolutionalLayer( + in_dim=self.in_dim, + out_dim=self.out_dim, + aggregators=aggregators, + scalers=scalers, + in_dim_edges=self.in_dim_edges, + **self.kwargs, + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertTrue(layer.layer_inputs_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertIsInstance(layer.layer_inputs_edges, bool) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + with self.assertRaises(Exception): + h = layer.forward(g=bg, h=h_in) + + h = layer.forward(g=bg, h=h_in, e=e_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_pnamessagepassinglayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + e_in = bg.edata["e"] + aggregators = ["mean", "max", "min", "lap", "std", "moment3", "moment4", "sum"] + scalers = ["identity", "amplification", "attenuation"] + + layer = PNAMessagePassingLayer( + in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + # Now try with edges + layer = PNAMessagePassingLayer( + in_dim=self.in_dim, + out_dim=self.out_dim, + aggregators=aggregators, + scalers=scalers, + in_dim_edges=self.in_dim_edges, + **self.kwargs, + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertTrue(layer.layer_inputs_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertIsInstance(layer.layer_inputs_edges, bool) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + with self.assertRaises(Exception): + h = layer.forward(g=bg, h=h_in) + + h = layer.forward(g=bg, h=h_in, e=e_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_dgnconvolutionallayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + e_in = bg.edata["e"] + aggregators = [ + "mean", + "max", + "min", + "lap", + "std", + "moment3", + "moment4", + "sum", + "dir0/dx_abs", + "dir1/dx_abs", + "dir2/dx_no_abs", + "dir1/smooth", + "dir1/forward", + "dir1/backward/0.5", + "dir4/dx_abs/5", + ] + scalers = ["identity", "amplification", "attenuation"] + + layer = DGNConvolutionalLayer( + in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + # Now try with edges + layer = DGNConvolutionalLayer( + in_dim=self.in_dim, + out_dim=self.out_dim, + aggregators=aggregators, + scalers=scalers, + in_dim_edges=self.in_dim_edges, + **self.kwargs, + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertTrue(layer.layer_inputs_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertIsInstance(layer.layer_inputs_edges, bool) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + with self.assertRaises(Exception): + h = layer.forward(g=bg, h=h_in) + + h = layer.forward(g=bg, h=h_in, e=e_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + def test_dgnmessagepassinglayer(self): + + bg = deepcopy(self.bg) + h_in = bg.ndata["h"] + e_in = bg.edata["e"] + aggregators = [ + "mean", + "max", + "min", + "lap", + "std", + "moment3", + "moment4", + "sum", + "dir0/dx_abs", + "dir1/dx_abs", + "dir2/dx_no_abs", + "dir1/smooth", + "dir1/forward", + "dir1/backward/0.5", + "dir4/dx_abs/5", + ] + scalers = ["identity", "amplification", "attenuation"] + + layer = DGNMessagePassingLayer( + in_dim=self.in_dim, out_dim=self.out_dim, aggregators=aggregators, scalers=scalers, **self.kwargs + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertFalse(layer.layer_inputs_edges) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + h = layer.forward(g=bg, h=h_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + # Now try with edges + layer = DGNMessagePassingLayer( + in_dim=self.in_dim, + out_dim=self.out_dim, + aggregators=aggregators, + scalers=scalers, + in_dim_edges=self.in_dim_edges, + **self.kwargs, + ).to(float) + + # Check the re-implementation of abstract methods + self.assertTrue(layer.layer_supports_edges) + self.assertTrue(layer.layer_inputs_edges) + self.assertIsInstance(layer.layer_supports_edges, bool) + self.assertIsInstance(layer.layer_inputs_edges, bool) + self.assertFalse(layer.layer_outputs_edges) + self.assertEqual(layer.out_dim_factor, 1) + + # Apply layer + with self.assertRaises(Exception): + h = layer.forward(g=bg, h=h_in) + + h = layer.forward(g=bg, h=h_in, e=e_in) + self.assertEqual(h.shape[0], h_in.shape[0]) + self.assertEqual(h.shape[1], self.out_dim * layer.out_dim_factor) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index bdffd6409..af8bd8499 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,149 +1,149 @@ -""" -Unit tests for the metrics and wrappers of goli/trainer/metrics/... -""" - -import torch -import unittest as ut - -from goli.trainer.metrics import ( - MetricWrapper, - Thresholder, - pearsonr, - spearmanr, - mean_squared_error, -) - - -class test_Metrics(ut.TestCase): - def test_thresholder(self): - - torch.manual_seed(42) - preds = torch.rand(100, dtype=torch.float32) - target = torch.rand(100, dtype=torch.float32) - - th = 0.7 - preds_greater = preds > th - target_greater = target > th - - # Test thresholder greater - for th_on_preds in [True, False]: - for th_on_target in [True, False]: - thresholder = Thresholder( - threshold=th, operator="greater", th_on_target=th_on_target, th_on_preds=th_on_preds - ) - preds2, target2 = thresholder(preds, target) - if th_on_preds: - self.assertListEqual(preds2.tolist(), preds_greater.tolist()) - else: - self.assertListEqual(preds2.tolist(), preds.tolist()) - if th_on_target: - self.assertListEqual(target2.tolist(), target_greater.tolist()) - else: - self.assertListEqual(target2.tolist(), target.tolist()) - - # Test thresholder lower - for th_on_preds in [True, False]: - for th_on_target in [True, False]: - thresholder = Thresholder( - threshold=th, operator="lower", th_on_target=th_on_target, th_on_preds=th_on_preds - ) - preds2, target2 = thresholder(preds, target) - if th_on_preds: - self.assertListEqual(preds2.tolist(), (~preds_greater).tolist()) - else: - self.assertListEqual(preds2.tolist(), preds.tolist()) - if th_on_target: - self.assertListEqual(target2.tolist(), (~target_greater).tolist()) - else: - self.assertListEqual(target2.tolist(), target.tolist()) - - def test_pearsonr_spearmanr(self): - preds = torch.tensor([0.0, 1, 2, 3]) - target = torch.tensor([0.0, 1, 2, 1.5]) - - self.assertAlmostEqual(pearsonr(preds, target).tolist(), 0.8315, places=4) - self.assertAlmostEqual(spearmanr(preds, target).tolist(), 0.8, places=4) - - preds = torch.tensor([76, 25, 72, 0, 60, 96, 55, 57, 10, 26, 47, 87, 97, 2, 20]) - target = torch.tensor([12, 80, 35, 6, 58, 22, 41, 66, 92, 55, 46, 61, 89, 83, 14]) - - self.assertAlmostEqual(pearsonr(preds, target).tolist(), -0.0784, places=4) - self.assertAlmostEqual(spearmanr(preds, target).tolist(), -0.024999, places=4) - - preds = preds.repeat(2, 1).T - target = target.repeat(2, 1).T - - self.assertAlmostEqual(pearsonr(preds, target).tolist(), -0.0784, places=4) - self.assertAlmostEqual(spearmanr(preds, target).tolist(), -0.024999, places=4) - - -class test_MetricWrapper(ut.TestCase): - def test_target_nan_mask(self): - - torch.random.manual_seed(42) - - for sz in [(100,), (100, 1), (100, 10)]: - - err_msg = f"Error for `sz = {sz}`" - - # Generate prediction and target matrices - target = torch.rand(sz, dtype=torch.float32) - preds = (0.5 * target) + (0.5 * torch.rand(sz, dtype=torch.float32)) - is_nan = torch.rand(sz) > 0.3 - target = (target > 0.5).to(torch.float32) - target[is_nan] = float("nan") - - # Compute score with different ways of ignoring NaNs - metric = MetricWrapper(metric="mse", target_nan_mask=None) - score1 = metric(preds, target) - self.assertTrue(torch.isnan(score1), msg=err_msg) - - # Replace NaNs by 0 - metric = MetricWrapper(metric="mse", target_nan_mask=0) - score2 = metric(preds, target) - - this_target = target.clone() - this_target[is_nan] = 0 - this_preds = preds.clone() - self.assertAlmostEqual(score2, mean_squared_error(this_preds, this_target), msg=err_msg) - - # Replace NaNs by 1.5 - metric = MetricWrapper(metric="mse", target_nan_mask=1.5) - score3 = metric(preds, target) - - this_target = target.clone() - this_target[is_nan] = 1.5 - this_preds = preds.clone() - self.assertAlmostEqual(score3, mean_squared_error(this_preds, this_target), msg=err_msg) - - # Flatten matrix and ignore NaNs - metric = MetricWrapper(metric="mse", target_nan_mask="ignore-flatten") - score4 = metric(preds, target) - - this_target = target.clone()[~is_nan] - this_preds = preds.clone()[~is_nan] - self.assertAlmostEqual(score4, mean_squared_error(this_preds, this_target), msg=err_msg) - - # Ignore NaNs in each column and average the score - metric = MetricWrapper(metric="mse", target_nan_mask="ignore-mean-label") - score5 = metric(preds, target) - - this_target = target.clone() - this_preds = preds.clone() - this_is_nan = is_nan.clone() - if len(sz) == 1: - this_target = target.unsqueeze(-1) - this_preds = preds.unsqueeze(-1) - this_is_nan = is_nan.unsqueeze(-1) - - this_target = [this_target[:, ii][~this_is_nan[:, ii]] for ii in range(this_target.shape[1])] - this_preds = [this_preds[:, ii][~this_is_nan[:, ii]] for ii in range(this_preds.shape[1])] - mse = [] - for ii in range(len(this_preds)): - mse.append(mean_squared_error(this_preds[ii], this_target[ii])) - mse = torch.mean(torch.stack(mse)) - self.assertAlmostEqual(score5.tolist(), mse.tolist(), msg=err_msg) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the metrics and wrappers of goli/trainer/metrics/... +""" + +import torch +import unittest as ut + +from goli.trainer.metrics import ( + MetricWrapper, + Thresholder, + pearsonr, + spearmanr, + mean_squared_error, +) + + +class test_Metrics(ut.TestCase): + def test_thresholder(self): + + torch.manual_seed(42) + preds = torch.rand(100, dtype=torch.float32) + target = torch.rand(100, dtype=torch.float32) + + th = 0.7 + preds_greater = preds > th + target_greater = target > th + + # Test thresholder greater + for th_on_preds in [True, False]: + for th_on_target in [True, False]: + thresholder = Thresholder( + threshold=th, operator="greater", th_on_target=th_on_target, th_on_preds=th_on_preds + ) + preds2, target2 = thresholder(preds, target) + if th_on_preds: + self.assertListEqual(preds2.tolist(), preds_greater.tolist()) + else: + self.assertListEqual(preds2.tolist(), preds.tolist()) + if th_on_target: + self.assertListEqual(target2.tolist(), target_greater.tolist()) + else: + self.assertListEqual(target2.tolist(), target.tolist()) + + # Test thresholder lower + for th_on_preds in [True, False]: + for th_on_target in [True, False]: + thresholder = Thresholder( + threshold=th, operator="lower", th_on_target=th_on_target, th_on_preds=th_on_preds + ) + preds2, target2 = thresholder(preds, target) + if th_on_preds: + self.assertListEqual(preds2.tolist(), (~preds_greater).tolist()) + else: + self.assertListEqual(preds2.tolist(), preds.tolist()) + if th_on_target: + self.assertListEqual(target2.tolist(), (~target_greater).tolist()) + else: + self.assertListEqual(target2.tolist(), target.tolist()) + + def test_pearsonr_spearmanr(self): + preds = torch.tensor([0.0, 1, 2, 3]) + target = torch.tensor([0.0, 1, 2, 1.5]) + + self.assertAlmostEqual(pearsonr(preds, target).tolist(), 0.8315, places=4) + self.assertAlmostEqual(spearmanr(preds, target).tolist(), 0.8, places=4) + + preds = torch.tensor([76, 25, 72, 0, 60, 96, 55, 57, 10, 26, 47, 87, 97, 2, 20]) + target = torch.tensor([12, 80, 35, 6, 58, 22, 41, 66, 92, 55, 46, 61, 89, 83, 14]) + + self.assertAlmostEqual(pearsonr(preds, target).tolist(), -0.0784, places=4) + self.assertAlmostEqual(spearmanr(preds, target).tolist(), -0.024999, places=4) + + preds = preds.repeat(2, 1).T + target = target.repeat(2, 1).T + + self.assertAlmostEqual(pearsonr(preds, target).tolist(), -0.0784, places=4) + self.assertAlmostEqual(spearmanr(preds, target).tolist(), -0.024999, places=4) + + +class test_MetricWrapper(ut.TestCase): + def test_target_nan_mask(self): + + torch.random.manual_seed(42) + + for sz in [(100,), (100, 1), (100, 10)]: + + err_msg = f"Error for `sz = {sz}`" + + # Generate prediction and target matrices + target = torch.rand(sz, dtype=torch.float32) + preds = (0.5 * target) + (0.5 * torch.rand(sz, dtype=torch.float32)) + is_nan = torch.rand(sz) > 0.3 + target = (target > 0.5).to(torch.float32) + target[is_nan] = float("nan") + + # Compute score with different ways of ignoring NaNs + metric = MetricWrapper(metric="mse", target_nan_mask=None) + score1 = metric(preds, target) + self.assertTrue(torch.isnan(score1), msg=err_msg) + + # Replace NaNs by 0 + metric = MetricWrapper(metric="mse", target_nan_mask=0) + score2 = metric(preds, target) + + this_target = target.clone() + this_target[is_nan] = 0 + this_preds = preds.clone() + self.assertAlmostEqual(score2, mean_squared_error(this_preds, this_target), msg=err_msg) + + # Replace NaNs by 1.5 + metric = MetricWrapper(metric="mse", target_nan_mask=1.5) + score3 = metric(preds, target) + + this_target = target.clone() + this_target[is_nan] = 1.5 + this_preds = preds.clone() + self.assertAlmostEqual(score3, mean_squared_error(this_preds, this_target), msg=err_msg) + + # Flatten matrix and ignore NaNs + metric = MetricWrapper(metric="mse", target_nan_mask="ignore-flatten") + score4 = metric(preds, target) + + this_target = target.clone()[~is_nan] + this_preds = preds.clone()[~is_nan] + self.assertAlmostEqual(score4, mean_squared_error(this_preds, this_target), msg=err_msg) + + # Ignore NaNs in each column and average the score + metric = MetricWrapper(metric="mse", target_nan_mask="ignore-mean-label") + score5 = metric(preds, target) + + this_target = target.clone() + this_preds = preds.clone() + this_is_nan = is_nan.clone() + if len(sz) == 1: + this_target = target.unsqueeze(-1) + this_preds = preds.unsqueeze(-1) + this_is_nan = is_nan.unsqueeze(-1) + + this_target = [this_target[:, ii][~this_is_nan[:, ii]] for ii in range(this_target.shape[1])] + this_preds = [this_preds[:, ii][~this_is_nan[:, ii]] for ii in range(this_preds.shape[1])] + mse = [] + for ii in range(len(this_preds)): + mse.append(mean_squared_error(this_preds[ii], this_target[ii])) + mse = torch.mean(torch.stack(mse)) + self.assertAlmostEqual(score5.tolist(), mse.tolist(), msg=err_msg) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_positional_encoding.py b/tests/test_positional_encoding.py index 66471880a..a907c19ba 100644 --- a/tests/test_positional_encoding.py +++ b/tests/test_positional_encoding.py @@ -1,100 +1,100 @@ -""" -Unit tests for the different datasets of goli/features/featurizer.py -""" - -import numpy as np -import unittest as ut -from copy import deepcopy -from rdkit import Chem -import datamol as dm - -from goli.features.featurizer import ( - mol_to_adj_and_features, - mol_to_dglgraph, -) -from goli.features.positional_encoding import graph_positional_encoder - - -class test_positional_encoder(ut.TestCase): - - smiles = [ - "C", - "CC", - "CC.CCCC", - "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", - "OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N", - "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5", - ] - mols = [dm.to_mol(s) for s in smiles] - adjs = [Chem.rdmolops.GetAdjacencyMatrix(mol) for mol in mols] - - def test_laplacian_eigvec(self): - - for ii, adj in enumerate(deepcopy(self.adjs)): - for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities - for disconnected_comp in [True, False]: - err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - pos_enc_sign_flip, pos_enc_no_flip = graph_positional_encoder( - adj, pos_type="laplacian_eigvec", num_pos=num_pos, disconnected_comp=disconnected_comp - ) - self.assertEqual(list(pos_enc_sign_flip.shape), [adj.shape[0], num_pos], msg=err_msg) - self.assertIsNone(pos_enc_no_flip) - - # Compute eigvals and eigvecs - lap = np.diag(np.sum(adj, axis=1)) - adj - eigvals, eigvecs = np.linalg.eig(lap) - sort_idx = np.argsort(eigvals) - eigvals, eigvecs = eigvals[sort_idx], eigvecs[:, sort_idx] - eigvecs = eigvecs / (np.sum(eigvecs ** 2, axis=0, keepdims=True) + 1e-8) - - true_num_pos = min(num_pos, len(eigvals)) - eigvals, eigvecs = eigvals[:true_num_pos], eigvecs[:, :true_num_pos] - eigvecs = np.sign(eigvecs[0:1, :]) * eigvecs - pos_enc_sign_flip = (np.sign(pos_enc_sign_flip[0:1, :]) * pos_enc_sign_flip).numpy() - - # Compare the positional encoding - if disconnected_comp and ("." in self.smiles[ii]): - self.assertGreater(np.max(np.abs(eigvecs - pos_enc_sign_flip)), 1e-3) - elif not ("." in self.smiles[ii]): - np.testing.assert_array_almost_equal( - eigvecs, pos_enc_sign_flip[:, :true_num_pos], decimal=6, err_msg=err_msg - ) - - def test_laplacian_eigvec_eigval(self): - - for ii, adj in enumerate(deepcopy(self.adjs)): - for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities - for disconnected_comp in [True, False]: - err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - pos_enc_sign_flip, pos_enc_no_flip = graph_positional_encoder( - adj, - pos_type="laplacian_eigvec_eigval", - num_pos=num_pos, - disconnected_comp=disconnected_comp, - ) - self.assertEqual(list(pos_enc_sign_flip.shape), [adj.shape[0], num_pos], msg=err_msg) - self.assertEqual(list(pos_enc_no_flip.shape), [adj.shape[0], num_pos], msg=err_msg) - - # Compute eigvals and eigvecs - lap = np.diag(np.sum(adj, axis=1)) - adj - eigvals, eigvecs = np.linalg.eig(lap) - sort_idx = np.argsort(eigvals) - eigvals, eigvecs = eigvals[sort_idx], eigvecs[:, sort_idx] - eigvecs = eigvecs / (np.sum(eigvecs ** 2, axis=0, keepdims=True) + 1e-8) - - true_num_pos = min(num_pos, len(eigvals)) - eigvals, eigvecs = eigvals[:true_num_pos], eigvecs[:, :true_num_pos] - eigvecs = np.sign(eigvecs[0:1, :]) * eigvecs - pos_enc_sign_flip = (np.sign(pos_enc_sign_flip[0:1, :]) * pos_enc_sign_flip).numpy() - - if not ("." in self.smiles[ii]): - np.testing.assert_array_almost_equal( - eigvecs, pos_enc_sign_flip[:, :true_num_pos], decimal=6, err_msg=err_msg - ) - np.testing.assert_array_almost_equal( - eigvals, pos_enc_no_flip[0, :true_num_pos], decimal=6, err_msg=err_msg - ) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the different datasets of goli/features/featurizer.py +""" + +import numpy as np +import unittest as ut +from copy import deepcopy +from rdkit import Chem +import datamol as dm + +from goli.features.featurizer import ( + mol_to_adj_and_features, + mol_to_dglgraph, +) +from goli.features.positional_encoding import graph_positional_encoder + + +class test_positional_encoder(ut.TestCase): + + smiles = [ + "C", + "CC", + "CC.CCCC", + "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", + "OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N", + "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5", + ] + mols = [dm.to_mol(s) for s in smiles] + adjs = [Chem.rdmolops.GetAdjacencyMatrix(mol) for mol in mols] + + def test_laplacian_eigvec(self): + + for ii, adj in enumerate(deepcopy(self.adjs)): + for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities + for disconnected_comp in [True, False]: + err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" + pos_enc_sign_flip, pos_enc_no_flip = graph_positional_encoder( + adj, pos_type="laplacian_eigvec", num_pos=num_pos, disconnected_comp=disconnected_comp + ) + self.assertEqual(list(pos_enc_sign_flip.shape), [adj.shape[0], num_pos], msg=err_msg) + self.assertIsNone(pos_enc_no_flip) + + # Compute eigvals and eigvecs + lap = np.diag(np.sum(adj, axis=1)) - adj + eigvals, eigvecs = np.linalg.eig(lap) + sort_idx = np.argsort(eigvals) + eigvals, eigvecs = eigvals[sort_idx], eigvecs[:, sort_idx] + eigvecs = eigvecs / (np.sum(eigvecs ** 2, axis=0, keepdims=True) + 1e-8) + + true_num_pos = min(num_pos, len(eigvals)) + eigvals, eigvecs = eigvals[:true_num_pos], eigvecs[:, :true_num_pos] + eigvecs = np.sign(eigvecs[0:1, :]) * eigvecs + pos_enc_sign_flip = (np.sign(pos_enc_sign_flip[0:1, :]) * pos_enc_sign_flip).numpy() + + # Compare the positional encoding + if disconnected_comp and ("." in self.smiles[ii]): + self.assertGreater(np.max(np.abs(eigvecs - pos_enc_sign_flip)), 1e-3) + elif not ("." in self.smiles[ii]): + np.testing.assert_array_almost_equal( + eigvecs, pos_enc_sign_flip[:, :true_num_pos], decimal=6, err_msg=err_msg + ) + + def test_laplacian_eigvec_eigval(self): + + for ii, adj in enumerate(deepcopy(self.adjs)): + for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities + for disconnected_comp in [True, False]: + err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" + pos_enc_sign_flip, pos_enc_no_flip = graph_positional_encoder( + adj, + pos_type="laplacian_eigvec_eigval", + num_pos=num_pos, + disconnected_comp=disconnected_comp, + ) + self.assertEqual(list(pos_enc_sign_flip.shape), [adj.shape[0], num_pos], msg=err_msg) + self.assertEqual(list(pos_enc_no_flip.shape), [adj.shape[0], num_pos], msg=err_msg) + + # Compute eigvals and eigvecs + lap = np.diag(np.sum(adj, axis=1)) - adj + eigvals, eigvecs = np.linalg.eig(lap) + sort_idx = np.argsort(eigvals) + eigvals, eigvecs = eigvals[sort_idx], eigvecs[:, sort_idx] + eigvecs = eigvecs / (np.sum(eigvecs ** 2, axis=0, keepdims=True) + 1e-8) + + true_num_pos = min(num_pos, len(eigvals)) + eigvals, eigvecs = eigvals[:true_num_pos], eigvecs[:, :true_num_pos] + eigvecs = np.sign(eigvecs[0:1, :]) * eigvecs + pos_enc_sign_flip = (np.sign(pos_enc_sign_flip[0:1, :]) * pos_enc_sign_flip).numpy() + + if not ("." in self.smiles[ii]): + np.testing.assert_array_almost_equal( + eigvecs, pos_enc_sign_flip[:, :true_num_pos], decimal=6, err_msg=err_msg + ) + np.testing.assert_array_almost_equal( + eigvals, pos_enc_no_flip[0, :true_num_pos], decimal=6, err_msg=err_msg + ) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_predictor_module.py b/tests/test_predictor_module.py index a34e01455..85d248755 100644 --- a/tests/test_predictor_module.py +++ b/tests/test_predictor_module.py @@ -1,50 +1,50 @@ -import yaml - -import goli - - -def test_load_pretrained_model(): - predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") - assert isinstance(predictor, goli.trainer.predictor.PredictorModule) - - -def test_training(datadir, tmpdir): - - config_path = datadir / "config_micro_ZINC.yaml" - data_path = datadir / "micro_ZINC.csv" - - # Load a config - with open(config_path, "r") as file: - yaml_config = yaml.load(file, Loader=yaml.FullLoader) - - training_dir = tmpdir.mkdir("training") - - # Tweak config and paths - yaml_config["datamodule"]["args"]["df_path"] = data_path - yaml_config["datamodule"]["args"]["cache_data_path"] = None - - yaml_config["trainer"]["trainer"]["min_epochs"] = 1 - yaml_config["trainer"]["trainer"]["max_epochs"] = 1 - - yaml_config["trainer"]["logger"]["save_dir"] = training_dir - yaml_config["trainer"]["model_checkpoint"]["dirpath"] = None - yaml_config["trainer"]["trainer"]["default_root_dir"] = training_dir - yaml_config["trainer"]["trainer"]["gpus"] = 0 - yaml_config["trainer"]["trainer"]["limit_train_batches"] = 1 - yaml_config["trainer"]["trainer"]["limit_val_batches"] = 1 - yaml_config["trainer"]["trainer"]["limit_test_batches"] = 1 - - # Load datamodule - datamodule = goli.config.load_datamodule(yaml_config) - - # Load a trainer - trainer = goli.config.load_trainer(yaml_config) - - # Load a pretrained model - predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") - - # Inference - results = trainer.predict(predictor, datamodule=datamodule, return_predictions=True) - - assert len(results) == 8 # type: ignore - assert tuple(results[0].shape) == (128, 1) # type: ignore +import yaml + +import goli + + +def test_load_pretrained_model(): + predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") + assert isinstance(predictor, goli.trainer.predictor.PredictorModule) + + +def test_training(datadir, tmpdir): + + config_path = datadir / "config_micro_ZINC.yaml" + data_path = datadir / "micro_ZINC.csv" + + # Load a config + with open(config_path, "r") as file: + yaml_config = yaml.load(file, Loader=yaml.FullLoader) + + training_dir = tmpdir.mkdir("training") + + # Tweak config and paths + yaml_config["datamodule"]["args"]["df_path"] = data_path + yaml_config["datamodule"]["args"]["cache_data_path"] = None + + yaml_config["trainer"]["trainer"]["min_epochs"] = 1 + yaml_config["trainer"]["trainer"]["max_epochs"] = 1 + + yaml_config["trainer"]["logger"]["save_dir"] = training_dir + yaml_config["trainer"]["model_checkpoint"]["dirpath"] = None + yaml_config["trainer"]["trainer"]["default_root_dir"] = training_dir + yaml_config["trainer"]["trainer"]["gpus"] = 0 + yaml_config["trainer"]["trainer"]["limit_train_batches"] = 1 + yaml_config["trainer"]["trainer"]["limit_val_batches"] = 1 + yaml_config["trainer"]["trainer"]["limit_test_batches"] = 1 + + # Load datamodule + datamodule = goli.config.load_datamodule(yaml_config) + + # Load a trainer + trainer = goli.config.load_trainer(yaml_config) + + # Load a pretrained model + predictor = goli.trainer.PredictorModule.load_pretrained_models("goli-zinc-micro-dummy-test") + + # Inference + results = trainer.predict(predictor, datamodule=datamodule, return_predictions=True) + + assert len(results) == 8 # type: ignore + assert tuple(results[0].shape) == (128, 1) # type: ignore diff --git a/tests/test_residual_connections.py b/tests/test_residual_connections.py index 4cd511231..7ba54a62a 100644 --- a/tests/test_residual_connections.py +++ b/tests/test_residual_connections.py @@ -1,236 +1,236 @@ -""" -Unit tests for the file goli/dgl/residual_connections.py -""" - -import numpy as np -import torch -import unittest as ut - -from goli.nn.residual_connections import ( - ResidualConnectionConcat, - ResidualConnectionDenseNet, - ResidualConnectionNone, - ResidualConnectionSimple, - ResidualConnectionWeighted, -) - - -class test_ResidualConnectionNone(ut.TestCase): - def test_get_true_out_dims_none(self): - full_dims = [4, 6, 8, 10, 12] - in_dims, out_dims = full_dims[:-1], full_dims[1:] - rc = ResidualConnectionNone(skip_steps=1) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = out_dims[:-1] - - self.assertListEqual(expected_out_dims, true_out_dims) - - def test_forward_none(self): - rc = ResidualConnectionNone(skip_steps=1) - num_loops = 10 - shape = (3, 11) - h_original = [torch.rand(shape) for _ in range(num_loops)] - - h_prev = None - for ii in range(num_loops): - h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) - np.testing.assert_array_equal(h.numpy(), h_original[ii].numpy(), err_msg=f"ii={ii}") - self.assertIsNone(h_prev) - - -class test_ResidualConnectionSimple(ut.TestCase): - def test_get_true_out_dims_simple(self): - full_dims = [4, 6, 8, 10, 12] - in_dims, out_dims = full_dims[:-1], full_dims[1:] - rc = ResidualConnectionSimple(skip_steps=1) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = out_dims[:-1] - - self.assertListEqual(expected_out_dims, true_out_dims) - - def test_forward_simple(self): - - for skip_steps in [1, 2, 3]: - rc = ResidualConnectionSimple(skip_steps=skip_steps) - num_loops = 10 - shape = (3, 11) - h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] - - h_prev = None - for ii in range(num_loops): - h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) - - if ((ii % skip_steps) == 0) and (ii > 0): - h_expected = ( - torch.sum(torch.stack(h_original[0 : ii + 1 : skip_steps], dim=0), dim=0) - ).numpy() - h_expected_prev = h_expected - else: - h_expected = h_original[ii].numpy() - if ii == 0: - h_expected_prev = h_expected - - np.testing.assert_array_equal( - h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - np.testing.assert_array_equal( - h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - - -class test_ResidualConnectionWeighted(ut.TestCase): - def test_get_true_out_dims_weighted(self): - full_dims = [4, 6, 8, 10, 12] - in_dims, out_dims = full_dims[:-1], full_dims[1:] - rc = ResidualConnectionWeighted(skip_steps=1, out_dims=full_dims[1:]) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = out_dims[:-1] - - self.assertListEqual(expected_out_dims, true_out_dims) - - def test_forward_weighted(self): - - for skip_steps in [1, 2, 3]: - num_loops = 10 - shape = (3, 11) - full_dims = [shape[1]] * (num_loops + 1) - rc = ResidualConnectionWeighted( - skip_steps=skip_steps, out_dims=full_dims[1:], activation="none", batch_norm=False, bias=False - ) - - h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] - h_forward = [] - - h_prev = None - step_counter = 0 - for ii in range(num_loops): - - h_prev_backup = h_prev - h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) - - if ((ii % skip_steps) == 0) and (ii > 0): - h_forward.append(rc.residual_list[step_counter].forward(h_prev_backup)) - h_expected = (h_forward[-1] + h_original[ii]).detach().numpy() - h_expected_prev = h_expected - step_counter += 1 - else: - h_expected = h_original[ii].detach().numpy() - if ii == 0: - h_expected_prev = h_expected - - np.testing.assert_array_equal( - h.detach().numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - np.testing.assert_array_equal( - h_prev.detach().numpy(), - h_expected_prev, - err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}", - ) - - -class test_ResidualConnectionConcat(ut.TestCase): - def test_get_true_out_dims_concat(self): - full_dims = [4, 6, 8, 10, 12, 14, 16, 18, 20] - in_dims, out_dims = full_dims[:-1], full_dims[1:] - - # skip_steps=1 - rc = ResidualConnectionConcat(skip_steps=1) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = [6, 14, 18, 22, 26, 30, 34] - self.assertListEqual(expected_out_dims, true_out_dims) - - # skip_steps=2 - rc = ResidualConnectionConcat(skip_steps=2) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = [6, 8, 16, 12, 24, 16, 32] - self.assertListEqual(expected_out_dims, true_out_dims) - - # skip_steps=3 - rc = ResidualConnectionConcat(skip_steps=3) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = [6, 8, 10, 18, 14, 16, 30] - self.assertListEqual(expected_out_dims, true_out_dims) - - def test_forward_concat(self): - - for skip_steps in [1, 2, 3]: - rc = ResidualConnectionConcat(skip_steps=skip_steps) - num_loops = 10 - shape = (3, 11) - h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] - - h_prev = None - for ii in range(num_loops): - h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) - - if ((ii % skip_steps) == 0) and (ii > 0): - h_expected = ( - torch.cat(h_original[ii - skip_steps : ii + 1 : skip_steps][::-1], dim=-1) - ).numpy() - h_expected_prev = h_original[ii].numpy() - else: - h_expected = h_original[ii].numpy() - if ii == 0: - h_expected_prev = h_expected - - np.testing.assert_array_equal( - h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - np.testing.assert_array_equal( - h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - - -class test_ResidualConnectionDenseNet(ut.TestCase): - def test_get_true_out_dims_densenet(self): - full_dims = [4, 6, 8, 10, 12, 14, 16, 18, 20] - in_dims, out_dims = full_dims[:-1], full_dims[1:] - - # skip_steps=1 - rc = ResidualConnectionDenseNet(skip_steps=1) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = np.cumsum(out_dims).tolist()[:-1] - self.assertListEqual(expected_out_dims, true_out_dims) - - # skip_steps=2 - rc = ResidualConnectionDenseNet(skip_steps=2) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = [6, 8, 10 + 6, 12, 14 + 10 + 6, 16, 18 + 14 + 10 + 6] - self.assertListEqual(expected_out_dims, true_out_dims) - - # skip_steps=3 - rc = ResidualConnectionDenseNet(skip_steps=3) - true_out_dims = rc.get_true_out_dims(out_dims) - expected_out_dims = [6, 8, 10, 12 + 6, 14, 16, 18 + 12 + 6] - self.assertListEqual(expected_out_dims, true_out_dims) - - def test_forward_densenet(self): - - for skip_steps in [1, 2, 3]: - rc = ResidualConnectionDenseNet(skip_steps=skip_steps) - num_loops = 10 - shape = (3, 11) - h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] - - h_prev = None - for ii in range(num_loops): - h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) - - if ((ii % skip_steps) == 0) and (ii > 0): - h_expected = (torch.cat(h_original[0 : ii + 1 : skip_steps][::-1], dim=-1)).numpy() - h_expected_prev = h_expected - else: - h_expected = h_original[ii].numpy() - if ii == 0: - h_expected_prev = h_expected - - np.testing.assert_array_equal( - h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - np.testing.assert_array_equal( - h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" - ) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the file goli/dgl/residual_connections.py +""" + +import numpy as np +import torch +import unittest as ut + +from goli.nn.residual_connections import ( + ResidualConnectionConcat, + ResidualConnectionDenseNet, + ResidualConnectionNone, + ResidualConnectionSimple, + ResidualConnectionWeighted, +) + + +class test_ResidualConnectionNone(ut.TestCase): + def test_get_true_out_dims_none(self): + full_dims = [4, 6, 8, 10, 12] + in_dims, out_dims = full_dims[:-1], full_dims[1:] + rc = ResidualConnectionNone(skip_steps=1) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = out_dims[:-1] + + self.assertListEqual(expected_out_dims, true_out_dims) + + def test_forward_none(self): + rc = ResidualConnectionNone(skip_steps=1) + num_loops = 10 + shape = (3, 11) + h_original = [torch.rand(shape) for _ in range(num_loops)] + + h_prev = None + for ii in range(num_loops): + h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) + np.testing.assert_array_equal(h.numpy(), h_original[ii].numpy(), err_msg=f"ii={ii}") + self.assertIsNone(h_prev) + + +class test_ResidualConnectionSimple(ut.TestCase): + def test_get_true_out_dims_simple(self): + full_dims = [4, 6, 8, 10, 12] + in_dims, out_dims = full_dims[:-1], full_dims[1:] + rc = ResidualConnectionSimple(skip_steps=1) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = out_dims[:-1] + + self.assertListEqual(expected_out_dims, true_out_dims) + + def test_forward_simple(self): + + for skip_steps in [1, 2, 3]: + rc = ResidualConnectionSimple(skip_steps=skip_steps) + num_loops = 10 + shape = (3, 11) + h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] + + h_prev = None + for ii in range(num_loops): + h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) + + if ((ii % skip_steps) == 0) and (ii > 0): + h_expected = ( + torch.sum(torch.stack(h_original[0 : ii + 1 : skip_steps], dim=0), dim=0) + ).numpy() + h_expected_prev = h_expected + else: + h_expected = h_original[ii].numpy() + if ii == 0: + h_expected_prev = h_expected + + np.testing.assert_array_equal( + h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + np.testing.assert_array_equal( + h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + + +class test_ResidualConnectionWeighted(ut.TestCase): + def test_get_true_out_dims_weighted(self): + full_dims = [4, 6, 8, 10, 12] + in_dims, out_dims = full_dims[:-1], full_dims[1:] + rc = ResidualConnectionWeighted(skip_steps=1, out_dims=full_dims[1:]) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = out_dims[:-1] + + self.assertListEqual(expected_out_dims, true_out_dims) + + def test_forward_weighted(self): + + for skip_steps in [1, 2, 3]: + num_loops = 10 + shape = (3, 11) + full_dims = [shape[1]] * (num_loops + 1) + rc = ResidualConnectionWeighted( + skip_steps=skip_steps, out_dims=full_dims[1:], activation="none", batch_norm=False, bias=False + ) + + h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] + h_forward = [] + + h_prev = None + step_counter = 0 + for ii in range(num_loops): + + h_prev_backup = h_prev + h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) + + if ((ii % skip_steps) == 0) and (ii > 0): + h_forward.append(rc.residual_list[step_counter].forward(h_prev_backup)) + h_expected = (h_forward[-1] + h_original[ii]).detach().numpy() + h_expected_prev = h_expected + step_counter += 1 + else: + h_expected = h_original[ii].detach().numpy() + if ii == 0: + h_expected_prev = h_expected + + np.testing.assert_array_equal( + h.detach().numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + np.testing.assert_array_equal( + h_prev.detach().numpy(), + h_expected_prev, + err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}", + ) + + +class test_ResidualConnectionConcat(ut.TestCase): + def test_get_true_out_dims_concat(self): + full_dims = [4, 6, 8, 10, 12, 14, 16, 18, 20] + in_dims, out_dims = full_dims[:-1], full_dims[1:] + + # skip_steps=1 + rc = ResidualConnectionConcat(skip_steps=1) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = [6, 14, 18, 22, 26, 30, 34] + self.assertListEqual(expected_out_dims, true_out_dims) + + # skip_steps=2 + rc = ResidualConnectionConcat(skip_steps=2) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = [6, 8, 16, 12, 24, 16, 32] + self.assertListEqual(expected_out_dims, true_out_dims) + + # skip_steps=3 + rc = ResidualConnectionConcat(skip_steps=3) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = [6, 8, 10, 18, 14, 16, 30] + self.assertListEqual(expected_out_dims, true_out_dims) + + def test_forward_concat(self): + + for skip_steps in [1, 2, 3]: + rc = ResidualConnectionConcat(skip_steps=skip_steps) + num_loops = 10 + shape = (3, 11) + h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] + + h_prev = None + for ii in range(num_loops): + h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) + + if ((ii % skip_steps) == 0) and (ii > 0): + h_expected = ( + torch.cat(h_original[ii - skip_steps : ii + 1 : skip_steps][::-1], dim=-1) + ).numpy() + h_expected_prev = h_original[ii].numpy() + else: + h_expected = h_original[ii].numpy() + if ii == 0: + h_expected_prev = h_expected + + np.testing.assert_array_equal( + h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + np.testing.assert_array_equal( + h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + + +class test_ResidualConnectionDenseNet(ut.TestCase): + def test_get_true_out_dims_densenet(self): + full_dims = [4, 6, 8, 10, 12, 14, 16, 18, 20] + in_dims, out_dims = full_dims[:-1], full_dims[1:] + + # skip_steps=1 + rc = ResidualConnectionDenseNet(skip_steps=1) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = np.cumsum(out_dims).tolist()[:-1] + self.assertListEqual(expected_out_dims, true_out_dims) + + # skip_steps=2 + rc = ResidualConnectionDenseNet(skip_steps=2) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = [6, 8, 10 + 6, 12, 14 + 10 + 6, 16, 18 + 14 + 10 + 6] + self.assertListEqual(expected_out_dims, true_out_dims) + + # skip_steps=3 + rc = ResidualConnectionDenseNet(skip_steps=3) + true_out_dims = rc.get_true_out_dims(out_dims) + expected_out_dims = [6, 8, 10, 12 + 6, 14, 16, 18 + 12 + 6] + self.assertListEqual(expected_out_dims, true_out_dims) + + def test_forward_densenet(self): + + for skip_steps in [1, 2, 3]: + rc = ResidualConnectionDenseNet(skip_steps=skip_steps) + num_loops = 10 + shape = (3, 11) + h_original = [torch.ones(shape) * (ii + 1) for ii in range(num_loops)] + + h_prev = None + for ii in range(num_loops): + h, h_prev = rc.forward(h_original[ii], h_prev, step_idx=ii) + + if ((ii % skip_steps) == 0) and (ii > 0): + h_expected = (torch.cat(h_original[0 : ii + 1 : skip_steps][::-1], dim=-1)).numpy() + h_expected_prev = h_expected + else: + h_expected = h_original[ii].numpy() + if ii == 0: + h_expected_prev = h_expected + + np.testing.assert_array_equal( + h.numpy(), h_expected, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + np.testing.assert_array_equal( + h_prev.numpy(), h_expected_prev, err_msg=f"Error at: skip_steps={skip_steps}, ii={ii}" + ) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index c1ae3f81e..1391e62c4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,82 +1,82 @@ -""" -Unit tests for the metrics and wrappers of goli/utils/... -""" - -from goli.utils.tensor import nan_mean, nan_std, nan_var -import torch -import numpy as np -import unittest as ut - - -class test_nan_statistics(ut.TestCase): - - torch.manual_seed(42) - - dims = [ - None, - (0), - (1), - (2), - (-1), - (-2), - (-3), - (0, 1), - (0, 2), - ] - # Create tensor - sz = (10, 6, 7) - tensor = torch.randn(sz, dtype=torch.float32) ** 2 + 3 - is_nan = torch.rand(sz) > 0.4 - tensor[is_nan] = float("nan") - - def test_nan_mean(self): - - for keepdim in [False, True]: - for dim in self.dims: - err_msg = f"Error for :\n dim = {dim}\n keepdim = {keepdim}" - - tensor = self.tensor.clone() - - # Prepare the arguments for numpy vs torch - if dim is not None: - torch_kwargs = {"dim": dim, "keepdim": keepdim} - numpy_kwargs = {"axis": dim, "keepdims": keepdim} - else: - torch_kwargs = {} - numpy_kwargs = {} - - # Compare the nan-mean - torch_mean = nan_mean(tensor, **torch_kwargs) - numpy_mean = np.nanmean(tensor.numpy(), **numpy_kwargs) - np.testing.assert_almost_equal(torch_mean.numpy(), numpy_mean, decimal=6, err_msg=err_msg) - - def test_nan_std_var(self): - - for unbiased in [True, False]: - for keepdim in [False, True]: - for dim in self.dims: - err_msg = f"Error for :\n\tdim = {dim}\n\tkeepdim = {keepdim}\n\tunbiased = {unbiased}" - - tensor = self.tensor.clone() - - # Prepare the arguments for numpy vs torch - if dim is not None: - torch_kwargs = {"dim": dim, "keepdim": keepdim, "unbiased": unbiased} - numpy_kwargs = {"axis": dim, "keepdims": keepdim, "ddof": float(unbiased)} - else: - torch_kwargs = {"unbiased": unbiased} - numpy_kwargs = {"ddof": float(unbiased)} - - # Compare the std - torch_std = nan_std(tensor, **torch_kwargs) - numpy_std = np.nanstd(tensor.numpy(), **numpy_kwargs) - np.testing.assert_almost_equal(torch_std.numpy(), numpy_std, decimal=6, err_msg=err_msg) - - # Compare the variance - torch_var = nan_var(tensor, **torch_kwargs) - numpy_var = np.nanvar(tensor.numpy(), **numpy_kwargs) - np.testing.assert_almost_equal(torch_var.numpy(), numpy_var, decimal=6, err_msg=err_msg) - - -if __name__ == "__main__": - ut.main() +""" +Unit tests for the metrics and wrappers of goli/utils/... +""" + +from goli.utils.tensor import nan_mean, nan_std, nan_var +import torch +import numpy as np +import unittest as ut + + +class test_nan_statistics(ut.TestCase): + + torch.manual_seed(42) + + dims = [ + None, + (0), + (1), + (2), + (-1), + (-2), + (-3), + (0, 1), + (0, 2), + ] + # Create tensor + sz = (10, 6, 7) + tensor = torch.randn(sz, dtype=torch.float32) ** 2 + 3 + is_nan = torch.rand(sz) > 0.4 + tensor[is_nan] = float("nan") + + def test_nan_mean(self): + + for keepdim in [False, True]: + for dim in self.dims: + err_msg = f"Error for :\n dim = {dim}\n keepdim = {keepdim}" + + tensor = self.tensor.clone() + + # Prepare the arguments for numpy vs torch + if dim is not None: + torch_kwargs = {"dim": dim, "keepdim": keepdim} + numpy_kwargs = {"axis": dim, "keepdims": keepdim} + else: + torch_kwargs = {} + numpy_kwargs = {} + + # Compare the nan-mean + torch_mean = nan_mean(tensor, **torch_kwargs) + numpy_mean = np.nanmean(tensor.numpy(), **numpy_kwargs) + np.testing.assert_almost_equal(torch_mean.numpy(), numpy_mean, decimal=6, err_msg=err_msg) + + def test_nan_std_var(self): + + for unbiased in [True, False]: + for keepdim in [False, True]: + for dim in self.dims: + err_msg = f"Error for :\n\tdim = {dim}\n\tkeepdim = {keepdim}\n\tunbiased = {unbiased}" + + tensor = self.tensor.clone() + + # Prepare the arguments for numpy vs torch + if dim is not None: + torch_kwargs = {"dim": dim, "keepdim": keepdim, "unbiased": unbiased} + numpy_kwargs = {"axis": dim, "keepdims": keepdim, "ddof": float(unbiased)} + else: + torch_kwargs = {"unbiased": unbiased} + numpy_kwargs = {"ddof": float(unbiased)} + + # Compare the std + torch_std = nan_std(tensor, **torch_kwargs) + numpy_std = np.nanstd(tensor.numpy(), **numpy_kwargs) + np.testing.assert_almost_equal(torch_std.numpy(), numpy_std, decimal=6, err_msg=err_msg) + + # Compare the variance + torch_var = nan_var(tensor, **torch_kwargs) + numpy_var = np.nanvar(tensor.numpy(), **numpy_kwargs) + np.testing.assert_almost_equal(torch_var.numpy(), numpy_var, decimal=6, err_msg=err_msg) + + +if __name__ == "__main__": + ut.main()