diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md index a48e240a23..88f418fff4 100644 --- a/.github/ISSUE_TEMPLATE/bug.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -1,6 +1,7 @@ --- name: Bug Report about: Any general feedback or bug reports about the Vyper Compiler. No new features proposals. +labels: ["needs triage"] --- ### Version Information diff --git a/.github/ISSUE_TEMPLATE/vip.md b/.github/ISSUE_TEMPLATE/vip.md index b35a1e7c23..d32a8ac3de 100644 --- a/.github/ISSUE_TEMPLATE/vip.md +++ b/.github/ISSUE_TEMPLATE/vip.md @@ -1,6 +1,7 @@ --- name: Vyper Improvement Proposal (VIP) about: This is the suggested template for new VIPs. +labels: ["needs triage"] --- ## Simple Summary "If you can't explain it simply, you don't understand it well enough." Provide a simplified and layman-accessible explanation of the VIP. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7cb8de830c..5dd98413a7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,10 @@ jobs: # need to fetch unshallow so that setuptools_scm can infer the version fetch-depth: 0 + # debug + - name: Git shorthash + run: git rev-parse --short HEAD + - name: Python uses: actions/setup-python@v5 with: @@ -45,8 +49,9 @@ jobs: - name: Upload Artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: + name: vyper-${{ runner.os }} path: dist/vyper.* windows-build: @@ -60,6 +65,10 @@ jobs: # need to fetch unshallow so that setuptools_scm can infer the version fetch-depth: 0 + # debug + - name: Git shorthash + run: git rev-parse --short HEAD + - name: Python uses: actions/setup-python@v5 with: @@ -73,8 +82,9 @@ jobs: ./make.cmd freeze - name: Upload Artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: + name: vyper-${{ runner.os }} path: dist/vyper.* publish-release-assets: @@ -84,14 +94,13 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: path: artifacts/ + merge-multiple: true - name: Upload assets - # fun - artifacts are downloaded into "artifact/". - # TODO: this needs to be tested when upgrading to upload-artifact v4 - working-directory: artifacts/artifact + working-directory: artifacts run: | set -Eeuxo pipefail for BIN_NAME in $(ls) diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 4c4ff8a1df..75a8762d04 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -24,6 +24,7 @@ jobs: with: types: | feat + perf fix chore refactor @@ -31,7 +32,9 @@ jobs: # docs: documentation # test: test suite # lang: language changes + # stdlib: changes to the stdlib # ux: language changes (UX) + # parser: parser changes # tool: integration # ir: (old) IR/codegen changes # codegen: lowering from vyper AST to codegen @@ -42,7 +45,9 @@ jobs: docs test lang + stdlib ux + parser tool ir codegen diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index d09aeb9adc..1511c61e51 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -20,6 +20,14 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # fetch unshallow so commit hash matches github release. + # see https://github.com/vyperlang/vyper/blob/8f9a8cac49aafb3fbc9dde78f0f6125c390c32f0/.github/workflows/build.yml#L27-L32 + fetch-depth: 0 + + # debug + - name: Git shorthash + run: git rev-parse --short HEAD - name: Python uses: actions/setup-python@v5 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8d491e2530..d4bfc2ee9c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -116,6 +116,7 @@ jobs: # modes across all python versions - one is enough - python-version: ["3.10", "310"] - python-version: ["3.12", "312"] + - python-version: ["3.13", "313"] # os-specific rules - os: windows @@ -148,16 +149,17 @@ jobs: --evm-backend ${{ matrix.evm-backend || 'revm' }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} ${{ matrix.experimental-codegen && '--experimental-codegen' || '' }} - --cov-branch - --cov-report xml:coverage.xml + --cov-config=setup.cfg --cov=vyper tests/ - - name: Upload Coverage - uses: codecov/codecov-action@v3 + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 with: - token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml + name: coverage-files-${{ github.job }}-${{ strategy.job-index }} + include-hidden-files: true + path: .coverage + if-no-files-found: error core-tests-success: if: always() @@ -208,16 +210,17 @@ jobs: --splits 120 \ --group ${{ matrix.group }} \ --splitting-algorithm least_duration \ - --cov-branch \ - --cov-report xml:coverage.xml \ + --cov-config=setup.cfg \ --cov=vyper \ tests/ - - name: Upload Coverage - uses: codecov/codecov-action@v3 + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 with: - token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml + name: coverage-files-${{ github.job }}-${{ strategy.job-index }} + include-hidden-files: true + path: .coverage + if-no-files-found: error slow-tests-success: if: always() @@ -230,3 +233,73 @@ jobs: - name: Check slow tests all succeeded if: ${{ needs.fuzzing.result != 'success' }} run: exit 1 + + coverage-report: + # Consolidate code coverage using `coverage combine` and + # call coverage report with fail-under=90 + runs-on: ubuntu-latest + needs: [tests, fuzzing] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: "pip" + + - name: Install coverage + run: pip install coverage + + - name: Download coverage artifacts + uses: actions/download-artifact@v4 + with: + pattern: coverage-files-* + path: coverage-files + + - name: Combine coverage + run: | + coverage combine coverage-files/**/.coverage + + - name: Coverage report + # coverage report and fail if coverage is too low + run: | + coverage report --fail-under=90 + + - name: Generate coverage.xml + run: | + coverage xml + + - name: Upload coverage sqlite artifact + # upload coverage sqlite db for debugging + uses: actions/upload-artifact@v4 + with: + name: coverage-sqlite + include-hidden-files: true + path: .coverage + if-no-files-found: error + + - name: Upload coverage.xml + uses: actions/upload-artifact@v4 + with: + name: coverage-xml + path: coverage.xml + if-no-files-found: error + + upload-coverage: + # upload coverage to the codecov app + runs-on: ubuntu-latest + needs: [coverage-report] + + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: coverage-xml + + - name: Upload Coverage + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml diff --git a/FUNDING.json b/FUNDING.json new file mode 100644 index 0000000000..f4befa822d --- /dev/null +++ b/FUNDING.json @@ -0,0 +1,10 @@ +{ + "drips": { + "ethereum": { + "ownedBy": "0x70CCBE10F980d80b7eBaab7D2E3A73e87D67B775" + } + }, + "opRetro": { + "projectId": "0x9ca1f7b0e0d10d3bd2619e51a54f2e4175e029c87a2944cf1ebc89164ba77ea0" + } +} diff --git a/README.md b/README.md index bcaa50b570..827d40d549 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Build Status](https://github.com/vyperlang/vyper/workflows/Test/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/test.yml) [![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://docs.vyperlang.org/en/latest/?badge=latest "ReadTheDocs") [![Discord](https://img.shields.io/discord/969926564286459934.svg?label=%23vyper)](https://discord.gg/6tw7PTM7C2) +[![Telegram](https://img.shields.io/badge/Vyperholics🐍-Telegram-blue)](https://t.me/vyperlang) [![PyPI](https://badge.fury.io/py/vyper.svg)](https://pypi.org/project/vyper "PyPI") [![Docker](https://img.shields.io/docker/cloud/build/vyperlang/vyper)](https://hub.docker.com/r/vyperlang/vyper "DockerHub") @@ -49,7 +50,7 @@ be a bit behind the latest version found in the master branch of this repository ```bash make dev-init -python setup.py test +./quicktest.sh -m "not fuzzing" ``` ## Developing (working on the compiler) diff --git a/SECURITY.md b/SECURITY.md index 1a16f521d3..977a00f7b2 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -3,7 +3,7 @@ ## Supported Versions - it is recommended to follow the list of known [vulnerabilities](https://github.com/vyperlang/vyper/security/advisories) and stay up-to-date with the latest releases - - as of May 2024, the `0.4.0` release is the most secure and the most comprehensively reviewed one and is recommended for use in production environments + - as of May 2024, the [`0.4.0`](https://github.com/vyperlang/vyper/releases/tag/v0.4.0) release is the most comprehensively reviewed one and is recommended for use in production environments - if a compiler vulnerability is found, a new compiler version with a patch will be released. The vulnerable version itself is not updated (see the examples below). - `example1`: suppose `0.4.0` is the latest version and a hypothetical vulnerability is found in `0.4.0`, then a patch will be released in `0.4.1` - `example2`: suppose `0.4.0` is the latest version and a hypothetical vulnerability is found both in `0.3.10` and `0.4.0`, then a patch will be released only in `0.4.1` @@ -26,7 +26,22 @@ we will add an entry to the list of security advisories for posterity and refere ## Bug Bounty Program -- as of May 2024, Vyper does not have a bug bounty program. It is planned to instantiate one soon. +- Vyper runs a bug bounty program via the Ethereum Foundation. + - Bugs should be reported through the [Ethereum Foundation's bounty program](https://ethereum.org/bug-bounty). + +### Scope +- Rules from the Ethereum Foundation's bug bounty program apply; for any questions please reach out [here](mailto:bounty@ethereum.org). Here we further clarify the scope of the Vyper bounty program. +- If a compiler bug affects production code, it is in scope (excluding known issues). + - This includes bugs in older compiler versions still used in production. +- If a compiler bug does not currently affect production but is likely to in the future, it is in scope. + - This mainly applies to the latest compiler release (e.g., a new release is available but contracts are not yet deployed with it). + - Experimental features (e.g. `--experimental-codegen`) are out of scope, as they are not intended for production and are unlikely to affect production code. + - Bugs in older compiler versions are generally out of scope, as they are no longer used for new contracts. + - There might be exceptions, e.g., when an L2 doesn't support recent compiler releases. In such cases, it might be reasonable for an older version to be used. It is up to the discretion of the EF & Vyper team to decide if the bug is in scope. +- If a vulnerability affects multiple contracts, the whitehat is eligible for only one payout (though the severity of the bug may increase). +- Eligibility for project-specific bounties is independent of this bounty. +- [Security advisories](https://github.com/vyperlang/vyper/security/advisories) and [known issues](https://github.com/vyperlang/vyper/issues) are not eligible for the bounty program, as they are publicly disclosed and protocols should structure their contracts accordingly. +- Individuals or organizations contracted or engaged specifically for security development, auditing, or testing of this project are ineligible for the bounty program. ## Reporting a Vulnerability diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index a0e424adb4..7a79379d08 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -1023,7 +1023,7 @@ Utilities >>> ExampleContract.foo() 0xa9059cbb -.. py:function:: abi_encode(*args, ensure_tuple: bool = True) -> Bytes[] +.. py:function:: abi_encode(*args, ensure_tuple: bool = True, method_id: Bytes[4] = None) -> Bytes[] Takes a variable number of args as input, and returns the ABIv2-encoded bytestring. Used for packing arguments to raw_call, EIP712 and other cases where a consistent and efficient serialization method is needed. Once this function has seen more use we provisionally plan to put it into the ``ethereum.abi`` namespace. diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 751af980b2..c839e1e81d 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -31,7 +31,7 @@ Include the ``-f`` flag to specify which output formats to return. Use ``vyper - .. code:: shell - $ vyper -f abi,abi_python,bytecode,bytecode_runtime,blueprint_bytecode,interface,external_interface,ast,annotated_ast,integrity,ir,ir_json,ir_runtime,asm,opcodes,opcodes_runtime,source_map,source_map_runtime,archive,solc_json,method_identifiers,userdoc,devdoc,metadata,combined_json,layout yourFileName.vy + $ vyper -f abi,abi_python,bb,bb_runtime,bytecode,bytecode_runtime,blueprint_bytecode,cfg,cfg_runtime,interface,external_interface,ast,annotated_ast,integrity,ir,ir_json,ir_runtime,asm,opcodes,opcodes_runtime,source_map,source_map_runtime,archive,solc_json,method_identifiers,userdoc,devdoc,metadata,combined_json,layout yourFileName.vy .. note:: The ``opcodes`` and ``opcodes_runtime`` output of the compiler has been returning incorrect opcodes since ``0.2.0`` due to a lack of 0 padding (patched via `PR 3735 `_). If you rely on these functions for debugging, please use the latest patched versions. @@ -106,7 +106,7 @@ Online Compilers Try VyperLang! ----------------- -`Try VyperLang! `_ is a JupterHub instance hosted by the Vyper team as a sandbox for developing and testing contracts in Vyper. It requires github for login, and supports deployment via the browser. +`Try VyperLang! `_ is a JupyterHub instance hosted by the Vyper team as a sandbox for developing and testing contracts in Vyper. It requires github for login, and supports deployment via the browser. Remix IDE --------- @@ -134,6 +134,11 @@ In codesize optimized mode, the compiler will try hard to minimize codesize by * out-lining code, and * using more loops for data copies. +Enabling Experimental Code Generation +=========================== + +When compiling, you can use the CLI flag ``--experimental-codegen`` or its alias ``--venom`` to activate the new `Venom IR `_. +Venom IR is inspired by LLVM IR and enables new advanced analysis and optimizations. .. _evm-version: @@ -198,7 +203,7 @@ The following is a list of supported EVM versions, and changes in the compiler i Integrity Hash ============== -To help tooling detect whether two builds are the same, Vyper provides the ``-f integrity`` output, which outputs the integrity hash of a contract. The integrity hash is recursively defined as the sha256 of the source code with the integrity hashes of its dependencies (imports). +To help tooling detect whether two builds are the same, Vyper provides the ``-f integrity`` output, which outputs the integrity hash of a contract. The integrity hash is recursively defined as the sha256 of the source code with the integrity hashes of its dependencies (imports) and storage layout overrides (if provided). .. _vyper-archives: @@ -214,8 +219,9 @@ A Vyper archive is a compileable bundle of input sources and settings. Technical β”œβ”€β”€ compilation_targets β”œβ”€β”€ compiler_version β”œβ”€β”€ integrity + β”œβ”€β”€ settings.json β”œβ”€β”€ searchpaths - └── settings.json + └── storage_layout.json [OPTIONAL] * ``cli_settings.txt`` is a text representation of the settings that were used on the compilation run that generated this archive. * ``compilation_targets`` is a newline separated list of compilation targets. Currently only one compilation is supported @@ -223,6 +229,7 @@ A Vyper archive is a compileable bundle of input sources and settings. Technical * ``integrity`` is the :ref:`integrity hash ` of the input contract * ``searchpaths`` is a newline-separated list of the search paths used on this compilation run * ``settings.json`` is a json representation of the settings used on this compilation run. It is 1:1 with ``cli_settings.txt``, but both are provided as they are convenient for different workflows (typically, manually vs automated). +* ``storage_layout.json`` is a json representation of the storage layout overrides to be used on this compilation run. It is optional. A Vyper archive file can be produced by requesting the ``-f archive`` output format. The compiler can also produce the archive in base64 encoded form using the ``--base64`` flag. The Vyper compiler can accept both ``.vyz`` and base64-encoded Vyper archives directly as input. @@ -276,6 +283,14 @@ The following example describes the expected input format of ``vyper-json``. (Co } }, // Optional + // Storage layout overrides for the contracts that are compiled + "storage_layout_overrides": { + "contracts/foo.vy": { + "a": {"type": "uint256", "slot": 1, "n_slots": 1}, + "b": {"type": "uint256", "slot": 0, "n_slots": 1}, + } + }, + // Optional "settings": { "evmVersion": "cancun", // EVM version to compile for. Can be london, paris, shanghai or cancun (default). // optional, optimization mode @@ -308,10 +323,10 @@ The following example describes the expected input format of ``vyper-json``. (Co // devdoc - Natspec developer documentation // evm.bytecode.object - Bytecode object // evm.bytecode.opcodes - Opcodes list + // evm.bytecode.sourceMap - Source mapping (useful for debugging) // evm.deployedBytecode.object - Deployed bytecode object // evm.deployedBytecode.opcodes - Deployed opcodes list - // evm.deployedBytecode.sourceMap - Solidity-style source mapping - // evm.deployedBytecode.sourceMapFull - Deployed source mapping (useful for debugging) + // evm.deployedBytecode.sourceMap - Deployed source mapping (useful for debugging) // evm.methodIdentifiers - The list of function hashes // // Using `evm`, `evm.bytecode`, etc. will select every target part of that output. @@ -359,6 +374,13 @@ The following example describes the output format of ``vyper-json``. Comments ar "formattedMessage": "line 5:11 Unsupported type conversion: int128 to bool" } ], + // Optional: not present if there are no storage layout overrides + "storage_layout_overrides": { + "contracts/foo.vy": { + "a": {"type": "uint256", "slot": 1, "n_slots": 1}, + "b": {"type": "uint256", "slot": 0, "n_slots": 1}, + } + }, // This contains the file-level outputs. Can be limited/filtered by the outputSelection settings. "sources": { "source_file.vy": { @@ -388,15 +410,37 @@ The following example describes the output format of ``vyper-json``. Comments ar // The bytecode as a hex string. "object": "00fe", // Opcodes list (string) - "opcodes": "" + "opcodes": "", + // The deployed source mapping. + "sourceMap": { + "breakpoints": [], + "error_map": {}, + "pc_ast_map": {}, + "pc_ast_map_item_keys": [], + "pc_breakpoints": [], + "pc_jump_map": {}, + "pc_pos_map": {}, + // The deployed source mapping as a string. + "pc_pos_map_compressed": "" + } }, "deployedBytecode": { // The deployed bytecode as a hex string. "object": "00fe", // Deployed opcodes list (string) "opcodes": "", - // The deployed source mapping as a string. - "sourceMap": "" + // The deployed source mapping. + "sourceMap": { + "breakpoints": [], + "error_map": {}, + "pc_ast_map": {}, + "pc_ast_map_item_keys": [], + "pc_breakpoints": [], + "pc_jump_map": {}, + "pc_pos_map": {}, + // The deployed source mapping as a string. + "pc_pos_map_compressed": "" + } }, // The list of function hashes "methodIdentifiers": { diff --git a/docs/constants-and-vars.rst b/docs/constants-and-vars.rst index 0043f45d5d..bca1d70440 100644 --- a/docs/constants-and-vars.rst +++ b/docs/constants-and-vars.rst @@ -26,6 +26,7 @@ Name Type Value ``chain.id`` ``uint256`` Chain ID ``msg.data`` ``Bytes`` Message data ``msg.gas`` ``uint256`` Remaining gas +``msg.mana`` ``uint256`` Remaining gas (alias for ``msg.gas``) ``msg.sender`` ``address`` Sender of the message (current call) ``msg.value`` ``uint256`` Number of wei sent with the message ``tx.origin`` ``address`` Sender of the transaction (full call chain) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index d46e7a4a28..6304637728 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -48,7 +48,16 @@ External functions (marked with the ``@external`` decorator) are a part of the c A Vyper contract cannot call directly between two external functions. If you must do this, you can use an :ref:`interface `. .. note:: - For external functions with default arguments like ``def my_function(x: uint256, b: uint256 = 1)`` the Vyper compiler will generate ``N+1`` overloaded function selectors based on ``N`` default arguments. + For external functions with default arguments like ``def my_function(x: uint256, b: uint256 = 1)`` the Vyper compiler will generate ``N+1`` overloaded function selectors based on ``N`` default arguments. Consequently, the ABI signature for a function (this includes interface functions) excludes optional arguments when their default values are used in the function call. + + .. code-block:: vyper + + from ethereum.ercs import IERC4626 + + @external + def foo(x: IERC4626): + extcall x.withdraw(0, self, self) # keccak256("withdraw(uint256,address,address)")[:4] = 0xb460af94 + extcall x.withdraw(0) # keccak256("withdraw(uint256)")[:4] = 0x2e1a7d4d .. _structure-functions-internal: @@ -75,6 +84,14 @@ Or for internal functions which are defined in :ref:`imported modules ` def calculate(amount: uint256) -> uint256: return calculator_library._times_two(amount) +Marking an internal function as ``payable`` specifies that the function can interact with ``msg.value``. A ``nonpayable`` internal function can be called from an external ``payable`` function, but it cannot access ``msg.value``. + +.. code-block:: vyper + + @payable + def _foo() -> uint256: + return msg.value % 2 + .. note:: As of v0.4.0, the ``@internal`` decorator is optional. That is, functions with no visibility decorator default to being ``internal``. @@ -110,7 +127,7 @@ You can optionally declare a function's mutability by using a :ref:`decorator `_. +Vyper is a contract-oriented, Pythonic programming language that targets the `Ethereum Virtual Machine (EVM) `_. +It prioritizes user safety, encourages clear coding practices via language design and efficient execution. In other words, Vyper code is safe, clear and efficient! Principles and Goals ==================== * **Security**: It should be possible and natural to build secure smart-contracts in Vyper. * **Language and compiler simplicity**: The language and the compiler implementation should strive to be simple. -* **Auditability**: Vyper code should be maximally human-readable. Furthermore, it should be maximally difficult to write misleading code. Simplicity for the reader is more important than simplicity for the writer, and simplicity for readers with low prior experience with Vyper (and low prior experience with programming in general) is particularly important. +* **Auditability**: Vyper code should be maximally human-readable. + Furthermore, it should be maximally difficult to write misleading code. + Simplicity for the reader is more important than simplicity for the writer, and simplicity for readers with low prior experience with Vyper (and low prior experience with programming in general) is particularly important. Because of this Vyper provides the following features: diff --git a/docs/installing-vyper.rst b/docs/installing-vyper.rst index 8eaa93590a..0c7d54903f 100644 --- a/docs/installing-vyper.rst +++ b/docs/installing-vyper.rst @@ -7,37 +7,37 @@ any errors. .. note:: - The easiest way to experiment with the language is to use the `Remix online compiler `_. - (Activate the vyper-remix plugin in the Plugin manager.) + The easiest way to experiment with the language is to use either `Try Vyper! `_ (maintained by the Vyper team) or the `Remix online compiler `_ (maintained by the Ethereum Foundation). + - To use Try Vyper, go to https://try.vyperlang.org and log in (requires Github login). + - To use remix, go to https://remix.ethereum.org and activate the vyper-remix plugin in the Plugin manager. -Docker -****** -Vyper can be downloaded as docker image from `dockerhub `_: -:: +Standalone +********** - docker pull vyperlang/vyper +The Vyper CLI can be installed with any ``pip`` compatible tool, for example, ``pipx`` or ``uv tool``. If you do not have ``pipx`` or ``uv`` installed, first, go to the respective tool's installation page: -To run the compiler use the ``docker run`` command: -:: +- https://github.com/pypa/pipx?tab=readme-ov-file +- https://github.com/astral-sh/uv?tab=readme-ov-file#uv - docker run -v $(pwd):/code vyperlang/vyper /code/ +Then, the command to install Vyper would be -Alternatively you can log into the docker image and execute vyper on the prompt. :: - docker run -v $(pwd):/code/ -it --entrypoint /bin/bash vyperlang/vyper - root@d35252d1fb1b:/code# vyper + pipx install vyper + +Or, -The normal parameters are also supported, for example: :: - docker run -v $(pwd):/code vyperlang/vyper -f abi /code/ - [{'name': 'test1', 'outputs': [], 'inputs': [{'type': 'uint256', 'name': 'a'}, {'type': 'bytes', 'name': 'b'}], 'constant': False, 'payable': False, 'type': 'function', 'gas': 441}, {'name': 'test2', 'outputs': [], 'inputs': [{'type': 'uint256', 'name': 'a'}], 'constant': False, 'payable': False, 'type': 'function', 'gas': 316}] + uv tool install vyper -.. note:: - If you would like to know how to install Docker, please follow their `documentation `_. +Binaries +******** + +Alternatively, prebuilt Vyper binaries for Windows, Mac and Linux are available for download from the GitHub releases page: https://github.com/vyperlang/vyper/releases. + PIP *** @@ -45,12 +45,17 @@ PIP Installing Python ================= -Vyper can only be built using Python 3.6 and higher. If you need to know how to install the correct version of python, +Vyper can only be built using Python 3.10 and higher. If you need to know how to install the correct version of python, follow the instructions from the official `Python website `_. Creating a virtual environment ============================== +Because pip installations are not isolated by default, this method of +installation is meant for more experienced Python developers who are using +Vyper as a library, or want to use it within a Python project with other +pip dependencies. + It is **strongly recommended** to install Vyper in **a virtual Python environment**, so that new packages installed and dependencies built are strictly contained in your Vyper project and will not alter or affect your @@ -76,13 +81,43 @@ Each tagged version of vyper is uploaded to `pypi `_: +:: + + docker pull vyperlang/vyper + +To run the compiler use the ``docker run`` command: +:: + + docker run -v $(pwd):/code vyperlang/vyper /code/ + +Alternatively you can log into the docker image and execute vyper on the prompt. +:: + + docker run -v $(pwd):/code/ -it --entrypoint /bin/bash vyperlang/vyper + root@d35252d1fb1b:/code# vyper + +The normal parameters are also supported, for example: +:: + + docker run -v $(pwd):/code vyperlang/vyper -f abi /code/ + [{'name': 'test1', 'outputs': [], 'inputs': [{'type': 'uint256', 'name': 'a'}, {'type': 'bytes', 'name': 'b'}], 'constant': False, 'payable': False, 'type': 'function', 'gas': 441}, {'name': 'test2', 'outputs': [], 'inputs': [{'type': 'uint256', 'name': 'a'}], 'constant': False, 'payable': False, 'type': 'function', 'gas': 316}] + +.. note:: + + If you would like to know how to install Docker, please follow their `documentation `_. + nix *** diff --git a/docs/interfaces.rst b/docs/interfaces.rst index acc0ce91f3..22a0874fa7 100644 --- a/docs/interfaces.rst +++ b/docs/interfaces.rst @@ -85,10 +85,6 @@ The ``default_return_value`` parameter can be used to handle ERC20 tokens affect extcall IERC20(USDT).transfer(msg.sender, 1, default_return_value=True) # returns True extcall IERC20(USDT).transfer(msg.sender, 1) # reverts because nothing returned -.. warning:: - - When ``skip_contract_check=True`` is used and the called function returns data (ex.: ``x: uint256 = SomeContract.foo(skip_contract_check=True)``, no guarantees are provided by the compiler as to the validity of the returned value. In other words, it is undefined behavior what happens if the called contract did not exist. In particular, the returned value might point to garbage memory. It is therefore recommended to only use ``skip_contract_check=True`` to call contracts which have been manually ensured to exist at the time of the call. - Built-in Interfaces =================== @@ -124,6 +120,10 @@ This imports the defined interface from the vyper file at ``an_interface.vyi`` ( Prior to v0.4.0, ``implements`` required that events defined in an interface were re-defined in the "implementing" contract. As of v0.4.0, this is no longer required because events can be used just by importing them. Any events used in a contract will automatically be exported in the ABI output. +.. note:: + + An interface function with default parameters (e.g. ``deposit(assets: uint256, receiver: address = msg.sender)``) implies that the contract being interfaced with supports these default arguments via the ABI-encoded function signatures (e.g. ``keccak256("deposit(uint256,address)")[:4]`` and ``keccak256("deposit(uint256)")[:4]``). It is the responsibility of the callee to implement the behavior associated with these defaults. + Standalone Interfaces ===================== diff --git a/docs/release-notes.rst b/docs/release-notes.rst index c107ee5554..fa17ef4f7b 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -11,17 +11,319 @@ Release Notes :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/pull\/)(\d+)/(`#\2 <\1\2>`_)/g ex. in: https://github.com/vyperlang/vyper/pull/3373 ex. out: (`#3373 `_) + remove authorship slugs (leave them on github release page; they have no meaning outside of github though) + :'<,'>s/by @\S\+ //c for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g -v0.4.0b1 ("Nagini") -******************* +v0.4.0 ("Nagini") +***************** -Date released: TBD -================== +Date released: 2024-06-20 +========================= v0.4.0 represents a major overhaul to the Vyper language. Notably, it overhauls the import system and adds support for code reuse. It also adds a new, experimental backend to Vyper which lays the foundation for improved analysis, optimization and integration with third party tools. +Breaking Changes +---------------- +* feat[tool]!: make cancun the default evm version (`#4029 `_) +* feat[lang]: remove named reentrancy locks (`#3769 `_) +* feat[lang]!: change the signature of ``block.prevrandao`` (`#3879 `_) +* feat[lang]!: change ABI type of ``decimal`` to ``int168`` (`#3696 `_) +* feat[lang]: rename ``_abi_encode`` and ``_abi_decode`` (`#4097 `_) +* feat[lang]!: add feature flag for decimals (`#3930 `_) +* feat[lang]!: make internal decorator optional (`#4040 `_) +* feat[lang]: protect external calls with keyword (`#2938 `_) +* introduce floordiv, ban regular div for integers (`#2937 `_) +* feat[lang]: use keyword arguments for struct instantiation (`#3777 `_) +* feat: require type annotations for loop variables (`#3596 `_) +* feat: replace ``enum`` with ``flag`` keyword (`#3697 `_) +* feat: remove builtin constants (`#3350 `_) +* feat: drop istanbul and berlin support (`#3843 `_) +* feat: allow range with two arguments and bound (`#3679 `_) +* fix[codegen]: range bound check for signed integers (`#3814 `_) +* feat: default code offset = 3 (`#3454 `_) +* feat: rename ``vyper.interfaces`` to ``ethereum.ercs`` (`#3741 `_) +* chore: add prefix to ERC interfaces (`#3804 `_) +* chore[ux]: compute natspec as part of standard pipeline (`#3946 `_) +* feat: deprecate ``vyper-serve`` (`#3666 `_) + +Module system +------------- +* refactor: internal handling of imports (`#3655 `_) +* feat: implement "stateless" modules (`#3663 `_) +* feat[lang]: export interfaces (`#3919 `_) +* feat[lang]: singleton modules with ownership hierarchy (`#3729 `_) +* feat[lang]: implement function exports (`#3786 `_) +* feat[lang]: auto-export events in ABI (`#3808 `_) +* fix: allow using interface defs from imported modules (`#3725 `_) +* feat: add support for constants in imported modules (`#3726 `_) +* fix[lang]: prevent modules as storage variables (`#4088 `_) +* fix[ux]: improve initializer hint for unimported modules (`#4145 `_) +* feat: add python ``sys.path`` to vyper path (`#3763 `_) +* feat[ux]: improve error message for importing ERC20 (`#3816 `_) +* fix[lang]: fix importing of flag types (`#3871 `_) +* feat: search path resolution for cli (`#3694 `_) +* fix[lang]: transitive exports (`#3888 `_) +* fix[ux]: error messages relating to initializer issues (`#3831 `_) +* fix[lang]: recursion in ``uses`` analysis for nonreentrant functions (`#3971 `_) +* fix[ux]: fix ``uses`` error message (`#3926 `_) +* fix[lang]: fix ``uses`` analysis for nonreentrant functions (`#3927 `_) +* fix[lang]: fix a hint in global initializer check (`#4089 `_) +* fix[lang]: builtin type comparisons (`#3956 `_) +* fix[tool]: fix ``combined_json`` output for CLI (`#3901 `_) +* fix[tool]: compile multiple files (`#4053 `_) +* refactor: reimplement AST folding (`#3669 `_) +* refactor: constant folding (`#3719 `_) +* fix[lang]: typecheck hashmap indexes with folding (`#4007 `_) +* fix[lang]: fix array index checks when the subscript is folded (`#3924 `_) +* fix[lang]: pure access analysis (`#3895 `_) + +Venom +----- +* feat: implement new IR for vyper (venom IR) (`#3659 `_) +* feat[ir]: add ``make_ssa`` pass to venom pipeline (`#3825 `_) +* feat[venom]: implement ``mem2var`` and ``sccp`` passes (`#3941 `_) +* feat[venom]: add store elimination pass (`#4021 `_) +* feat[venom]: add ``extract_literals`` pass (`#4067 `_) +* feat[venom]: optimize branching (`#4049 `_) +* feat[venom]: avoid last ``swap`` for commutative ops (`#4048 `_) +* feat[venom]: "pickaxe" stack scheduler optimization (`#3951 `_) +* feat[venom]: add algebraic optimization pass (`#4054 `_) +* feat: Implement target constrained venom jump instruction (`#3687 `_) +* feat: remove ``deploy`` instruction from venom (`#3703 `_) +* fix[venom]: liveness analysis in some loops (`#3732 `_) +* feat: add more venom instructions (`#3733 `_) +* refactor[venom]: use venom pass instances (`#3908 `_) +* refactor[venom]: refactor venom operand classes (`#3915 `_) +* refactor[venom]: introduce ``IRContext`` and ``IRAnalysisCache`` (`#3983 `_) +* feat: add utility functions to ``OrderedSet`` (`#3833 `_) +* feat[venom]: optimize ``get_basic_block()`` (`#4002 `_) +* fix[venom]: fix branch eliminator cases in sccp (`#4003 `_) +* fix[codegen]: same symbol jumpdest merge (`#3982 `_) +* fix[venom]: fix eval of ``exp`` in sccp (`#4009 `_) +* refactor[venom]: remove unused method in ``make_ssa.py`` (`#4012 `_) +* fix[venom]: fix return opcode handling in mem2var (`#4011 `_) +* fix[venom]: fix ``cfg`` output format (`#4010 `_) +* chore[venom]: fix output formatting of data segment in ``IRContext`` (`#4016 `_) +* feat[venom]: optimize mem2var and store/variable elimination pass sequences (`#4032 `_) +* fix[venom]: fix some sccp evaluations (`#4028 `_) +* fix[venom]: add ``unique_symbols`` check to venom pipeline (`#4149 `_) +* feat[venom]: remove redundant store elimination pass (`#4036 `_) +* fix[venom]: remove some dead code in ``venom_to_assembly`` (`#4042 `_) +* feat[venom]: improve unused variable removal pass (`#4055 `_) +* fix[venom]: remove liveness requests (`#4058 `_) +* fix[venom]: fix list of volatile instructions (`#4065 `_) +* fix[venom]: remove dominator tree invalidation for store elimination pass (`#4069 `_) +* fix[venom]: move loop invariant assertion to entry block (`#4098 `_) +* fix[venom]: clear ``out_vars`` during calculation (`#4129 `_) +* fix[venom]: alloca for default arguments (`#4155 `_) +* Refactor ctx.add_instruction() and friends (`#3685 `_) +* fix: type annotation of helper function (`#3702 `_) +* feat[ir]: emit ``djump`` in dense selector table (`#3849 `_) +* chore: move venom tests to ``tests/unit/compiler`` (`#3684 `_) + +Other new features +------------------ +* feat[lang]: add ``blobhash()`` builtin (`#3962 `_) +* feat[lang]: support ``block.blobbasefee`` (`#3945 `_) +* feat[lang]: add ``revert_on_failure`` kwarg for create builtins (`#3844 `_) +* feat[lang]: allow downcasting of bytestrings (`#3832 `_) + +Docs +---- +* chore[docs]: add docs for v0.4.0 features (`#3947 `_) +* chore[docs]: ``implements`` does not check event declarations (`#4052 `_) +* docs: adopt a new theme: ``shibuya`` (`#3754 `_) +* chore[docs]: add evaluation order warning for builtins (`#4158 `_) +* Update ``FUNDING.yml`` (`#3636 `_) +* docs: fix nit in v0.3.10 release notes (`#3638 `_) +* docs: add note on ``pragma`` parsing (`#3640 `_) +* docs: retire security@vyperlang.org (`#3660 `_) +* feat[docs]: add more detail to modules docs (`#4087 `_) +* docs: update resources section (`#3656 `_) +* docs: add script to help working on the compiler (`#3674 `_) +* docs: add warnings at the top of all example token contracts (`#3676 `_) +* docs: typo in ``on_chain_market_maker.vy`` (`#3677 `_) +* docs: clarify ``address.codehash`` for empty account (`#3711 `_) +* docs: indexed arguments for events are limited (`#3715 `_) +* docs: Fix typos (`#3747 `_) +* docs: Upgrade dependencies and fixes (`#3745 `_) +* docs: add missing cli flags (`#3736 `_) +* chore: fix formatting and docs for new struct instantiation syntax (`#3792 `_) +* docs: floordiv (`#3797 `_) +* docs: add missing ``annotated_ast`` flag (`#3813 `_) +* docs: update logo in readme, remove competition reference (`#3837 `_) +* docs: add rationale for floordiv rounding behavior (`#3845 `_) +* chore[docs]: amend ``revert_on_failure`` kwarg docs for create builtins (`#3921 `_) +* fix[docs]: fix clipped ``endAuction`` method in example section (`#3969 `_) +* refactor[docs]: refactor security policy (`#3981 `_) +* fix: edit link to style guide (`#3658 `_) +* Add Vyper online compiler tooling (`#3680 `_) +* chore: fix typos (`#3749 `_) + +Bugfixes +-------- +* fix[codegen]: fix ``raw_log()`` when topics are non-literals (`#3977 `_) +* fix[codegen]: fix transient codegen for ``slice`` and ``extract32`` (`#3874 `_) +* fix[codegen]: bounds check for signed index accesses (`#3817 `_) +* fix: disallow ``value=`` passing for delegate and static raw_calls (`#3755 `_) +* fix[codegen]: fix double evals in sqrt, slice, blueprint (`#3976 `_) +* fix[codegen]: fix double eval in dynarray append/pop (`#4030 `_) +* fix[codegen]: fix double eval of start in range expr (`#4033 `_) +* fix[codegen]: overflow check in ``slice()`` (`#3818 `_) +* fix: concat buffer bug (`#3738 `_) +* fix[codegen]: fix ``make_setter`` overlap with internal calls (`#4037 `_) +* fix[codegen]: fix ``make_setter`` overlap in ``dynarray_append`` (`#4059 `_) +* fix[codegen]: ``make_setter`` overlap in the presence of ``staticcall`` (`#4128 `_) +* fix[codegen]: fix ``_abi_decode`` buffer overflow (`#3925 `_) +* fix[codegen]: zero-length dynarray ``abi_decode`` validation (`#4060 `_) +* fix[codegen]: recursive dynarray oob check (`#4091 `_) +* fix[codegen]: add back in ``returndatasize`` check (`#4144 `_) +* fix: block memory allocation overflow (`#3639 `_) +* fix[codegen]: panic on potential eval order issue for some builtins (`#4157 `_) +* fix[codegen]: panic on potential subscript eval order issue (`#4159 `_) +* add comptime check for uint2str input (`#3671 `_) +* fix: dead code analysis inside for loops (`#3731 `_) +* fix[ir]: fix a latent bug in ``sha3_64`` codegen (`#4063 `_) +* fix: ``opcodes`` and ``opcodes_runtime`` outputs (`#3735 `_) +* fix: bad assertion in expr.py (`#3758 `_) +* fix: iterator modification analysis (`#3764 `_) +* feat: allow constant interfaces (`#3718 `_) +* fix: assembly dead code eliminator (`#3791 `_) +* fix: prevent range over decimal (`#3798 `_) +* fix: mutability check for interface implements (`#3805 `_) +* fix[codegen]: fix non-memory reason strings (`#3877 `_) +* fix[ux]: fix compiler hang for large exponentiations (`#3893 `_) +* fix[lang]: allow type expressions inside pure functions (`#3906 `_) +* fix[ux]: raise ``VersionException`` with source info (`#3920 `_) +* fix[lang]: fix ``pow`` folding when args are not literals (`#3949 `_) +* fix[codegen]: fix some hardcoded references to ``STORAGE`` location (`#4015 `_) + +Patched security advisories (GHSAs) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* Bounds check on built-in ``slice()`` function can be overflowed (`GHSA-9x7f-gwxq-6f2c `_) +* ``concat`` built-in can corrupt memory (`GHSA-2q8v-3gqq-4f8p `_) +* ``raw_call`` ``value=`` kwargs not disabled for static and delegate calls (`GHSA-x2c2-q32w-4w6m `_) +* negative array index bounds checks (`GHSA-52xq-j7v9-v4v2 `_) +* ``range(start, start + N)`` reverts for negative numbers (`GHSA-ppx5-q359-pvwj `_) +* incorrect topic logging in ``raw_log`` (`GHSA-xchq-w5r3-4wg3 `_) +* double eval of the ``slice`` start/length args in certain cases (`GHSA-r56x-j438-vw5m `_) +* multiple eval of ``sqrt()`` built in argument (`GHSA-5jrj-52x8-m64h `_) +* double eval of raw_args in ``create_from_blueprint`` (`GHSA-3whq-64q2-qfj6 `_) +* ``sha3`` codegen bug (`GHSA-6845-xw22-ffxv `_) +* ``extract32`` can read dirty memory (`GHSA-4hwq-4cpm-8vmx `_) +* ``_abi_decode`` Memory Overflow (`GHSA-9p8r-4xp4-gw5w `_) +* External calls can overflow return data to return input buffer (`GHSA-gp3w-2v2m-p686 `_) + +Tooling +------- +* feat[tool]: archive format (`#3891 `_) +* feat[tool]: add source map for constructors (`#4008 `_) +* feat: add short options ``-v`` and ``-O`` to the CLI (`#3695 `_) +* feat: Add ``bb`` and ``bb_runtime`` output options (`#3700 `_) +* fix: remove hex-ir from format cli options list (`#3657 `_) +* fix: pickleability of ``CompilerData`` (`#3803 `_) +* feat[tool]: validate AST nodes early in the pipeline (`#3809 `_) +* feat[tool]: delay global constraint check (`#3810 `_) +* feat[tool]: export variable read/write access (`#3790 `_) +* feat[tool]: improvements to AST annotation (`#3829 `_) +* feat[tool]: add ``node_id`` map to source map (`#3811 `_) +* chore[tool]: add help text for ``hex-ir`` CLI flag (`#3942 `_) +* refactor[tool]: refactor storage layout export (`#3789 `_) +* fix[tool]: fix cross-compilation issues, add windows CI (`#4014 `_) +* fix[tool]: star option in ``outputSelection`` (`#4094 `_) + +Performance +----------- +* perf: lazy eval of f-strings in IRnode ctor (`#3602 `_) +* perf: levenshtein optimization (`#3780 `_) +* feat: frontend optimizations (`#3781 `_) +* feat: optimize ``VyperNode.deepcopy`` (`#3784 `_) +* feat: more frontend optimizations (`#3785 `_) +* perf: reimplement ``IRnode.__deepcopy__`` (`#3761 `_) + +Testing suite improvements +-------------------------- +* refactor[test]: bypass ``eth-tester`` and interface with evm backend directly (`#3846 `_) +* feat: Refactor assert_tx_failed into a context (`#3706 `_) +* feat[test]: implement ``abi_decode`` spec test (`#4095 `_) +* feat[test]: add more coverage to ``abi_decode`` fuzzer tests (`#4153 `_) +* feat[ci]: enable cancun testing (`#3861 `_) +* fix: add missing test for memory allocation overflow (`#3650 `_) +* chore: fix test for ``slice`` (`#3633 `_) +* add abi_types unit tests (`#3662 `_) +* refactor: test directory structure (`#3664 `_) +* chore: test all output formats (`#3683 `_) +* chore: deduplicate test files (`#3773 `_) +* feat[test]: add more transient storage tests (`#3883 `_) +* chore[ci]: fix apt-get failure in era pipeline (`#3821 `_) +* chore[ci]: enable python3.12 tests (`#3860 `_) +* chore[ci]: refactor jobs to use gh actions (`#3863 `_) +* chore[ci]: use ``--dist worksteal`` from latest ``xdist`` (`#3869 `_) +* chore: run mypy as part of lint rule in Makefile (`#3771 `_) +* chore[test]: always specify the evm backend (`#4006 `_) +* chore: update lint dependencies (`#3704 `_) +* chore: add color to mypy output (`#3793 `_) +* chore: remove tox rules for lint commands (`#3826 `_) +* chore[ci]: roll back GH actions/artifacts version (`#3838 `_) +* chore: Upgrade GitHub action dependencies (`#3807 `_) +* chore[ci]: pin eth-abi for decode regression (`#3834 `_) +* fix[ci]: release artifacts (`#3839 `_) +* chore[ci]: merge mypy job into lint (`#3840 `_) +* test: parametrize CI over EVM versions (`#3842 `_) +* feat[ci]: add PR title validation (`#3887 `_) +* fix[test]: fix failure in grammar fuzzing (`#3892 `_) +* feat[test]: add ``xfail_strict``, clean up ``setup.cfg`` (`#3889 `_) +* fix[ci]: pin hexbytes to pre-1.0.0 (`#3903 `_) +* chore[test]: update hexbytes version and tests (`#3904 `_) +* fix[test]: fix a bad bound in decimal fuzzing (`#3909 `_) +* fix[test]: fix a boundary case in decimal fuzzing (`#3918 `_) +* feat[ci]: update pypi release pipeline to use OIDC (`#3912 `_) +* chore[ci]: reconfigure single commit validation (`#3937 `_) +* chore[ci]: downgrade codecov action to v3 (`#3940 `_) +* feat[ci]: add codecov configuration (`#4057 `_) +* feat[test]: remove memory mocker (`#4005 `_) +* refactor[test]: change fixture scope in examples (`#3995 `_) +* fix[test]: fix call graph stability fuzzer (`#4064 `_) +* chore[test]: add macos to test matrix (`#4025 `_) +* refactor[test]: change default expected exception type (`#4004 `_) + +Misc / refactor +--------------- +* feat[ir]: add ``eval_once`` sanity fences to more builtins (`#3835 `_) +* fix: reorder compilation of branches in stmt.py (`#3603 `_) +* refactor[codegen]: make settings into a global object (`#3929 `_) +* chore: improve exception handling in IR generation (`#3705 `_) +* refactor: merge ``annotation.py`` and ``local.py`` (`#3456 `_) +* chore[ux]: remove deprecated python AST classes (`#3998 `_) +* refactor[ux]: remove deprecated ``VyperNode`` properties (`#3999 `_) +* feat: remove Index AST node (`#3757 `_) +* refactor: for loop target parsing (`#3724 `_) +* chore: improve diagnostics for invalid for loop annotation (`#3721 `_) +* refactor: builtin functions inherit from ``VyperType`` (`#3559 `_) +* fix: remove .keyword from Call AST node (`#3689 `_) +* improvement: assert descriptions in Crowdfund finalize() and participate() (`#3064 `_) +* feat: improve panics in IR generation (`#3708 `_) +* feat: improve warnings, refactor ``vyper_warn()`` (`#3800 `_) +* fix[ir]: unique symbol name (`#3848 `_) +* refactor: remove duplicate terminus checking code (`#3541 `_) +* refactor: ``ExprVisitor`` type validation (`#3739 `_) +* chore: improve exception for type validation (`#3759 `_) +* fix: fuzz test not updated to use TypeMismatch (`#3768 `_) +* chore: fix StringEnum._generate_next_value_ signature (`#3770 `_) +* chore: improve some error messages (`#3775 `_) +* refactor: ``get_search_paths()`` for vyper cli (`#3778 `_) +* chore: replace occurrences of 'enum' by 'flag' (`#3794 `_) +* chore: add another borrowship test (`#3802 `_) +* chore[ux]: improve an exports error message (`#3822 `_) +* chore: improve codegen test coverage report (`#3824 `_) +* chore: improve syntax error messages (`#3885 `_) +* chore[tool]: remove ``vyper-serve`` from ``setup.py`` (`#3936 `_) +* fix[ux]: replace standard strings with f-strings (`#3953 `_) +* chore[ir]: sanity check types in for range codegen (`#3968 `_) + v0.3.10 ("Black Adder") *********************** diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index fc817cf4b6..7e599d677b 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -54,6 +54,16 @@ EVM Version The EVM version can be set with the ``evm-version`` pragma, which is documented in :ref:`evm-version`. +Experimental Code Generation +----------------- +The new experimental code generation feature can be activated using the following directive: + +.. code-block:: vyper + + #pragma experimental-codegen + +Alternatively, you can use the alias ``"venom"`` instead of ``"experimental-codegen"`` to enable this feature. + Imports ======= diff --git a/docs/types.rst b/docs/types.rst index 752e06b14f..807c83848f 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -359,11 +359,12 @@ A byte array with a max size. The syntax being ``Bytes[maxLen]``, where ``maxLen`` is an integer which denotes the maximum number of bytes. On the ABI level the Fixed-size bytes array is annotated as ``bytes``. -Bytes literals may be given as bytes strings. +Bytes literals may be given as bytes strings or as hex strings. .. code-block:: vyper bytes_string: Bytes[100] = b"\x01" + bytes_string: Bytes[100] = x"01" .. index:: !string diff --git a/docs/using-modules.rst b/docs/using-modules.rst index 7d63eb6617..4400a8dfa8 100644 --- a/docs/using-modules.rst +++ b/docs/using-modules.rst @@ -62,6 +62,21 @@ The ``_times_two()`` helper function in the above module can be immediately used The other functions cannot be used yet, because they touch the ``ownable`` module's state. There are two ways to declare a module so that its state can be used. +Using a module as an interface +============================== + +A module can be used as an interface with the ``__at__`` syntax. + +.. code-block:: vyper + + import ownable + + an_ownable: ownable.__interface__ + + def call_ownable(addr: address): + self.an_ownable = ownable.__at__(addr) + self.an_ownable.transfer_ownership(...) + Initializing a module ===================== diff --git a/quicktest.sh b/quicktest.sh index cd3aad1f15..af928f5d7c 100755 --- a/quicktest.sh +++ b/quicktest.sh @@ -2,8 +2,17 @@ # examples: # ./quicktest.sh +# ./quicktest.sh -m "not fuzzing" +# ./quicktest.sh -m "not fuzzing" -n (this is the most useful) +# ./quicktest.sh -m "not fuzzing" -n0 # ./quicktest.sh tests/.../mytest.py # run pytest but bail out on first error -# useful for dev workflow +# useful for dev workflow. + pytest -q -s --instafail -x --disable-warnings "$@" + +# useful options include: +# -n0 (uses only one core but faster startup) +# -nauto (uses only one core but faster startup) +# -m "not fuzzing" - skip slow/fuzzing tests diff --git a/setup.cfg b/setup.cfg index 5998961ee8..4cce85034d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,3 +33,15 @@ markers = fuzzing: Run Hypothesis fuzz test suite (deselect with '-m "not fuzzing"') requires_evm_version(version): Mark tests that require at least a specific EVM version and would throw `EvmVersionException` otherwise venom_xfail: mark a test case as a regression (expected to fail) under the venom pipeline + + +[coverage:run] +branch = True +source = vyper + +# this is not available on the CI step that performs `coverage combine` +omit = vyper/version.py + +# allow `coverage combine` to combine reports from heterogeneous OSes. +# (mainly important for consolidating coverage reports in the CI). +relative_files = True diff --git a/setup.py b/setup.py index 6e48129cba..e6d4c5763d 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ extras_require["dev"] = extras_require["dev"] + extras_require["test"] + extras_require["lint"] -with open("README.md", "r") as f: +with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() @@ -94,17 +94,18 @@ def _global_version(version): "asttokens>=2.0.5,<3", "pycryptodome>=3.5.1,<4", "packaging>=23.1,<24", + "lark>=1.0.0,<2", "importlib-metadata", "wheel", ], - setup_requires=["pytest-runner", "setuptools_scm>=7.1.0,<8.0.0"], - tests_require=extras_require["test"], + setup_requires=["setuptools_scm>=7.1.0,<8.0.0"], extras_require=extras_require, entry_points={ "console_scripts": [ "vyper=vyper.cli.vyper_compile:_parse_cli_args", "fang=vyper.cli.vyper_ir:_parse_cli_args", "vyper-json=vyper.cli.vyper_json:_parse_cli_args", + "venom=vyper.cli.venom_main:_parse_cli_args", ] }, classifiers=[ @@ -113,6 +114,7 @@ def _global_version(version): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={"vyper.ast": ["grammar.lark"]}, data_files=[("", [hash_file_rel_path])], diff --git a/tests/ast_utils.py b/tests/ast_utils.py new file mode 100644 index 0000000000..e4be35adb2 --- /dev/null +++ b/tests/ast_utils.py @@ -0,0 +1,25 @@ +from vyper.ast.nodes import VyperNode + + +def deepequals(node: VyperNode, other: VyperNode): + # checks two nodes are recursively equal, ignoring metadata + # like line info. + if not isinstance(other, type(node)): + return False + + if isinstance(node, list): + if len(node) != len(other): + return False + return all(deepequals(a, b) for a, b in zip(node, other)) + + if not isinstance(node, VyperNode): + return node == other + + if getattr(node, "node_id", None) != getattr(other, "node_id", None): + return False + for field_name in (i for i in node.get_fields() if i not in VyperNode.__slots__): + lhs = getattr(node, field_name, None) + rhs = getattr(other, field_name, None) + if not deepequals(lhs, rhs): + return False + return True diff --git a/tests/conftest.py b/tests/conftest.py index 31c72246bd..76ebc2df22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from tests.utils import working_directory from vyper import compiler from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.settings import OptimizationLevel, Settings, set_global_settings from vyper.exceptions import EvmVersionException from vyper.ir import compile_ir, optimizer @@ -166,12 +166,6 @@ def fn(sources_dict): return fn -# for tests which just need an input bundle, doesn't matter what it is -@pytest.fixture -def dummy_input_bundle(): - return InputBundle([]) - - @pytest.fixture(scope="module") def gas_limit(): # set absurdly high gas limit so that london basefee never adjusts diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 9ae869c9cc..475118c7e3 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -8,6 +8,8 @@ TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() +BUFFER_OVERHEAD = 4 + 2 * 32 + def test_abi_decode_complex(get_contract): contract = """ @@ -474,8 +476,10 @@ def test_abi_decode_length_mismatch(get_contract, assert_compile_failed, bad_cod assert_compile_failed(lambda: get_contract(bad_code), exception) -def _abi_payload_from_tuple(payload: tuple[int | bytes, ...]) -> bytes: - return b"".join(p.to_bytes(32, "big") if isinstance(p, int) else p for p in payload) +def _abi_payload_from_tuple(payload: tuple[int | bytes, ...], max_sz: int) -> bytes: + ret = b"".join(p.to_bytes(32, "big") if isinstance(p, int) else p for p in payload) + assert len(ret) <= max_sz + return ret def _replicate(value: int, count: int) -> tuple[int, ...]: @@ -486,11 +490,12 @@ def test_abi_decode_arithmetic_overflow(env, tx_failed, get_contract): # test based on GHSA-9p8r-4xp4-gw5w: # https://github.com/vyperlang/vyper/security/advisories/GHSA-9p8r-4xp4-gw5w#advisory-comment-91841 # buf + head causes arithmetic overflow - code = """ + buffer_size = 32 * 3 + code = f""" @external -def f(x: Bytes[32 * 3]): +def f(x: Bytes[{buffer_size}]): a: Bytes[32] = b"foo" - y: Bytes[32 * 3] = x + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[32] = _abi_decode(y, Bytes[32]) a = b"bar" @@ -500,39 +505,47 @@ def f(x: Bytes[32 * 3]): """ c = get_contract(code) - data = method_id("f(bytes)") - payload = ( - 0x20, # tuple head - 0x60, # parent array length - # parent payload - this word will be considered as the head of the abi-encoded inner array - # and it will be added to base ptr leading to an arithmetic overflow - 2**256 - 0x60, - ) - data += _abi_payload_from_tuple(payload) + tuple_head_ofst = 0x20 + parent_array_len = 0x60 + msg_call_overhead = (method_id("f(bytes)"), tuple_head_ofst, parent_array_len) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + # parent payload - this word will be considered as the head of the + # abi-encoded inner array and it will be added to base ptr leading to an + # arithmetic overflow + buffer_payload = (2**256 - 0x60,) + + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) -def test_abi_decode_nonstrict_head(env, tx_failed, get_contract): +def test_abi_decode_nonstrict_head(env, get_contract): # data isn't strictly encoded - head is 0x21 instead of 0x20 # but the head + length is still within runtime bounds of the parent buffer - code = """ + buffer_size = 32 * 5 + code = f""" @external -def f(x: Bytes[32 * 5]): - y: Bytes[32 * 5] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x a: Bytes[32] = b"a" decoded_y1: DynArray[uint256, 3] = _abi_decode(y, DynArray[uint256, 3]) + assert len(decoded_y1) == 1 and decoded_y1[0] == 0 a = b"aaaa" decoded_y1 = _abi_decode(y, DynArray[uint256, 3]) + assert len(decoded_y1) == 1 and decoded_y1[0] == 0 """ c = get_contract(code) - data = method_id("f(bytes)") + tuple_head_ofst = 0x20 + parent_array_len = 0xA0 + msg_call_overhead = (method_id("f(bytes)"), tuple_head_ofst, parent_array_len) - payload = ( - 0x20, # tuple head - 0xA0, # parent array length + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( # head should be 0x20 but is 0x21 thus the data isn't strictly encoded 0x21, # we don't want to revert on invalid length, so set this to 0 @@ -543,27 +556,30 @@ def f(x: Bytes[32 * 5]): *_replicate(0x03, 2), ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) env.message_call(c.address, data=data) def test_abi_decode_child_head_points_to_parent(tx_failed, get_contract): # data isn't strictly encoded and the head for the inner array - # skipts the corresponding payload and points to other valid section of the parent buffer - code = """ + # skips the corresponding payload and points to other valid section of the + # parent buffer + buffer_size = 14 * 32 + code = f""" @external -def run(x: Bytes[14 * 32]): - y: Bytes[14 * 32] = x +def run(x: Bytes[{buffer_size}]) -> DynArray[DynArray[DynArray[uint256, 2], 1], 2]: + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[DynArray[uint256, 2], 1], 2] = _abi_decode( y, DynArray[DynArray[DynArray[uint256, 2], 1], 2] ) + return decoded_y1 """ c = get_contract(code) # encode [[[1, 1]], [[2, 2]]] and modify the head for [1, 1] # to actually point to [2, 2] - payload = ( + buffer_payload = ( 0x20, # top-level array head 0x02, # top-level array length 0x40, # head of DAr[DAr[DAr, uint256]]][0] @@ -582,30 +598,33 @@ def run(x: Bytes[14 * 32]): 0x02, # DAr[DAr[DAr, uint256]]][1][0][1] ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - c.run(data) + res = c.run(data) + assert res == [[[2, 2]], [[2, 2]]] def test_abi_decode_nonstrict_head_oob(tx_failed, get_contract): # data isn't strictly encoded and (non_strict_head + len(DynArray[..][2])) > parent_static_sz # thus decoding the data pointed to by the head would cause an OOB read # non_strict_head + length == parent + parent_static_sz + 1 - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length - # non_strict_head - if the length pointed to by this head is 0x60 (which is valid - # length for the Bytes[32*3] buffer), the decoding function would decode - # 1 byte over the end of the buffer - # we define the non_strict_head as: skip the remaining heads, 1st and 2nd tail + # non_strict_head - if the length pointed to by this head is 0x60 + # (which is valid length for the Bytes[32*3] buffer), the decoding + # function would decode 1 byte over the end of the buffer + # we define the non_strict_head as: + # skip the remaining heads, 1st and 2nd tail # to the third tail + 1B 0x20 * 8 + 0x20 * 3 + 0x01, # inner array0 head 0x20 * 4 + 0x20 * 3, # inner array1 head @@ -622,7 +641,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x03, 2), ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -631,10 +650,11 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_nonstrict_head_oob2(tx_failed, get_contract): # same principle as in Test_abi_decode_nonstrict_head_oob # but adapted for dynarrays - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -642,7 +662,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length (0x20 * 8 + 0x20 * 3 + 0x01), # inner array0 head @@ -658,7 +678,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 2), # DynArray[..][2] data ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -666,33 +686,36 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_head_pointing_outside_buffer(tx_failed, get_contract): # the head points completely outside the buffer - code = """ + buffer_size = 3 * 32 + code = f""" @external -def run(x: Bytes[3 * 32]): - y: Bytes[3 * 32] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[32] = _abi_decode(y, Bytes[32]) """ c = get_contract(code) - payload = (0x80, 0x20, 0x01) - data = _abi_payload_from_tuple(payload) + buffer_payload = (0x80, 0x20, 0x01) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_bytearray_clamp(tx_failed, get_contract): - # data has valid encoding, but the length of DynArray[Bytes[96], 3][0] is set to 0x61 + # data has valid encoding, but the length of DynArray[Bytes[96], 3][0] is + # set to 0x61 # and thus the decoding should fail on bytestring clamp - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -707,32 +730,38 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[Bytes[96], 3][2] data ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_runtimesz_oob(tx_failed, get_contract, env): - # provide enough data, but set the runtime size to be smaller than the actual size - # so after y: [..] = x, y will have the incorrect size set and only part of the - # original data will be copied. This will cause oob read outside the - # runtime sz (but still within static size of the buffer) - code = """ + # provide enough data, but set the runtime size to be smaller than the + # actual size so after y: [..] = x, y will have the incorrect size set and + # only part of the original data will be copied. This will cause oob read + # outside the runtime sz (but still within static size of the buffer) + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - data = method_id("f(bytes)") - - payload = ( + msg_call_overhead = ( + method_id("f(bytes)"), 0x20, # tuple head # the correct size is 0x220 (2*32+3*32+4*3*32) - # therefore we will decode after the end of runtime size (but still within the buffer) + # therefore we will decode after the end of runtime size (but still + # within the buffer) 0x01E4, # top-level bytes array length + ) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -746,7 +775,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[Bytes[96], 3][2] data ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) @@ -755,10 +784,11 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_runtimesz_oob2(tx_failed, get_contract, env): # same principle as in test_abi_decode_runtimesz_oob # but adapted for dynarrays - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -766,11 +796,15 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - data = method_id("f(bytes)") - - payload = ( + msg_call_overhead = ( + method_id("f(bytes)"), 0x20, # tuple head 0x01E4, # top-level bytes array length + ) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -784,7 +818,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[..][2] data ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) @@ -796,11 +830,13 @@ def test_abi_decode_head_roundtrip(tx_failed, get_contract, env): # which are in turn in the y2 buffer # NOTE: the test is memory allocator dependent - we assume that y1 and y2 # have the 800 & 960 addresses respectively - code = """ + buffer_size1 = 4 * 32 + buffer_size2 = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y1: Bytes[4*32] = x1 # addr: 800 - y2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x2 # addr: 960 +def run(x1: Bytes[{buffer_size1}], x2: Bytes[{buffer_size2}]): + y1: Bytes[{buffer_size1}] = x1 # addr: 800 + y2: Bytes[{buffer_size2}] = x2 # addr: 960 decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y2, DynArray[DynArray[uint256, 3], 3] @@ -808,7 +844,7 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x03, # DynArray length # distance to y2 from y1 is 160 160 + 0x20 + 0x20 * 3, # points to DynArray[..][0] length @@ -816,9 +852,9 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): 160 + 0x20 + 0x20 * 8 + 0x20 * 3, # points to DynArray[..][2] length ) - data1 = _abi_payload_from_tuple(payload) + data1 = _abi_payload_from_tuple(buffer_payload, buffer_size1) - payload = ( + buffer_payload = ( # (960 + (2**256 - 160)) % 2**256 == 800, ie will roundtrip to y1 2**256 - 160, # points to y1 0x03, # DynArray length (not used) @@ -833,7 +869,7 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x03, 3), # DynArray[..][2] data ) - data2 = _abi_payload_from_tuple(payload) + data2 = _abi_payload_from_tuple(buffer_payload, buffer_size2) with tx_failed(): c.run(data1, data2) @@ -841,22 +877,23 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_merge_head_and_length(get_contract): # compress head and length into 33B - code = """ + buffer_size = 32 * 2 + 8 * 32 + code = f""" @external -def run(x: Bytes[32 * 2 + 8 * 32]) -> uint256: - y: Bytes[32 * 2 + 8 * 32] = x +def run(x: Bytes[{buffer_size}]) -> Bytes[{buffer_size}]: + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[256] = _abi_decode(y, Bytes[256]) - return len(decoded_y1) + return decoded_y1 """ c = get_contract(code) - payload = (0x01, (0x00).to_bytes(1, "big"), *_replicate(0x00, 8)) + buffer_payload = (0x01, (0x00).to_bytes(1, "big"), *_replicate(0x00, 8)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - length = c.run(data) + res = c.run(data) - assert length == 256 + assert res == bytes(256) def test_abi_decode_extcall_invalid_head(tx_failed, get_contract): @@ -880,8 +917,8 @@ def foo(): def test_abi_decode_extcall_oob(tx_failed, get_contract): # the head returned from the extcall is 1 byte bigger than expected - # thus we'll take the last 31 0-bytes from tuple[1] and the 1st byte from tuple[2] - # and consider this the length - thus the length is 2**5 + # thus we'll take the last 31 0-bytes from tuple[1] and the 1st byte from + # tuple[2] and consider this the length - thus the length is 2**5 # and thus we'll read 1B over the buffer end (33 + 32 + 32) code = """ @external @@ -902,7 +939,8 @@ def foo(): def test_abi_decode_extcall_runtimesz_oob(tx_failed, get_contract): # the runtime size (33) is bigger than the actual payload (32 bytes) - # thus we'll read 1B over the runtime size - but still within the static size of the buffer + # thus we'll read 1 byte over the runtime size - but still within the + # static size of the buffer code = """ @external def bar() -> (uint256, uint256, uint256): @@ -932,11 +970,13 @@ def bar() -> (uint256, uint256, uint256, uint256): def bar() -> Bytes[32]: nonpayable @external -def foo(): - x:Bytes[32] = extcall A(self).bar() +def foo() -> Bytes[32]: + return extcall A(self).bar() """ c = get_contract(code) - c.foo() + res = c.foo() + + assert res == (36).to_bytes(32, "big") def test_abi_decode_extcall_truncate_returndata2(tx_failed, get_contract): @@ -1053,12 +1093,14 @@ def bar() -> (uint256, uint256): def bar() -> DynArray[Bytes[32], 2]: nonpayable @external -def run(): - x: DynArray[Bytes[32], 2] = extcall A(self).bar() +def run() -> DynArray[Bytes[32], 2]: + return extcall A(self).bar() """ c = get_contract(code) - c.run() + res = c.run() + + assert res == [] def test_abi_decode_extcall_complex_empty_dynarray(get_contract): @@ -1079,13 +1121,14 @@ def bar() -> (uint256, uint256, uint256, uint256, uint256, uint256): def bar() -> DynArray[Point, 2]: nonpayable @external -def run(): - x: DynArray[Point, 2] = extcall A(self).bar() - assert len(x) == 1 and len(x[0].y) == 0 +def run() -> DynArray[Point, 2]: + return extcall A(self).bar() """ c = get_contract(code) - c.run() + res = c.run() + + assert res == [(1, [], 0)] def test_abi_decode_extcall_complex_empty_dynarray2(tx_failed, get_contract): @@ -1124,21 +1167,21 @@ def bar() -> (uint256, uint256): def bar() -> DynArray[Bytes[32], 2]: nonpayable @external -def run() -> uint256: - x: DynArray[Bytes[32], 2] = extcall A(self).bar() - return len(x) +def run() -> DynArray[Bytes[32], 2]: + return extcall A(self).bar() """ c = get_contract(code) - length = c.run() + res = c.run() - assert length == 0 + assert res == [] def test_abi_decode_top_level_head_oob(tx_failed, get_contract): - code = """ + buffer_size = 256 + code = f""" @external -def run(x: Bytes[256], y: uint256): +def run(x: Bytes[{buffer_size}], y: uint256): player_lost: bool = empty(bool) if y == 1: @@ -1150,9 +1193,9 @@ def run(x: Bytes[256], y: uint256): c = get_contract(code) # head points over the buffer end - payload = (0x0100, *_replicate(0x00, 7)) + bufffer_payload = (0x0100, *_replicate(0x00, 7)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(bufffer_payload, buffer_size) with tx_failed(): c.run(data, 1) @@ -1162,23 +1205,24 @@ def run(x: Bytes[256], y: uint256): def test_abi_decode_dynarray_complex_insufficient_data(env, tx_failed, get_contract): - code = """ + buffer_size = 32 * 8 + code = f""" struct Point: x: uint256 y: uint256 @external -def run(x: Bytes[32 * 8]): - y: Bytes[32 * 8] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Point, 3] = _abi_decode(y, DynArray[Point, 3]) """ c = get_contract(code) # runtime buffer has insufficient size - we decode 3 points, but provide only # 3 * 32B of payload - payload = (0x20, 0x03, *_replicate(0x03, 3)) + buffer_payload = (0x20, 0x03, *_replicate(0x03, 3)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -1187,7 +1231,8 @@ def run(x: Bytes[32 * 8]): def test_abi_decode_dynarray_complex2(env, tx_failed, get_contract): # point head to the 1st 0x01 word (ie the length) # but size of the point is 3 * 32B, thus we'd decode 2B over the buffer end - code = """ + buffer_size = 32 * 8 + code = f""" struct Point: x: uint256 y: uint256 @@ -1195,19 +1240,19 @@ def test_abi_decode_dynarray_complex2(env, tx_failed, get_contract): @external -def run(x: Bytes[32 * 8]): +def run(x: Bytes[{buffer_size}]): y: Bytes[32 * 11] = x decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0xC0, # points to the 1st 0x01 word (ie the length) *_replicate(0x03, 5), *_replicate(0x01, 2), ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -1216,7 +1261,8 @@ def run(x: Bytes[32 * 8]): def test_abi_decode_complex_empty_dynarray(env, tx_failed, get_contract): # point head to the last word of the payload # this will be the length, but because it's set to 0, the decoding should succeed - code = """ + buffer_size = 32 * 16 + code = f""" struct Point: x: uint256 y: DynArray[uint256, 2] @@ -1224,14 +1270,13 @@ def test_abi_decode_complex_empty_dynarray(env, tx_failed, get_contract): @external -def run(x: Bytes[32 * 16]): - y: Bytes[32 * 16] = x - decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) - assert len(decoded_y1) == 1 and len(decoded_y1[0].y) == 0 +def run(x: Bytes[{buffer_size}]) -> DynArray[Point, 2]: + y: Bytes[{buffer_size}] = x + return _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, 0x01, 0x20, @@ -1243,14 +1288,17 @@ def run(x: Bytes[32 * 16]): 0x00, # length is 0, so decoding should succeed ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) + + res = c.run(data) - c.run(data) + assert res == [(1, [], 4)] def test_abi_decode_complex_arithmetic_overflow(tx_failed, get_contract): # inner head roundtrips due to arithmetic overflow - code = """ + buffer_size = 32 * 16 + code = f""" struct Point: x: uint256 y: DynArray[uint256, 2] @@ -1258,13 +1306,13 @@ def test_abi_decode_complex_arithmetic_overflow(tx_failed, get_contract): @external -def run(x: Bytes[32 * 16]): - y: Bytes[32 * 16] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, 0x01, 0x20, @@ -1276,39 +1324,43 @@ def run(x: Bytes[32 * 16]): 0x00, ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_empty_toplevel_dynarray(get_contract): - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]) -> DynArray[DynArray[uint256, 3], 3]: + y: Bytes[{buffer_size}] = x assert len(y) == 2 * 32 decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] ) - assert len(decoded_y1) == 0 + return decoded_y1 """ c = get_contract(code) - payload = (0x20, 0x00) # DynArray head, DynArray length + buffer_payload = (0x20, 0x00) # DynArray head, DynArray length + + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - data = _abi_payload_from_tuple(payload) + res = c.run(data) - c.run(data) + assert res == [] def test_abi_decode_invalid_toplevel_dynarray_head(tx_failed, get_contract): # head points 1B over the bounds of the runtime buffer - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -1317,33 +1369,34 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): c = get_contract(code) # head points 1B over the bounds of the runtime buffer - payload = (0x21, 0x00) # DynArray head, DynArray length + buffer_payload = (0x21, 0x00) # DynArray head, DynArray length - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_nested_invalid_dynarray_head(get_contract, tx_failed): - code = """ + buffer_size = 320 + code = f""" @nonpayable @external -def foo(x:Bytes[320]): +def foo(x:Bytes[{buffer_size}]): if True: a: Bytes[320-32] = b'' # make the word following the buffer x_mem dirty to make a potential # OOB revert fake_head: uint256 = 32 - x_mem: Bytes[320] = x + x_mem: Bytes[{buffer_size}] = x y: DynArray[DynArray[uint256, 2], 2] = _abi_decode(x_mem,DynArray[DynArray[uint256, 2], 2]) @nonpayable @external -def bar(x:Bytes[320]): - x_mem: Bytes[320] = x +def bar(x:Bytes[{buffer_size}]): + x_mem: Bytes[{buffer_size}] = x y:DynArray[DynArray[uint256, 2], 2] = _abi_decode(x_mem,DynArray[DynArray[uint256, 2], 2]) """ @@ -1355,7 +1408,7 @@ def bar(x:Bytes[320]): # 0x0, # head2 ) - encoded = _abi_payload_from_tuple(encoded + inner) + encoded = _abi_payload_from_tuple(encoded + inner, buffer_size) with tx_failed(): c.foo(encoded) # revert with tx_failed(): @@ -1363,22 +1416,23 @@ def bar(x:Bytes[320]): def test_static_outer_type_invalid_heads(get_contract, tx_failed): - code = """ + buffer_size = 320 + code = f""" @nonpayable @external -def foo(x:Bytes[320]): - x_mem: Bytes[320] = x +def foo(x:Bytes[{buffer_size}]): + x_mem: Bytes[{buffer_size}] = x y:DynArray[uint256, 2][2] = _abi_decode(x_mem,DynArray[uint256, 2][2]) @nonpayable @external -def bar(x:Bytes[320]): +def bar(x:Bytes[{buffer_size}]): if True: a: Bytes[160] = b'' # write stuff here to make the call revert in case decode do # an out of bound access: fake_head: uint256 = 32 - x_mem: Bytes[320] = x + x_mem: Bytes[{buffer_size}] = x y:DynArray[uint256, 2][2] = _abi_decode(x_mem,DynArray[uint256, 2][2]) """ c = get_contract(code) @@ -1389,7 +1443,7 @@ def bar(x:Bytes[320]): # 0x00, # head of the second dynarray ) - encoded = _abi_payload_from_tuple(encoded + inner) + encoded = _abi_payload_from_tuple(encoded + inner, buffer_size) with tx_failed(): c.foo(encoded) @@ -1402,9 +1456,10 @@ def test_abi_decode_max_size(get_contract, tx_failed): # of abi encoding the type. this can happen when the payload is # "sparse" and has garbage bytes in between the static and dynamic # sections - code = """ + buffer_size = 1000 + code = f""" @external -def foo(a:Bytes[1000]): +def foo(a:Bytes[{buffer_size}]): v: DynArray[uint256, 1] = _abi_decode(a,DynArray[uint256, 1]) """ c = get_contract(code) @@ -1420,7 +1475,7 @@ def foo(a:Bytes[1000]): ) with tx_failed(): - c.foo(_abi_payload_from_tuple(payload)) + c.foo(_abi_payload_from_tuple(payload, buffer_size)) # returndatasize check for uint256 diff --git a/tests/functional/builtins/codegen/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py index 8db51fdd07..47a225068d 100644 --- a/tests/functional/builtins/codegen/test_ecrecover.py +++ b/tests/functional/builtins/codegen/test_ecrecover.py @@ -1,7 +1,10 @@ +import contextlib + from eth_account import Account from eth_account._utils.signing import to_bytes32 -from tests.utils import ZERO_ADDRESS +from tests.utils import ZERO_ADDRESS, check_precompile_asserts +from vyper.compiler.settings import OptimizationLevel def test_ecrecover_test(get_contract): @@ -86,3 +89,40 @@ def test_ecrecover() -> bool: """ c = get_contract(code) assert c.test_ecrecover() is True + + +def test_ecrecover_oog_handling(env, get_contract, tx_failed, optimize, experimental_codegen): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@external +@view +def do_ecrecover(hash: bytes32, v: uint256, r:uint256, s:uint256) -> address: + return ecrecover(hash, v, r, s) + """ + check_precompile_asserts(code) + + c = get_contract(code) + + h = b"\x35" * 32 + local_account = Account.from_key(b"\x46" * 32) + sig = local_account.signHash(h) + v, r, s = sig.v, sig.r, sig.s + + assert c.do_ecrecover(h, v, r, s) == local_account.address + + gas_used = env.last_result.gas_used + + if optimize == OptimizationLevel.NONE and not experimental_codegen: + # if optimizations are off, enough gas is used by the contract + # that the gas provided to ecrecover (63/64ths rule) is enough + # for it to succeed + ctx = contextlib.nullcontext + else: + # in other cases, the gas forwarded is small enough for ecrecover + # to fail with oog, which we handle by reverting. + ctx = tx_failed + + with ctx(): + # provide enough spare gas for the top-level call to not oog but + # not enough for ecrecover to succeed + c.do_ecrecover(h, v, r, s, gas=gas_used) diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index dd6c5c7cc1..3088162238 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -672,11 +672,11 @@ def test_empty_array_in_event_logging(get_contract, get_logs): @external def foo(): log MyLog( - b'hellohellohellohellohellohellohellohellohello', - empty(int128[2][3]), - 314159, - b'helphelphelphelphelphelphelphelphelphelphelp', - empty(uint256[3]) + arg1=b'hellohellohellohellohellohellohellohellohello', + arg2=empty(int128[2][3]), + arg3=314159, + arg4=b'helphelphelphelphelphelphelphelphelphelphelp', + arg5=empty(uint256[3]) ) """ diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index 4107f9a4d0..bf953ff018 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -261,6 +261,12 @@ def __default__(): assert env.message_call(caller.address, data=sig) == b"" +def _strip_initcode_suffix(bytecode): + bs = bytes.fromhex(bytecode.removeprefix("0x")) + to_strip = int.from_bytes(bs[-2:], "big") + return bs[:-to_strip].hex() + + # check max_outsize=0 does same thing as not setting max_outsize. # compile to bytecode and compare bytecode directly. def test_max_outsize_0(): @@ -276,7 +282,11 @@ def test_raw_call(_target: address): """ output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"]) output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"]) - assert output1 == output2 + assert output1["bytecode_runtime"] == output2["bytecode_runtime"] + + bytecode1 = output1["bytecode"] + bytecode2 = output2["bytecode"] + assert _strip_initcode_suffix(bytecode1) == _strip_initcode_suffix(bytecode2) # check max_outsize=0 does same thing as not setting max_outsize, @@ -298,7 +308,11 @@ def test_raw_call(_target: address) -> bool: """ output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"]) output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"]) - assert output1 == output2 + assert output1["bytecode_runtime"] == output2["bytecode_runtime"] + + bytecode1 = output1["bytecode"] + bytecode2 = output2["bytecode"] + assert _strip_initcode_suffix(bytecode1) == _strip_initcode_suffix(bytecode2) # test functionality of max_outsize=0 diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index db03f6f023..0777b2f08b 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -5,7 +5,7 @@ from vyper.compiler import compile_code from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import version_check -from vyper.exceptions import ArgumentException, TypeMismatch +from vyper.exceptions import ArgumentException, StaticAssertionException, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -533,9 +533,15 @@ def do_slice(): @pytest.mark.parametrize("bad_code", oob_fail_list) def test_slice_buffer_oob_reverts(bad_code, get_contract, tx_failed): - c = get_contract(bad_code) - with tx_failed(): - c.do_slice() + try: + c = get_contract(bad_code) + with tx_failed(): + c.do_slice() + except StaticAssertionException: + # it should be ok if we + # catch the assert in compile time + # since it supposed to be revert + pass # tests all 3 adhoc locations: `msg.data`, `self.code`, `
.code` diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index 4d54e31f91..08d9c08678 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -28,7 +28,7 @@ def test_basic_default(env, get_logs, get_contract): @external @payable def __default__(): - log Sent(msg.sender) + log Sent(sender=msg.sender) """ c = get_contract(code) env.set_balance(env.deployer, 10**18) @@ -46,13 +46,13 @@ def test_basic_default_default_param_function(env, get_logs, get_contract): @external @payable def fooBar(a: int128 = 12345) -> int128: - log Sent(empty(address)) + log Sent(sender=empty(address)) return a @external @payable def __default__(): - log Sent(msg.sender) + log Sent(sender=msg.sender) """ c = get_contract(code) env.set_balance(env.deployer, 10**18) @@ -69,7 +69,7 @@ def test_basic_default_not_payable(env, tx_failed, get_contract): @external def __default__(): - log Sent(msg.sender) + log Sent(sender=msg.sender) """ c = get_contract(code) env.set_balance(env.deployer, 10**17) @@ -103,7 +103,7 @@ def test_always_public_2(assert_compile_failed, get_contract): sender: indexed(address) def __default__(): - log Sent(msg.sender) + log Sent(sender=msg.sender) """ assert_compile_failed(lambda: get_contract(code)) @@ -119,12 +119,12 @@ def test_zero_method_id(env, get_logs, get_contract, tx_failed): @payable # function selector: 0x00000000 def blockHashAskewLimitary(v: uint256) -> uint256: - log Sent(2) + log Sent(sig=2) return 7 @external def __default__(): - log Sent(1) + log Sent(sig=1) """ c = get_contract(code) @@ -165,12 +165,12 @@ def test_another_zero_method_id(env, get_logs, get_contract, tx_failed): @payable # function selector: 0x00000000 def wycpnbqcyf() -> uint256: - log Sent(2) + log Sent(sig=2) return 7 @external def __default__(): - log Sent(1) + log Sent(sig=1) """ c = get_contract(code) @@ -205,12 +205,12 @@ def test_partial_selector_match_trailing_zeroes(env, get_logs, get_contract): @payable # function selector: 0xd88e0b00 def fow() -> uint256: - log Sent(2) + log Sent(sig=2) return 7 @external def __default__(): - log Sent(1) + log Sent(sig=1) """ c = get_contract(code) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index f9252f0a99..d98c8d79dc 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -1441,12 +1441,18 @@ def get_lucky(gas_amount: uint256) -> int128: c2.get_lucky(50) # too little gas. -def test_skip_contract_check(get_contract): +def test_skip_contract_check(get_contract, tx_failed): contract_2 = """ @external @view def bar(): pass + +# include fallback for sanity, make sure we don't get trivially rejected in +# selector table +@external +def __default__(): + pass """ contract_1 = """ interface Bar: @@ -1454,9 +1460,10 @@ def bar() -> uint256: view def baz(): nonpayable @external -def call_bar(addr: address): - # would fail if returndatasize check were on - x: uint256 = staticcall Bar(addr).bar(skip_contract_check=True) +def call_bar(addr: address) -> uint256: + # fails during abi decoding + return staticcall Bar(addr).bar(skip_contract_check=True) + @external def call_baz(): # some address with no code @@ -1466,7 +1473,10 @@ def call_baz(): """ c1 = get_contract(contract_1) c2 = get_contract(contract_2) - c1.call_bar(c2.address) + + with tx_failed(): + c1.call_bar(c2.address) + c1.call_baz() diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index d313aa3bda..b9e34ea49b 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -436,7 +436,7 @@ def i_am_me() -> bool: @external @nonpayable def whoami() -> address: - log Addr(self._whoami()) + log Addr(addr=self._whoami()) return self._whoami() """ diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 184e6a2859..036e7c0647 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -897,3 +897,36 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_memory(get_contract): + code = """ +@external +def foo() -> DynArray[uint256, 10]: + # check VarInfos are distinguished by decl_node when they have same type + alreadyDone: DynArray[uint256, 10] = [] + _assets: DynArray[uint256, 10] = [1, 2, 3, 4, 3, 2, 1] + for a: uint256 in _assets: + if a in alreadyDone: + continue + alreadyDone.append(a) + return alreadyDone + """ + c = get_contract(code) + assert c.foo() == [1, 2, 3, 4] + + +def test_iterator_modification_func_arg(get_contract): + code = """ +@internal +def boo(a: DynArray[uint256, 12] = [], b: DynArray[uint256, 12] = []) -> DynArray[uint256, 12]: + for i: uint256 in a: + b.append(i) + return b + +@external +def foo() -> DynArray[uint256, 12]: + return self.boo([1, 2, 3]) + """ + c = get_contract(code) + assert c.foo() == [1, 2, 3] diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 1adffcf29a..2b015a1cce 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -5,7 +5,6 @@ from eth_utils import keccak from tests.utils import ZERO_ADDRESS, decimal_to_int -from vyper.exceptions import StackTooDeep from vyper.utils import int_bounds @@ -429,7 +428,7 @@ def foo(b: int128[6][1][2]) -> int128[6][1][2]: c = get_contract(code) with tx_failed(): - _make_tx(env, c.address, "foo(int128[6][1][2]])", values) + _make_tx(env, c.address, "foo(int128[6][1][2])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -453,7 +452,7 @@ def test_int128_dynarray_clamper_failing(env, tx_failed, get_contract, bad_value # ensure the invalid value is detected at all locations in the array code = """ @external -def foo(b: int128[5]) -> int128[5]: +def foo(b: DynArray[int128, 5]) -> DynArray[int128, 5]: return b """ @@ -502,7 +501,6 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_multidimension_dynarray_clamper_passing(get_contract, value): code = """ @external diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index cf77a30bd9..87d848fae5 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -5,13 +5,14 @@ from tests.utils import decimal_to_int from vyper import compile_code from vyper.exceptions import ( - ArgumentException, EventDeclarationException, + InstantiationException, InvalidType, NamespaceCollision, StructureException, TypeMismatch, UndeclaredDefinition, + UnknownAttribute, ) from vyper.utils import keccak256 @@ -50,7 +51,7 @@ def test_event_logging_with_topics(get_logs, keccak, get_contract): @external def foo(): self.a = b"bar" - log MyLog(self.a) + log MyLog(arg1=self.a) """ c = get_contract(loggy_code) @@ -78,7 +79,7 @@ def test_event_logging_with_multiple_topics(env, keccak, get_logs, get_contract) @external def foo(): - log MyLog(-2, True, self) + log MyLog(arg1=-2, arg2=True, arg3=self) """ c = get_contract(loggy_code) @@ -120,7 +121,7 @@ def test_event_logging_with_multiple_topics_var_and_store(get_contract, get_logs def foo(arg1: int128): a: bool = True self.b = self - log MyLog(arg1, a, self.b) + log MyLog(arg1=arg1, arg2=a, arg3=self.b) """ c = get_contract(code) @@ -141,13 +142,13 @@ def test_logging_the_same_event_multiple_times_with_topics(env, keccak, get_logs @external def foo(): - log MyLog(1, self) - log MyLog(1, self) + log MyLog(arg1=1, arg2=self) + log MyLog(arg1=1, arg2=self) @external def bar(): - log MyLog(1, self) - log MyLog(1, self) + log MyLog(arg1=1, arg2=self) + log MyLog(arg1=1, arg2=self) """ c = get_contract(loggy_code) @@ -198,7 +199,7 @@ def test_event_logging_with_data(get_logs, keccak, get_contract): @external def foo(): - log MyLog(123) + log MyLog(arg1=123) """ c = get_contract(loggy_code) @@ -231,8 +232,16 @@ def test_event_logging_with_fixed_array_data(env, keccak, get_logs, get_contract @external def foo(): - log MyLog([1,2], [block.timestamp, block.timestamp+1, block.timestamp+2], [[1,2],[1,2]]) - log MyLog([1,2], [block.timestamp, block.timestamp+1, block.timestamp+2], [[1,2],[1,2]]) + log MyLog( + arg1=[1,2], + arg2=[block.timestamp, block.timestamp+1, block.timestamp+2], + arg3=[[1,2],[1,2]] + ) + log MyLog( + arg1=[1,2], + arg2=[block.timestamp, block.timestamp+1, block.timestamp+2], + arg3=[[1,2],[1,2]] + ) """ c = get_contract(loggy_code) @@ -271,7 +280,7 @@ def test_logging_with_input_bytes_1(env, keccak, get_logs, get_contract): @external def foo(arg1: Bytes[29], arg2: Bytes[31]): - log MyLog(b'bar', arg1, arg2) + log MyLog(arg1=b'bar', arg2=arg1, arg3=arg2) """ c = get_contract(loggy_code) @@ -307,7 +316,7 @@ def test_event_logging_with_bytes_input_2(env, keccak, get_logs, get_contract): @external def foo(_arg1: Bytes[20]): - log MyLog(_arg1) + log MyLog(arg1=_arg1) """ c = get_contract(loggy_code) @@ -335,7 +344,7 @@ def test_event_logging_with_bytes_input_3(get_logs, keccak, get_contract): @external def foo(_arg1: Bytes[5]): - log MyLog(_arg1) + log MyLog(arg1=_arg1) """ c = get_contract(loggy_code) @@ -369,7 +378,7 @@ def test_event_logging_with_data_with_different_types(env, keccak, get_logs, get @external def foo(): - log MyLog(123, b'home', b'bar', 0xc305c901078781C232A2a521C2aF7980f8385ee9, self, block.timestamp) # noqa: E501 + log MyLog(arg1=123, arg2=b'home', arg3=b'bar', arg4=0xc305c901078781C232A2a521C2aF7980f8385ee9, arg5=self, arg6=block.timestamp) # noqa: E501 """ c = get_contract(loggy_code) @@ -412,7 +421,7 @@ def test_event_logging_with_topics_and_data_1(env, keccak, get_logs, get_contrac @external def foo(): - log MyLog(1, b'bar') + log MyLog(arg1=1, arg2=b'bar') """ c = get_contract(loggy_code) @@ -457,8 +466,8 @@ def test_event_logging_with_multiple_logs_topics_and_data(env, keccak, get_logs, @external def foo(): - log MyLog(1, b'bar') - log YourLog(self, MyStruct(x=1, y=b'abc', z=SmallStruct(t='house', w=13.5))) + log MyLog(arg1=1, arg2=b'bar') + log YourLog(arg1=self, arg2=MyStruct(x=1, y=b'abc', z=SmallStruct(t='house', w=13.5))) """ c = get_contract(loggy_code) @@ -524,7 +533,7 @@ def test_fails_when_input_is_the_wrong_type(tx_failed, get_contract): @external def foo_(): - log MyLog(b'yo') + log MyLog(arg1=b'yo') """ with tx_failed(TypeMismatch): @@ -539,7 +548,7 @@ def test_fails_when_topic_is_the_wrong_size(tx_failed, get_contract): @external def foo(): - log MyLog(b'bars') + log MyLog(arg1=b'bars') """ with tx_failed(TypeMismatch): @@ -553,7 +562,7 @@ def test_fails_when_input_topic_is_the_wrong_size(tx_failed, get_contract): @external def foo(arg1: Bytes[4]): - log MyLog(arg1) + log MyLog(arg1=arg1) """ with tx_failed(TypeMismatch): @@ -567,7 +576,7 @@ def test_fails_when_data_is_the_wrong_size(tx_failed, get_contract): @external def foo(): - log MyLog(b'bars') + log MyLog(arg1=b'bars') """ with tx_failed(TypeMismatch): @@ -581,7 +590,7 @@ def test_fails_when_input_data_is_the_wrong_size(tx_failed, get_contract): @external def foo(arg1: Bytes[4]): - log MyLog(arg1) + log MyLog(arg1=arg1) """ with tx_failed(TypeMismatch): @@ -610,7 +619,7 @@ def test_logging_fails_with_over_three_topics(tx_failed, get_contract): @deploy def __init__(): - log MyLog(1, 2, 3, 4) + log MyLog(arg1=1, arg2=2, arg3=3, arg4=4) """ with tx_failed(EventDeclarationException): @@ -650,7 +659,7 @@ def test_logging_fails_with_topic_type_mismatch(tx_failed, get_contract): @external def foo(): - log MyLog(self) + log MyLog(arg1=self) """ with tx_failed(TypeMismatch): @@ -664,7 +673,7 @@ def test_logging_fails_with_data_type_mismatch(tx_failed, get_contract): @external def foo(): - log MyLog(self) + log MyLog(arg1=self) """ with tx_failed(TypeMismatch): @@ -680,9 +689,9 @@ def test_logging_fails_when_number_of_arguments_is_greater_than_declaration( @external def foo(): - log MyLog(1, 2) + log MyLog(arg1=1, arg2=2) """ - with tx_failed(ArgumentException): + with tx_failed(UnknownAttribute): get_contract(loggy_code) @@ -694,9 +703,9 @@ def test_logging_fails_when_number_of_arguments_is_less_than_declaration(tx_fail @external def foo(): - log MyLog(1) + log MyLog(arg1=1) """ - with tx_failed(ArgumentException): + with tx_failed(InstantiationException): get_contract(loggy_code) @@ -852,7 +861,7 @@ def test_variable_list_packing(get_logs, get_contract): @external def foo(): a: int128[4] = [1, 2, 3, 4] - log Bar(a) + log Bar(_value=a) """ c = get_contract(code) @@ -868,7 +877,7 @@ def test_literal_list_packing(get_logs, get_contract): @external def foo(): - log Bar([1, 2, 3, 4]) + log Bar(_value=[1, 2, 3, 4]) """ c = get_contract(code) @@ -886,7 +895,7 @@ def test_storage_list_packing(get_logs, get_contract): @external def foo(): - log Bar(self.x) + log Bar(_value=self.x) @external def set_list(): @@ -910,7 +919,7 @@ def test_passed_list_packing(get_logs, get_contract): @external def foo(barbaric: int128[4]): - log Bar(barbaric) + log Bar(_value=barbaric) """ c = get_contract(code) @@ -926,7 +935,7 @@ def test_variable_decimal_list_packing(get_logs, get_contract): @external def foo(): - log Bar([1.11, 2.22, 3.33, 4.44]) + log Bar(_value=[1.11, 2.22, 3.33, 4.44]) """ c = get_contract(code) @@ -949,7 +958,7 @@ def test_storage_byte_packing(get_logs, get_contract): @external def foo(a: int128): - log MyLog(self.x) + log MyLog(arg1=self.x) @external def setbytez(): @@ -975,7 +984,7 @@ def test_storage_decimal_list_packing(get_logs, get_contract): @external def foo(): - log Bar(self.x) + log Bar(_value=self.x) @external def set_list(): @@ -1004,7 +1013,7 @@ def test_logging_fails_when_input_is_too_big(tx_failed, get_contract): @external def foo(inp: Bytes[33]): - log Bar(inp) + log Bar(_value=inp) """ with tx_failed(TypeMismatch): get_contract(code) @@ -1019,7 +1028,7 @@ def test_2nd_var_list_packing(get_logs, get_contract): @external def foo(): a: int128[4] = [1, 2, 3, 4] - log Bar(10, a) + log Bar(arg1=10, arg2=a) """ c = get_contract(code) @@ -1037,7 +1046,7 @@ def test_2nd_var_storage_list_packing(get_logs, get_contract): @external def foo(): - log Bar(10, self.x) + log Bar(arg1=10, arg2=self.x) @external def set_list(): @@ -1071,7 +1080,7 @@ def __init__(): @external def foo(): v: int128[3] = [7, 8, 9] - log Bar(10, self.x, b"test", v, self.y) + log Bar(arg1=10, arg2=self.x, arg3=b"test", arg4=v, arg5=self.y) @external def set_list(): @@ -1104,7 +1113,7 @@ def test_hashed_indexed_topics_calldata(get_logs, keccak, get_contract): @external def foo(a: Bytes[36], b: int128, c: String[7]): - log MyLog(a, b, c) + log MyLog(arg1=a, arg2=b, arg3=c) """ c = get_contract(loggy_code) @@ -1144,7 +1153,7 @@ def foo(): a: Bytes[10] = b"potato" b: int128 = -777 c: String[44] = "why hello, neighbor! how are you today?" - log MyLog(a, b, c) + log MyLog(arg1=a, arg2=b, arg3=c) """ c = get_contract(loggy_code) @@ -1191,7 +1200,7 @@ def setter(_a: Bytes[32], _b: int128, _c: String[6]): @external def foo(): - log MyLog(self.a, self.b, self.c) + log MyLog(arg1=self.a, arg2=self.b, arg3=self.c) """ c = get_contract(loggy_code) @@ -1229,7 +1238,7 @@ def test_hashed_indexed_topics_storxxage(get_logs, keccak, get_contract): @external def foo(): - log MyLog(b"wow", 666, "madness!") + log MyLog(arg1=b"wow", arg2=666, arg3="madness!") """ c = get_contract(loggy_code) @@ -1245,6 +1254,23 @@ def foo(): assert log.topics == [event_id, topic1, topic2, topic3] +valid_list = [ + # test constant folding inside raw_log + """ +topic: constant(bytes32) = 0x1212121212121210212801291212121212121210121212121212121212121212 + +@external +def foo(): + raw_log([[topic]][0], b'') + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_raw_log_pass(code): + assert compile_code(code) is not None + + fail_list = [ ( """ diff --git a/tests/functional/codegen/features/test_logging_bytes_extended.py b/tests/functional/codegen/features/test_logging_bytes_extended.py index 6b84cdd23a..64c848bb8e 100644 --- a/tests/functional/codegen/features/test_logging_bytes_extended.py +++ b/tests/functional/codegen/features/test_logging_bytes_extended.py @@ -7,7 +7,7 @@ def test_bytes_logging_extended(get_contract, get_logs): @external def foo(): - log MyLog(667788, b'hellohellohellohellohellohellohellohellohello', 334455) + log MyLog(arg1=667788, arg2=b'hellohellohellohellohellohellohellohellohello', arg3=334455) """ c = get_contract(code) @@ -31,7 +31,7 @@ def foo(): a: Bytes[64] = b'hellohellohellohellohellohellohellohellohello' b: Bytes[64] = b'hellohellohellohellohellohellohellohello' # test literal much smaller than buffer - log MyLog(a, b, b'hello') + log MyLog(arg1=a, arg2=b, arg3=b'hello') """ c = get_contract(code) @@ -51,7 +51,7 @@ def test_bytes_logging_extended_passthrough(get_contract, get_logs): @external def foo(a: int128, b: Bytes[64], c: int128): - log MyLog(a, b, c) + log MyLog(arg1=a, arg2=b, arg3=c) """ c = get_contract(code) @@ -77,7 +77,7 @@ def test_bytes_logging_extended_storage(get_contract, get_logs): @external def foo(): - log MyLog(self.a, self.b, self.c) + log MyLog(arg1=self.a, arg2=self.b, arg3=self.c) @external def set(x: int128, y: Bytes[64], z: int128): @@ -114,10 +114,10 @@ def test_bytes_logging_extended_mixed_with_lists(get_contract, get_logs): @external def foo(): log MyLog( - [[24, 26], [12, 10]], - b'hellohellohellohellohellohellohellohellohello', - 314159, - b'helphelphelphelphelphelphelphelphelphelphelp' + arg1=[[24, 26], [12, 10]], + arg2=b'hellohellohellohellohellohellohellohellohello', + arg3=314159, + arg4=b'helphelphelphelphelphelphelphelphelphelphelp' ) """ diff --git a/tests/functional/codegen/features/test_logging_from_call.py b/tests/functional/codegen/features/test_logging_from_call.py index 190be7b4f4..2b14cd8398 100644 --- a/tests/functional/codegen/features/test_logging_from_call.py +++ b/tests/functional/codegen/features/test_logging_from_call.py @@ -21,11 +21,11 @@ def to_bytes32(_value: uint256) -> bytes32: @external def test_func(_value: uint256): data2: Bytes[60] = concat(self.to_bytes32(_value),self.to_bytes(_value),b"testing") - log TestLog(self.to_bytes32(_value), data2, self.to_bytes(_value)) + log TestLog(testData1=self.to_bytes32(_value), testData2=data2, testData3=self.to_bytes(_value)) loggedValue: bytes32 = self.to_bytes32(_value) loggedValue2: Bytes[8] = self.to_bytes(_value) - log TestLog(loggedValue, data2, loggedValue2) + log TestLog(testData1=loggedValue, testData2=data2, testData3=loggedValue2) """ c = get_contract(code) @@ -65,8 +65,8 @@ def test_func(_value: uint256,input: Bytes[133]): data2: Bytes[200] = b"hello world" - # log TestLog(self.to_bytes32(_value),input,self.to_bytes(_value)) - log TestLog(self.to_bytes32(_value),input,"bababa") + # log TestLog(testData1=self.to_bytes32(_value),testData2=input,testData3=self.to_bytes(_value)) + log TestLog(testData1=self.to_bytes32(_value),testData2=input,testData3="bababa") """ c = get_contract(code) @@ -99,8 +99,8 @@ def test_func(_value: uint256,input: Bytes[133]): data2: Bytes[200] = b"hello world" # log will be malformed - # log TestLog(self.to_bytes32(_value),input,self.to_bytes(_value)) - log TestLog(self.to_bytes32(_value), input) + # log TestLog(testData1=self.to_bytes32(_value),testData2=input,testData3=self.to_bytes(_value)) + log TestLog(testData1=self.to_bytes32(_value), testData2=input) """ c = get_contract(code) @@ -137,12 +137,12 @@ def test_func(_value: uint256,input: Bytes[2048]): data2: Bytes[2064] = concat(self.to_bytes(_value),self.to_bytes(_value),input) # log will be malformed - log TestLog(self.to_bytes32(_value), data2, self.to_bytes(_value)) + log TestLog(testData1=self.to_bytes32(_value), testData2=data2, testData3=self.to_bytes(_value)) loggedValue: Bytes[8] = self.to_bytes(_value) # log will be normal - log TestLog(self.to_bytes32(_value),data2,loggedValue) + log TestLog(testData1=self.to_bytes32(_value),testData2=data2,testData3=loggedValue) """ c = get_contract(code) diff --git a/tests/functional/codegen/features/test_mana.py b/tests/functional/codegen/features/test_mana.py new file mode 100644 index 0000000000..1169b011ff --- /dev/null +++ b/tests/functional/codegen/features/test_mana.py @@ -0,0 +1,11 @@ +def test_mana_call(get_contract): + mana_call = """ +@external +def foo() -> uint256: + return msg.mana + """ + + c = get_contract(mana_call) + + assert c.foo(gas=50000) < 50000 + assert c.foo(gas=50000) > 25000 diff --git a/tests/functional/codegen/features/test_memory_dealloc.py b/tests/functional/codegen/features/test_memory_dealloc.py index 3be57038ef..b733de736b 100644 --- a/tests/functional/codegen/features/test_memory_dealloc.py +++ b/tests/functional/codegen/features/test_memory_dealloc.py @@ -9,7 +9,7 @@ def sendit(): nonpayable @external def foo(target: address) -> uint256[2]: - log Shimmy(empty(address), 3) + log Shimmy(a=empty(address), b=3) amount: uint256 = 1 flargen: uint256 = 42 extcall Other(target).sendit() diff --git a/tests/functional/codegen/modules/test_events.py b/tests/functional/codegen/modules/test_events.py index ae5198cf90..c32a66caec 100644 --- a/tests/functional/codegen/modules/test_events.py +++ b/tests/functional/codegen/modules/test_events.py @@ -50,7 +50,7 @@ def test_module_event_indexed(get_contract, make_input_bundle, get_logs): @internal def foo(): - log MyEvent(5, 6) + log MyEvent(x=5, y=6) """ main = """ import lib1 diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 93f4fe6c2f..3cc21d61a9 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -440,3 +440,26 @@ def __init__(): # call `c.__default__()` env.message_call(c.address) assert c.counter() == 6 + + +def test_inline_interface_export(make_input_bundle, get_contract): + lib1 = """ +interface IAsset: + def asset() -> address: view + +implements: IAsset + +@external +@view +def asset() -> address: + return self + """ + main = """ +import lib1 + +exports: lib1.IAsset + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.asset() == c.address diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py index c0fae6496c..af9f9b5e68 100644 --- a/tests/functional/codegen/modules/test_interface_imports.py +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -1,3 +1,6 @@ +import pytest + + def test_import_interface_types(make_input_bundle, get_contract): ifaces = """ interface IFoo: @@ -50,11 +53,70 @@ def foo() -> bool: # check that this typechecks both directions a: lib1.IERC20 = IERC20(msg.sender) b: lib2.IERC20 = IERC20(msg.sender) + c: IERC20 = lib1.IERC20(msg.sender) # allowed in call position # return the equality so we can sanity check it - return a == b + return a == b and b == c """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) c = get_contract(main, input_bundle=input_bundle) assert c.foo() is True + + +@pytest.mark.parametrize("interface_syntax", ["__at__", "__interface__"]) +def test_intrinsic_interface(get_contract, make_input_bundle, interface_syntax): + lib = """ +@external +@view +def foo() -> uint256: + # detect self call + if msg.sender == self: + return 4 + else: + return 5 + """ + + main = f""" +import lib + +exports: lib.__interface__ + +@external +@view +def bar() -> uint256: + return staticcall lib.{interface_syntax}(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.foo() == 5 + assert c.bar() == 4 + + +def test_import_interface_flags(make_input_bundle, get_contract): + ifaces = """ +flag Foo: + BOO + MOO + POO + +interface IFoo: + def foo() -> Foo: nonpayable + """ + + contract = """ +import ifaces + +implements: ifaces + +@external +def foo() -> ifaces.Foo: + return ifaces.Foo.POO + """ + + input_bundle = make_input_bundle({"ifaces.vyi": ifaces}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 4 diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 9442362696..e0b59ff668 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -4,7 +4,7 @@ from eth_utils import to_wei from tests.utils import decimal_to_int -from vyper.compiler import compile_code +from vyper.compiler import compile_code, compile_from_file_input from vyper.exceptions import ( ArgumentException, DuplicateImport, @@ -13,6 +13,7 @@ ) +# TODO CMC 2024-10-13: this should probably be in tests/unit/compiler/ def test_basic_extract_interface(): code = """ # Events @@ -22,6 +23,7 @@ def test_basic_extract_interface(): _to: address _value: uint256 + # Functions @view @@ -37,6 +39,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): assert code_pass.strip() == out.strip() +# TODO CMC 2024-10-13: this should probably be in tests/unit/compiler/ def test_basic_extract_external_interface(): code = """ @view @@ -68,6 +71,7 @@ def test(_owner: address): nonpayable assert interface.strip() == out.strip() +# TODO CMC 2024-10-13: should probably be in syntax tests def test_basic_interface_implements(assert_compile_failed): code = """ from ethereum.ercs import IERC20 @@ -82,6 +86,7 @@ def test() -> bool: assert_compile_failed(lambda: compile_code(code), InterfaceViolation) +# TODO CMC 2024-10-13: should probably be in syntax tests def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external @@ -126,6 +131,7 @@ def foo() -> uint256: compile_code(not_implemented_code, input_bundle=input_bundle) +# TODO CMC 2024-10-13: should probably be in syntax tests def test_log_interface_event(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: @@ -155,11 +161,10 @@ def bar() -> uint256: ("import Foo as Foo", "Foo.vyi"), ("from a import Foo", "a/Foo.vyi"), ("from b.a import Foo", "b/a/Foo.vyi"), - ("from .a import Foo", "./a/Foo.vyi"), - ("from ..a import Foo", "../a/Foo.vyi"), ] +# TODO CMC 2024-10-13: should probably be in syntax tests @pytest.mark.parametrize("code,filename", VALID_IMPORT_CODE) def test_extract_file_interface_imports(code, filename, make_input_bundle): input_bundle = make_input_bundle({filename: ""}) @@ -167,6 +172,22 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): assert compile_code(code, input_bundle=input_bundle) is not None +VALID_RELATIVE_IMPORT_CODE = [ + # import statement, import path without suffix + ("from .a import Foo", "mock.vy"), + ("from ..a import Foo", "b/mock.vy"), +] + + +# TODO CMC 2024-10-13: should probably be in syntax tests +@pytest.mark.parametrize("code,filename", VALID_RELATIVE_IMPORT_CODE) +def test_extract_file_interface_relative_imports(code, filename, make_input_bundle): + input_bundle = make_input_bundle({"a/Foo.vyi": "", filename: code}) + + file_input = input_bundle.load_file(filename) + assert compile_from_file_input(file_input, input_bundle=input_bundle) is not None + + BAD_IMPORT_CODE = [ ("import a as A\nimport a as A", DuplicateImport), ("import a as A\nimport a as a", DuplicateImport), @@ -177,13 +198,13 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): ] +# TODO CMC 2024-10-13: should probably be in syntax tests @pytest.mark.parametrize("code,exception_type", BAD_IMPORT_CODE) -def test_extract_file_interface_imports_raises( - code, exception_type, assert_compile_failed, make_input_bundle -): - input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": ""}) +def test_extract_file_interface_imports_raises(code, exception_type, make_input_bundle): + input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": "", "mock.vy": code}) + file_input = input_bundle.load_file("mock.vy") with pytest.raises(exception_type): - compile_code(code, input_bundle=input_bundle) + compile_from_file_input(file_input, input_bundle=input_bundle) def test_external_call_to_interface(env, get_contract, make_input_bundle): @@ -695,3 +716,268 @@ def test_call(a: address, b: {type_str}) -> {type_str}: make_file("jsonabi.json", json.dumps(convert_v1_abi(abi))) c3 = get_contract(code, input_bundle=input_bundle) assert c3.test_call(c1.address, value) == value + + +def test_interface_function_without_visibility(make_input_bundle, get_contract): + interface_code = """ +def foo() -> uint256: + ... + +@external +def bar() -> uint256: + ... + """ + + code = """ +import a as FooInterface + +implements: FooInterface + +@external +def foo() -> uint256: + return 1 + +@external +def bar() -> uint256: + return 1 + """ + + input_bundle = make_input_bundle({"a.vyi": interface_code}) + + c = get_contract(code, input_bundle=input_bundle) + + assert c.foo() == c.bar() == 1 + + +def test_interface_with_structures(): + code = """ +struct MyStruct: + a: address + b: uint256 + +event Transfer: + sender: indexed(address) + receiver: indexed(address) + value: uint256 + +struct Voter: + weight: int128 + voted: bool + delegate: address + vote: int128 + +@external +def bar(): + pass + +event Buy: + buyer: indexed(address) + buy_order: uint256 + +@external +@view +def foo(s: MyStruct) -> MyStruct: + return s + """ + + out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"] + + assert "# Structs" in out + assert "struct MyStruct:" in out + assert "b: uint256" in out + assert "struct Voter:" in out + assert "voted: bool" in out + + +def test_intrinsic_interface_instantiation(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +i: lib1.__interface__ + +@external +def bar() -> lib1.__interface__: + self.i = lib1.__at__(self) + return self.i + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_converts(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +@external +def bar() -> lib1.__interface__: + return lib1.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_kws(env, make_input_bundle, get_contract): + value = 10**5 + lib1 = f""" +@external +@payable +def foo(a: address): + send(a, {value}) + """ + main = f""" +import lib1 + +exports: lib1.__interface__ + +@external +def bar(a: address): + extcall lib1.__at__(self).foo(a, value={value}) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + env.set_balance(c.address, value) + original_balance = env.get_balance(env.deployer) + c.bar(env.deployer) + assert env.get_balance(env.deployer) == original_balance + value + + +def test_intrinsic_interface_defaults(env, make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def foo(i: uint256=1) -> uint256: + return i + """ + main = """ +import lib1 + +exports: lib1.__interface__ + +@external +def bar() -> uint256: + return extcall lib1.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.bar() == 1 + + +def test_interface_with_flags(): + code = """ +struct MyStruct: + a: address + +flag Foo: + BOO + MOO + POO + +event Transfer: + sender: indexed(address) + +@external +def bar(): + pass +flag BAR: + BIZ + BAZ + BOO + +@external +@view +def foo(s: MyStruct) -> MyStruct: + return s + """ + + out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"] + + assert "# Flags" in out + assert "flag Foo:" in out + assert "flag BAR" in out + assert "BOO" in out + assert "MOO" in out + + compile_code(out, contract_path="code.vyi", output_formats=["interface"]) + + +vyi_filenames = [ + "test__test.vyi", + "test__t.vyi", + "t__test.vyi", + "t__t.vyi", + "t_t.vyi", + "test_test.vyi", + "t_test.vyi", + "test_t.vyi", + "_test_t__t_tt_.vyi", + "foo_bar_baz.vyi", +] + + +@pytest.mark.parametrize("vyi_filename", vyi_filenames) +def test_external_interface_names(vyi_filename): + code = """ +@external +def foo(): + ... + """ + + compile_code(code, contract_path=vyi_filename, output_formats=["external_interface"]) + + +def test_external_interface_with_flag(): + code = """ +flag Foo: + Blah + +@external +def foo() -> Foo: + ... + """ + + out = compile_code(code, contract_path="test__test.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + assert "-> Foo:" in out + + +def test_external_interface_compiles_again(): + code = """ +@external +def foo() -> uint256: + ... +@external +def bar(a:int32) -> uint256: + ... + """ + + out = compile_code(code, contract_path="test.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + compile_code(out, contract_path="test.vyi", output_formats=["external_interface"]) + + +@pytest.mark.xfail +def test_weird_interface_name(): + # based on comment https://github.com/vyperlang/vyper/pull/4290#discussion_r1884137428 + # we replace "_" for "" which results in an interface without name + out = compile_code("", contract_path="_.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + assert "interface _:" in out diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 36c14f804d..ad8bf74b0d 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -299,7 +299,7 @@ def foo(): compile_code(code) -def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): +def test_replace_decimal_nested_intermediate_underflow(): code = """ @external def foo(): diff --git a/tests/functional/codegen/types/numbers/test_exponents.py b/tests/functional/codegen/types/numbers/test_exponents.py index 702cbcb1dd..28dba59edc 100644 --- a/tests/functional/codegen/types/numbers/test_exponents.py +++ b/tests/functional/codegen/types/numbers/test_exponents.py @@ -173,3 +173,17 @@ def foo(b: int128) -> int128: c.foo(max_power) with tx_failed(): c.foo(max_power + 1) + + +valid_list = [ + """ +@external +def foo() -> uint256: + return (10**18)**2 + """ +] + + +@pytest.mark.parametrize("good_code", valid_list) +def test_exponent_success(good_code): + assert compile_code(good_code) is not None diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 42619a8bd5..2bd3184ec0 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -269,13 +269,65 @@ def foo(): compile_code(code) -def test_invalid_div(): - code = """ +div_code_with_hint = [ + ( + """ @external def foo(): a: uint256 = 5 / 9 - """ + """, + "did you mean `5 // 9`?", + ), + ( + """ +@external +def foo(): + a: uint256 = 10 + a /= (3 + 10) // (2 + 3) + """, + "did you mean `a //= (3 + 10) // (2 + 3)`?", + ), + ( + """ +@external +def foo(a: uint256, b:uint256, c: uint256) -> uint256: + return (a + b) / c + """, + "did you mean `(a + b) // c`?", + ), + ( + """ +@external +def foo(a: uint256, b:uint256, c: uint256) -> uint256: + return (a + b) / (a + c) + """, + "did you mean `(a + b) // (a + c)`?", + ), + ( + """ +@external +def foo(a: uint256, b:uint256, c: uint256) -> uint256: + return (a + (c + b)) / (a + c) + """, + "did you mean `(a + (c + b)) // (a + c)`?", + ), + ( + """ +interface Foo: + def foo() -> uint256: view + +@external +def foo(a: uint256, b:uint256, c: uint256) -> uint256: + return (a + b) / staticcall Foo(self).foo() + """, + "did you mean `(a + b) // staticcall Foo(self).foo()`?", + ), +] + + +@pytest.mark.parametrize("code, expected_hint", div_code_with_hint) +def test_invalid_div(code, expected_hint): with pytest.raises(InvalidOperation) as e: compile_code(code) - assert e.value._hint == "did you mean `5 // 9`?" + assert e.value._hint == expected_hint diff --git a/tests/functional/codegen/types/test_array_indexing.py b/tests/functional/codegen/types/test_array_indexing.py index 45e777d919..7f5c0d0e21 100644 --- a/tests/functional/codegen/types/test_array_indexing.py +++ b/tests/functional/codegen/types/test_array_indexing.py @@ -1,5 +1,9 @@ # TODO: rewrite the tests in type-centric way, parametrize array and indices types +import pytest + +from vyper.exceptions import CompilerPanic + def test_negative_ix_access(get_contract, tx_failed): # Arrays can't be accessed with negative indices @@ -130,3 +134,76 @@ def foo(): c.foo() for i in range(10): assert c.arr(i) == i + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_array_index_overlap(get_contract): + code = """ +a: public(DynArray[DynArray[Bytes[96], 5], 5]) + +@external +def foo() -> Bytes[96]: + self.a.append([b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx']) + return self.a[0][self.bar()] + + +@internal +def bar() -> uint256: + self.a[0] = [b'yyy'] + self.a.pop() + return 0 + """ + c = get_contract(code) + # tricky to get this right, for now we just panic instead of generating code + assert c.foo() == b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_array_index_overlap_extcall(get_contract): + code = """ + +interface Bar: + def bar() -> uint256: payable + +a: public(DynArray[DynArray[Bytes[96], 5], 5]) + +@external +def foo() -> Bytes[96]: + self.a.append([b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx']) + return self.a[0][extcall Bar(self).bar()] + + +@external +def bar() -> uint256: + self.a[0] = [b'yyy'] + self.a.pop() + return 0 + """ + c = get_contract(code) + assert c.foo() == b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_array_index_overlap_extcall2(get_contract): + code = """ +interface B: + def calculate_index() -> uint256: nonpayable + +a: HashMap[uint256, DynArray[uint256, 5]] + +@external +def bar() -> uint256: + self.a[0] = [2] + return self.a[0][extcall B(self).calculate_index()] + +@external +def calculate_index() -> uint256: + self.a[0] = [1] + return 0 + """ + c = get_contract(code) + + assert c.bar() == 1 diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index a5b119f143..8bd7cb6a2c 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import TypeMismatch @@ -259,6 +260,42 @@ def test2(l: bytes{m} = {vyper_literal}) -> bool: assert c.test2(vyper_literal) is True +@pytest.mark.parametrize("m,val", [(2, "ab"), (3, "ab"), (4, "abcd")]) +def test_native_hex_literals(get_contract, m, val): + vyper_literal = bytes.fromhex(val) + code = f""" +@external +def test() -> bool: + l: Bytes[{m}] = x"{val}" + return l == {vyper_literal} + +@external +def test2(l: Bytes[{m}] = x"{val}") -> bool: + return l == {vyper_literal} + """ + print(code) + + c = get_contract(code) + + assert c.test() is True + assert c.test2() is True + assert c.test2(vyper_literal) is True + + +def test_hex_literal_parser_edge_case(): + # see GH issue 4405 example 2 + code = """ +interface FooBar: + def test(a: Bytes[2], b: String[4]): payable + +@deploy +def __init__(ext: FooBar): + extcall ext.test(x'6161', x'6161') #ext.test(b'\x61\61', '6161') gets called + """ + with pytest.raises(TypeMismatch): + compile_code(code) + + def test_zero_padding_with_private(get_contract): code = """ counter: uint256 diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 5f26e05839..e35bec9dbc 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1,17 +1,20 @@ +import contextlib import itertools from typing import Any, Callable import pytest -from tests.utils import decimal_to_int +from tests.utils import check_precompile_asserts, decimal_to_int from vyper.compiler import compile_code +from vyper.evm.opcodes import version_check from vyper.exceptions import ( ArgumentException, ArrayIndexException, + CompilerPanic, ImmutableViolation, OverflowException, - StackTooDeep, StateAccessViolation, + StaticAssertionException, TypeMismatch, ) @@ -736,7 +739,6 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]: ] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_mult_list(get_contract): code = """ nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2] @@ -1862,9 +1864,16 @@ def should_revert() -> DynArray[String[65], 2]: @pytest.mark.parametrize("code", dynarray_length_no_clobber_cases) def test_dynarray_length_no_clobber(get_contract, tx_failed, code): # check that length is not clobbered before dynarray data copy happens - c = get_contract(code) - with tx_failed(): - c.should_revert() + try: + c = get_contract(code) + with tx_failed(): + c.should_revert() + except StaticAssertionException: + # this test should create + # assert error so if it is + # detected in compile time + # we can continue + pass def test_dynarray_make_setter_overlap(get_contract): @@ -1887,3 +1896,74 @@ def boo() -> uint256: c = get_contract(code) assert c.foo() == [1, 2, 3, 4] + + +@pytest.mark.xfail(raises=CompilerPanic) +def test_dangling_reference(get_contract, tx_failed): + code = """ +a: DynArray[DynArray[uint256, 5], 5] + +@external +def foo(): + self.a = [[1]] + self.a.pop().append(2) + """ + c = get_contract(code) + with tx_failed(): + c.foo() + + +def test_dynarray_copy_oog(env, get_contract, tx_failed): + # GHSA-vgf2-gvx8-xwc3 + code = """ + +@external +def foo(a: DynArray[uint256, 4000]) -> uint256: + b: DynArray[uint256, 4000] = a + return b[0] + """ + check_precompile_asserts(code) + + c = get_contract(code) + dynarray = [2] * 4000 + assert c.foo(dynarray) == 2 + + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(dynarray, gas=gas_used) + + +def test_dynarray_copy_oog2(env, get_contract, tx_failed): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@external +@view +def foo(x: String[1000000], y: String[1000000]) -> DynArray[String[1000000], 2]: + z: DynArray[String[1000000], 2] = [x, y] + # Some code + return z + """ + check_precompile_asserts(code) + + c = get_contract(code) + calldata0 = "a" * 10 + calldata1 = "b" * 1000000 + assert c.foo(calldata0, calldata1) == [calldata0, calldata1] + + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(calldata0, calldata1, gas=gas_used) diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 953a9a9f9f..26cd16ed32 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -1,8 +1,12 @@ +import contextlib import itertools import pytest -from tests.utils import decimal_to_int +from tests.evm_backends.base_env import EvmError +from tests.utils import check_precompile_asserts, decimal_to_int +from vyper.compiler.settings import OptimizationLevel +from vyper.evm.opcodes import version_check from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch @@ -848,3 +852,73 @@ def foo() -> {return_type}: return MY_CONSTANT[0][0] """ assert_compile_failed(lambda: get_contract(code), TypeMismatch) + + +def test_array_copy_oog(env, get_contract, tx_failed, optimize, experimental_codegen, request): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@internal +def bar(x: uint256[3000]) -> uint256[3000]: + a: uint256[3000] = x + return a + +@external +def foo(x: uint256[3000]) -> uint256: + s: uint256[3000] = self.bar(x) + return s[0] + """ + check_precompile_asserts(code) + + if optimize == OptimizationLevel.NONE and not experimental_codegen: + # fails in bytecode generation due to jumpdests too large + with pytest.raises(AssertionError): + get_contract(code) + return + + c = get_contract(code) + array = [2] * 3000 + assert c.foo(array) == array[0] + + # get the minimum gas for the contract complete execution + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(array, gas=gas_used) + + +def test_array_copy_oog2(env, get_contract, tx_failed, optimize, experimental_codegen, request): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@external +def foo(x: uint256[2500]) -> uint256: + s: uint256[2500] = x + t: uint256[2500] = s + return t[0] + """ + check_precompile_asserts(code) + + if optimize == OptimizationLevel.NONE and not experimental_codegen: + # fails in creating contract due to code too large + with tx_failed(EvmError): + get_contract(code) + return + + c = get_contract(code) + array = [2] * 2500 + assert c.foo(array) == array[0] + + # get the minimum gas for the contract complete execution + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(array, gas=gas_used) diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 51899b50f3..b4e6919ea7 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -1,5 +1,10 @@ +import contextlib + import pytest +from tests.utils import check_precompile_asserts +from vyper.evm.opcodes import version_check + def test_string_return(get_contract): code = """ @@ -116,7 +121,7 @@ def test_logging_extended_string(get_contract, get_logs): @external def foo(): - log MyLog(667788, 'hellohellohellohellohellohellohellohellohello', 334455) + log MyLog(arg1=667788, arg2='hellohellohellohellohellohellohellohellohello', arg3=334455) """ c = get_contract(code) @@ -359,3 +364,56 @@ def compare_var_storage_not_equal_false() -> bool: assert c.compare_var_storage_equal_false() is False assert c.compare_var_storage_not_equal_true() is True assert c.compare_var_storage_not_equal_false() is False + + +def test_string_copy_oog(env, get_contract, tx_failed): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@external +@view +def foo(x: String[1000000]) -> String[1000000]: + return x + """ + check_precompile_asserts(code) + + c = get_contract(code) + calldata = "a" * 1000000 + assert c.foo(calldata) == calldata + + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(calldata, gas=gas_used) + + +def test_string_copy_oog2(env, get_contract, tx_failed): + # GHSA-vgf2-gvx8-xwc3 + code = """ +@external +@view +def foo(x: String[1000000]) -> uint256: + y: String[1000000] = x + return len(y) + """ + check_precompile_asserts(code) + + c = get_contract(code) + calldata = "a" * 1000000 + assert c.foo(calldata) == len(calldata) + + gas_used = env.last_result.gas_used + if version_check(begin="cancun"): + ctx = contextlib.nullcontext + else: + ctx = tx_failed + + with ctx(): + # depends on EVM version. pre-cancun, will revert due to checking + # success flag from identity precompile. + c.foo(calldata, gas=gas_used) diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index de399e84b7..871ba4547f 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -9,7 +9,7 @@ from vyper.ast import Module, parse_to_ast from vyper.ast.grammar import parse_vyper_source, vyper_grammar -from vyper.ast.pre_parser import pre_parse +from vyper.ast.pre_parser import PreParser def test_basic_grammar(): @@ -37,9 +37,9 @@ def test_basic_grammar_empty(): assert len(tree.children) == 0 -def fix_terminal(terminal: str) -> bool: +def fix_terminal(terminal: str) -> str: # these throw exceptions in the grammar - for bad in ("\x00", "\\ ", "\x0c"): + for bad in ("\x00", "\\ ", "\x0c", "\x0d"): terminal = terminal.replace(bad, " ") return terminal @@ -102,6 +102,7 @@ def has_no_docstrings(c): max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much] ) def test_grammar_bruteforce(code): - _, _, _, reformatted_code = pre_parse(code + "\n") - tree = parse_to_ast(reformatted_code) + pre_parser = PreParser() + pre_parser.parse(code + "\n") + tree = parse_to_ast(pre_parser.reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4c6af8f6da..484e291da1 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -130,6 +130,18 @@ def bar()->DynArray[uint16,3]: @view def topup(amount: uint256): assert extcall self.token.transferFrom(msg.sender, self, amount) + """, + """ +@external +@view +def foo(_topic: bytes32): + raw_log([_topic], b"") + """, + """ +@external +@pure +def foo(_topic: bytes32): + raw_log([_topic], b"") """, ], ) diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index a0cf10ad02..f3fd73fbfc 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -36,6 +36,14 @@ def foo(): def foo(): a: bytes32 = keccak256("ѓtest") """, + # test constant folding inside of `convert()` + """ +BAR: constant(uint16) = 256 + +@external +def foo(): + a: uint8 = convert(BAR, uint8) + """, ] diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 80f499ac89..b95e63f598 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import SyntaxException fail_list = [ @@ -107,5 +108,30 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) -def test_syntax_exception(assert_compile_failed, get_contract, bad_code): - assert_compile_failed(lambda: get_contract(bad_code), SyntaxException) +def test_syntax_exception(bad_code): + with pytest.raises(SyntaxException): + compile_code(bad_code) + + +def test_bad_staticcall_keyword(): + bad_code = """ +from ethereum.ercs import IERC20Detailed + +def foo(): + staticcall ERC20(msg.sender).transfer(msg.sender, staticall IERC20Detailed(msg.sender).decimals()) + """ # noqa + with pytest.raises(SyntaxException) as e: + compile_code(bad_code) + + expected_error = """ +invalid syntax. Perhaps you forgot a comma? (, line 5) + + contract ":5", line 5:54 + 4 def foo(): + ---> 5 staticcall ERC20(msg.sender).transfer(msg.sender, staticall IERC20Detailed(msg.sender).decimals()) + -------------------------------------------------------------^ + 6 + + (hint: did you mean `staticcall`?) + """ # noqa + assert str(e.value) == expected_error.strip() diff --git a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py index 514f2df618..63e0eb6d11 100644 --- a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py +++ b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py @@ -41,12 +41,20 @@ def foo(): message: String[1] @external def foo(): - log Foo("abcd") + log Foo(message="abcd") """, # Address literal must be checksummed """ a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 """, + # test constant folding inside `convert()` + """ +BAR: constant(Bytes[5]) = b"vyper" + +@external +def foo(): + a: Bytes[4] = convert(BAR, Bytes[4]) + """, ] diff --git a/tests/functional/syntax/exceptions/test_undeclared_definition.py b/tests/functional/syntax/exceptions/test_undeclared_definition.py index f90aa4137b..5786b37b1f 100644 --- a/tests/functional/syntax/exceptions/test_undeclared_definition.py +++ b/tests/functional/syntax/exceptions/test_undeclared_definition.py @@ -66,5 +66,6 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) def test_undeclared_def_exception(bad_code): - with pytest.raises(UndeclaredDefinition): + with pytest.raises(UndeclaredDefinition) as e: compiler.compile_code(bad_code) + assert "(hint: )" not in str(e.value) diff --git a/tests/functional/syntax/exceptions/test_unknown_type.py b/tests/functional/syntax/exceptions/test_unknown_type.py new file mode 100644 index 0000000000..cd8866d5cb --- /dev/null +++ b/tests/functional/syntax/exceptions/test_unknown_type.py @@ -0,0 +1,15 @@ +import pytest + +from vyper import compiler +from vyper.exceptions import UnknownType + + +def test_unknown_type_exception(): + code = """ +@internal +def foobar(token: IERC20): + pass + """ + with pytest.raises(UnknownType) as e: + compiler.compile_code(code) + assert "(hint: )" not in str(e.value) diff --git a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py index 9e0767cb83..17bd4de1cd 100644 --- a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py +++ b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py @@ -1,6 +1,7 @@ from pytest import raises -from vyper.exceptions import VyperException +from vyper import compile_code +from vyper.exceptions import SyntaxException, VyperException def test_type_exception_pos(): @@ -29,3 +30,32 @@ def __init__(): """ assert_compile_failed(lambda: get_contract(code), VyperException) + + +def test_exception_contains_file(make_input_bundle): + code = """ +def bar()>: + """ + input_bundle = make_input_bundle({"code.vy": code}) + with raises(SyntaxException, match="contract"): + compile_code(code, input_bundle=input_bundle) + + +def test_exception_reports_correct_file(make_input_bundle, chdir_tmp_path): + code_a = "def bar()>:" + code_b = "import A" + input_bundle = make_input_bundle({"A.vy": code_a, "B.vy": code_b}) + + with raises(SyntaxException, match=r'contract "A\.vy:\d+"'): + compile_code(code_b, input_bundle=input_bundle) + + +def test_syntax_exception_reports_correct_offset(make_input_bundle): + code = """ +def foo(): + uint256 a = pass + """ + input_bundle = make_input_bundle({"code.vy": code}) + + with raises(SyntaxException, match=r"line \d+:12"): + compile_code(code, input_bundle=input_bundle) diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py index f51bf9575b..c908d4adae 100644 --- a/tests/functional/syntax/modules/test_deploy_visibility.py +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -1,7 +1,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import CallViolation +from vyper.exceptions import CallViolation, UnknownAttribute def test_call_deploy_from_external(make_input_bundle): @@ -25,3 +25,35 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value.message == "Cannot call an @deploy function from an @external function!" + + +@pytest.mark.parametrize("interface_syntax", ["__interface__", "__at__"]) +def test_module_interface_init(make_input_bundle, tmp_path, interface_syntax): + lib1 = """ +#lib1.vy +k: uint256 + +@external +def bar(): + pass + +@deploy +def __init__(): + self.k = 10 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + code = f""" +import lib1 + +@deploy +def __init__(): + lib1.{interface_syntax}(self).__init__() + """ + + with pytest.raises(UnknownAttribute) as e: + compile_code(code, input_bundle=input_bundle) + + # as_posix() for windows tests + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value.message == f"interface {lib1_path} has no member '__init__'." diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 7b00d29c98..4314c1bbf0 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -385,6 +385,28 @@ def do_xyz(): assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" +def test_no_export_unimplemented_inline_interface(make_input_bundle): + lib1 = """ +interface ifoo: + def do_xyz(): nonpayable + +# technically implements ifoo, but missing `implements: ifoo` + +@external +def do_xyz(): + pass + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + def test_export_selector_conflict(make_input_bundle): ifoo = """ @external @@ -444,3 +466,87 @@ def __init__(): with pytest.raises(InterfaceViolation) as e: compile_code(main, input_bundle=input_bundle) assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + +def test_export_empty_interface(make_input_bundle, tmp_path): + lib1 = """ +def an_internal_function(): + pass + """ + main = """ +import lib1 + +exports: lib1.__interface__ + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + # as_posix() for windows + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value._message == f"lib1 (located at `{lib1_path}`) has no external functions!" + + +def test_invalid_export(make_input_bundle): + lib1 = """ +@external +def foo(): + pass + """ + main = """ +import lib1 +a: address + +exports: lib1.__interface__(self.a).foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "invalid export of a value" + assert e.value._hint == "exports should look like ." + + main = """ +interface Foo: + def foo(): nonpayable + +exports: Foo + """ + with pytest.raises(StructureException) as e: + compile_code(main) + + assert e.value._message == "invalid export" + assert e.value._hint == "exports should look like ." + + +@pytest.mark.parametrize("exports_item", ["__at__", "__at__(self)", "__at__(self).__interface__"]) +def test_invalid_at_exports(get_contract, make_input_bundle, exports_item): + lib = """ +@external +@view +def foo() -> uint256: + return 5 + """ + + main = f""" +import lib + +exports: lib.{exports_item} + +@external +@view +def bar() -> uint256: + return staticcall lib.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + + with pytest.raises(Exception) as e: + compile_code(main, input_bundle=input_bundle) + + if exports_item == "__at__": + assert "not a function or interface" in str(e.value) + if exports_item == "__at__(self)": + assert "invalid exports" in str(e.value) + if exports_item == "__at__(self).__interface__": + assert "has no member '__interface__'" in str(e.value) diff --git a/tests/functional/syntax/names/test_event_names.py b/tests/functional/syntax/names/test_event_names.py index 367b646bfe..28cd6bdad0 100644 --- a/tests/functional/syntax/names/test_event_names.py +++ b/tests/functional/syntax/names/test_event_names.py @@ -26,7 +26,7 @@ def foo(i: int128) -> int128: @external def foo(i: int128) -> int128: temp_var : int128 = i - log int128(temp_var) + log int128(variable=temp_var) return temp_var """, NamespaceCollision, @@ -39,7 +39,7 @@ def foo(i: int128) -> int128: @external def foo(i: int128) -> int128: temp_var : int128 = i - log decimal(temp_var) + log decimal(variable=temp_var) return temp_var """, NamespaceCollision, @@ -52,7 +52,7 @@ def foo(i: int128) -> int128: @external def foo(i: int128) -> int128: temp_var : int128 = i - log wei(temp_var) + log wei(variable=temp_var) return temp_var """, StructureException, @@ -65,7 +65,7 @@ def foo(i: int128) -> int128: @external def foo(i: int128) -> int128: temp_var : int128 = i - log false(temp_var) + log false(variable=temp_var) return temp_var """, StructureException, @@ -102,7 +102,7 @@ def test_varname_validity_fail(bad_code, exc): @external def foo(i: int128) -> int128: variable : int128 = i - log Assigned(variable) + log Assigned(variable=variable) return variable """, """ @@ -122,7 +122,7 @@ def foo(i: int128) -> int128: @external def foo(i: int128) -> int128: variable : int128 = i - log Assigned1(variable) + log Assigned1(variable=variable) return variable """, ] diff --git a/tests/functional/syntax/test_abi_encode.py b/tests/functional/syntax/test_abi_encode.py index 5e0175857d..edb441652a 100644 --- a/tests/functional/syntax/test_abi_encode.py +++ b/tests/functional/syntax/test_abi_encode.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import TypeMismatch +from vyper.exceptions import InvalidLiteral, TypeMismatch fail_list = [ ( @@ -41,11 +41,37 @@ def foo(x: uint256) -> Bytes[36]: ( """ @external +def foo(x: uint256) -> Bytes[36]: + return _abi_encode(x, method_id=b"abc") + """, + InvalidLiteral, # len(method_id) must be greater than 3 + ), + ( + """ +@external def foo(x: uint256) -> Bytes[36]: return _abi_encode(x, method_id=0x1234567890) """, TypeMismatch, # len(method_id) must be less than 4 ), + ( + """ +@external +def foo(x: uint256) -> Bytes[36]: + return _abi_encode(x, method_id=0x123456) + """, + TypeMismatch, # len(method_id) must be greater than 3 + ), + ( + """ +@external +def foo() -> Bytes[132]: + x: uint256 = 1 + y: Bytes[32] = b"234" + return abi_encode(x, y, method_id=b"") + """, + InvalidLiteral, # len(method_id) must be 4 + ), ] @@ -82,6 +108,11 @@ def foo(x: Bytes[1]) -> Bytes[68]: return _abi_encode(x, ensure_tuple=False, method_id=0x12345678) """, """ +@external +def foo(x: Bytes[1]) -> Bytes[68]: + return _abi_encode(x, ensure_tuple=False, method_id=b"1234") + """, + """ BAR: constant(DynArray[uint256, 5]) = [1, 2, 3, 4, 5] @external diff --git a/tests/functional/syntax/test_ann_assign.py b/tests/functional/syntax/test_ann_assign.py index 23ebeb9560..fba9eff38d 100644 --- a/tests/functional/syntax/test_ann_assign.py +++ b/tests/functional/syntax/test_ann_assign.py @@ -3,11 +3,11 @@ from vyper import compiler from vyper.exceptions import ( + InstantiationException, InvalidAttribute, TypeMismatch, UndeclaredDefinition, UnknownAttribute, - VariableDeclarationException, ) fail_list = [ @@ -73,7 +73,7 @@ def foo() -> int128: def foo() -> int128: s: S = S(a=1) """, - VariableDeclarationException, + InstantiationException, ), ( """ diff --git a/tests/functional/syntax/test_bytes.py b/tests/functional/syntax/test_bytes.py index 0ca3b27fee..06c3c1f443 100644 --- a/tests/functional/syntax/test_bytes.py +++ b/tests/functional/syntax/test_bytes.py @@ -80,6 +80,25 @@ def test() -> Bytes[1]: ( """ @external +def test() -> Bytes[2]: + a: Bytes[2] = x"abc" # non-hex nibbles + return a + """, + SyntaxException, + ), + ( + """ +@external +def test() -> Bytes[10]: + # GH issue 4405 example 1 + a: Bytes[10] = x x x x x x"61" # messed up hex prefix + return a + """, + SyntaxException, + ), + ( + """ +@external def foo(): a: Bytes = b"abc" """, @@ -98,6 +117,24 @@ def test_bytes_fail(bad_code): compiler.compile_code(bad_code) +@pytest.mark.xfail +def test_hexbytes_offset(): + good_code = """ + event X: + a: Bytes[2] + +@deploy +def __init__(): + # GH issue 4405, example 1 + # + # log changes offset of HexString, and the hex_string_locations tracked + # location is incorrect when visiting ast + log X(a = x"6161") + """ + # move this to valid list once it passes. + assert compiler.compile_code(good_code) is not None + + valid_list = [ """ @external diff --git a/tests/functional/syntax/test_external_calls.py b/tests/functional/syntax/test_external_calls.py index 79f3f6db93..fd6fa28cc9 100644 --- a/tests/functional/syntax/test_external_calls.py +++ b/tests/functional/syntax/test_external_calls.py @@ -61,7 +61,7 @@ def foo(f: Foo): s: uint256 = staticcall f.foo() """, # TODO: tokenizer currently has issue with log+staticcall/extcall, e.g. - # `log Bar(staticcall f.foo() + extcall f.bar())` + # `log Bar(_value=staticcall f.foo() + extcall f.bar())` ] @@ -305,7 +305,7 @@ def bar(): extcall Foo(msg.sender) """, StructureException, - "Function `type(interface Foo)` cannot be called without assigning the result", + "Function `type(Foo)` cannot be called without assigning the result", None, ), ] diff --git a/tests/functional/syntax/test_import.py b/tests/functional/syntax/test_import.py new file mode 100644 index 0000000000..acc556206e --- /dev/null +++ b/tests/functional/syntax/test_import.py @@ -0,0 +1,132 @@ +import pytest + +from vyper import compiler +from vyper.exceptions import ModuleNotFound + +CODE_TOP = """ +import subdir0.lib0 as lib0 +@external +def foo(): + lib0.foo() +""" + +CODE_LIB1 = """ +def foo(): + pass +""" + + +def test_implicitly_relative_import_crashes(make_input_bundle): + lib0 = """ +import subdir1.lib1 as lib1 +def foo(): + lib1.foo() + """ + + input_bundle = make_input_bundle( + {"top.vy": CODE_TOP, "subdir0/lib0.vy": lib0, "subdir0/subdir1/lib1.vy": CODE_LIB1} + ) + + file_input = input_bundle.load_file("top.vy") + with pytest.raises(ModuleNotFound) as e: + compiler.compile_from_file_input(file_input, input_bundle=input_bundle) + assert "lib0.vy:" in str(e.value) + + +def test_implicitly_relative_import_crashes_2(make_input_bundle): + lib0 = """ +from subdir1 import lib1 as lib1 +def foo(): + lib1.foo() + """ + + input_bundle = make_input_bundle( + {"top.vy": CODE_TOP, "subdir0/lib0.vy": lib0, "subdir0/subdir1/lib1.vy": CODE_LIB1} + ) + + file_input = input_bundle.load_file("top.vy") + with pytest.raises(ModuleNotFound) as e: + compiler.compile_from_file_input(file_input, input_bundle=input_bundle) + assert "lib0.vy:" in str(e.value) + + +def test_relative_import_searches_only_current_path(make_input_bundle): + top = """ +from subdir import b as b +@external +def foo(): + b.foo() + """ + + a = """ +def foo(): + pass + """ + + b = """ +from . import a as a +def foo(): + a.foo() + """ + + input_bundle = make_input_bundle({"top.vy": top, "a.vy": a, "subdir/b.vy": b}) + file_input = input_bundle.load_file("top.vy") + + with pytest.raises(ModuleNotFound) as e: + compiler.compile_from_file_input(file_input, input_bundle=input_bundle) + assert "b.vy:" in str(e.value) + + +def test_absolute_import_within_relative_import(make_input_bundle): + top = """ +import subdir0.subdir1.c as c +@external +def foo(): + c.foo() + """ + a = """ +import subdir0.b as b +def foo(): + b.foo() + """ + b = """ +def foo(): + pass + """ + + c = """ +from .. import a as a +def foo(): + a.foo() + """ + + input_bundle = make_input_bundle( + {"top.vy": top, "subdir0/a.vy": a, "subdir0/b.vy": b, "subdir0/subdir1/c.vy": c} + ) + compiler.compile_code(top, input_bundle=input_bundle) + + +def test_absolute_path_passes(make_input_bundle): + lib0 = """ +import subdir0.subdir1.lib1 as lib1 +def foo(): + lib1.foo() + """ + + input_bundle = make_input_bundle( + {"top.vy": CODE_TOP, "subdir0/lib0.vy": lib0, "subdir0/subdir1/lib1.vy": CODE_LIB1} + ) + compiler.compile_code(CODE_TOP, input_bundle=input_bundle) + + +def test_absolute_path_passes_2(make_input_bundle): + lib0 = """ +from .subdir1 import lib1 as lib1 +def foo(): + lib1.foo() + """ + + input_bundle = make_input_bundle( + {"top.vy": CODE_TOP, "subdir0/lib0.vy": lib0, "subdir0/subdir1/lib1.vy": CODE_LIB1} + ) + compiler.compile_code(CODE_TOP, input_bundle=input_bundle) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 113629220e..fcfa5ba0c9 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -3,6 +3,7 @@ from vyper import compiler from vyper.exceptions import ( ArgumentException, + FunctionDeclarationException, InterfaceViolation, InvalidReference, InvalidType, @@ -157,12 +158,12 @@ def f(a: uint256): # visibility is nonpayable instead of view @external def transfer(_to : address, _value : uint256) -> bool: - log Transfer(msg.sender, _to, _value) + log Transfer(sender=msg.sender, receiver=_to, value=_value) return True @external def transferFrom(_from : address, _to : address, _value : uint256) -> bool: - log IERC20.Transfer(_from, _to, _value) + log IERC20.Transfer(sender=_from, receiver=_to, value=_value) return True @external @@ -380,13 +381,22 @@ def test_interfaces_success(good_code): def test_imports_and_implements_within_interface(make_input_bundle): - interface_code = """ + ibar_code = """ @external def foobar(): ... """ + ifoo_code = """ +import bar - input_bundle = make_input_bundle({"foo.vyi": interface_code}) +implements: bar + +@external +def foobar(): + ... +""" + + input_bundle = make_input_bundle({"foo.vyi": ifoo_code, "bar.vyi": ibar_code}) code = """ import foo as Foo @@ -401,23 +411,218 @@ def foobar(): assert compiler.compile_code(code, input_bundle=input_bundle) is not None -def test_builtins_not_found(): +def test_builtins_not_found(make_input_bundle): code = """ from vyper.interfaces import foobar """ + input_bundle = make_input_bundle({"code.vy": code}) + file_input = input_bundle.load_file("code.vy") with pytest.raises(ModuleNotFound) as e: - compiler.compile_code(code) - + compiler.compile_from_file_input(file_input, input_bundle=input_bundle) assert e.value._message == "vyper.interfaces.foobar" assert e.value._hint == "try renaming `vyper.interfaces` to `ethereum.ercs`" + assert "code.vy:" in str(e.value) @pytest.mark.parametrize("erc", ("ERC20", "ERC721", "ERC4626")) -def test_builtins_not_found2(erc): +def test_builtins_not_found2(erc, make_input_bundle): code = f""" from ethereum.ercs import {erc} """ + input_bundle = make_input_bundle({"code.vy": code}) + file_input = input_bundle.load_file("code.vy") with pytest.raises(ModuleNotFound) as e: - compiler.compile_code(code) + compiler.compile_from_file_input(file_input, input_bundle=input_bundle) assert e.value._message == f"ethereum.ercs.{erc}" assert e.value._hint == f"try renaming `{erc}` to `I{erc}`" + assert "code.vy:" in str(e.value) + + +def test_interface_body_check(make_input_bundle): + interface_code = """ +@external +def foobar(): + return ... +""" + + input_bundle = make_input_bundle({"foo.vyi": interface_code}) + + code = """ +import foo as Foo + +implements: Foo + +@external +def foobar(): + pass +""" + with pytest.raises(FunctionDeclarationException) as e: + compiler.compile_code(code, input_bundle=input_bundle) + + assert e.value._message == "function body in an interface can only be `...`!" + + +def test_interface_body_check2(make_input_bundle): + interface_code = """ +@external +def foobar(): + ... + +@external +def bar(): + ... + +@external +def baz(): + ... +""" + + input_bundle = make_input_bundle({"foo.vyi": interface_code}) + + code = """ +import foo + +implements: foo + +@external +def foobar(): + pass + +@external +def bar(): + pass + +@external +def baz(): + pass +""" + + assert compiler.compile_code(code, input_bundle=input_bundle) is not None + + +invalid_visibility_code = [ + """ +import foo as Foo +implements: Foo +@external +def foobar(): + pass + """, + """ +import foo as Foo +implements: Foo +@internal +def foobar(): + pass + """, + """ +import foo as Foo +implements: Foo +def foobar(): + pass + """, +] + + +@pytest.mark.parametrize("code", invalid_visibility_code) +def test_internal_visibility_in_interface(make_input_bundle, code): + interface_code = """ +@internal +def foobar(): + ... +""" + + input_bundle = make_input_bundle({"foo.vyi": interface_code}) + + with pytest.raises(FunctionDeclarationException) as e: + compiler.compile_code(code, input_bundle=input_bundle) + + assert e.value._message == "Interface functions can only be marked as `@external`" + + +external_visibility_interface = [ + """ +@external +def foobar(): + ... +def bar(): + ... + """, + """ +def foobar(): + ... +@external +def bar(): + ... + """, +] + + +@pytest.mark.parametrize("iface", external_visibility_interface) +def test_internal_implemenatation_of_external_interface(make_input_bundle, iface): + input_bundle = make_input_bundle({"foo.vyi": iface}) + + code = """ +import foo as Foo +implements: Foo +@internal +def foobar(): + pass +def bar(): + pass + """ + + with pytest.raises(InterfaceViolation) as e: + compiler.compile_code(code, input_bundle=input_bundle) + + assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()" + + +def test_intrinsic_interfaces_different_types(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + lib2 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +@external +def bar(): + assert lib1.__at__(self) == lib2.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(TypeMismatch): + compiler.compile_code(main, input_bundle=input_bundle) + + +@pytest.mark.xfail +def test_intrinsic_interfaces_default_function(make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def __default__(): + pass + """ + main = """ +import lib1 + +@external +def bar(): + extcall lib1.__at__(self).__default__() + + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + # TODO make the exception more precise once fixed + with pytest.raises(Exception): # noqa: B017 + compiler.compile_code(main, input_bundle=input_bundle) diff --git a/tests/functional/syntax/test_logging.py b/tests/functional/syntax/test_logging.py index b96700a128..7f8f141b99 100644 --- a/tests/functional/syntax/test_logging.py +++ b/tests/functional/syntax/test_logging.py @@ -1,7 +1,13 @@ import pytest from vyper import compiler -from vyper.exceptions import StructureException, TypeMismatch +from vyper.exceptions import ( + InstantiationException, + InvalidAttribute, + StructureException, + TypeMismatch, + UnknownAttribute, +) fail_list = [ """ @@ -12,7 +18,7 @@ @external def foo(): - log Bar(self.x) + log Bar(_value=self.x) """, """ event Bar: @@ -21,7 +27,7 @@ def foo(): @external def foo(): x: decimal[4] = [0.0, 0.0, 0.0, 0.0] - log Bar(x) + log Bar(_value=x) """, """ struct Foo: @@ -37,7 +43,7 @@ def foo(): @external def test(): - log Test(-7) + log Test(n=-7) """, ] @@ -48,6 +54,61 @@ def test_logging_fail(bad_code): compiler.compile_code(bad_code) +def test_logging_fail_mixed_positional_kwargs(): + code = """ +event Test: + n: uint256 + o: uint256 + +@external +def test(): + log Test(7, o=12) + """ + with pytest.raises(InstantiationException): + compiler.compile_code(code) + + +def test_logging_fail_unknown_kwarg(): + code = """ +event Test: + n: uint256 + +@external +def test(): + log Test(n=7, o=12) + """ + with pytest.raises(UnknownAttribute): + compiler.compile_code(code) + + +def test_logging_fail_missing_kwarg(): + code = """ +event Test: + n: uint256 + o: uint256 + +@external +def test(): + log Test(n=7) + """ + with pytest.raises(InstantiationException): + compiler.compile_code(code) + + +def test_logging_fail_kwargs_out_of_order(): + code = """ +event Test: + n: uint256 + o: uint256 + +@external +def test(): + log Test(o=12, n=7) + """ + with pytest.raises(InvalidAttribute): + compiler.compile_code(code) + + @pytest.mark.parametrize("mutability", ["@pure", "@view"]) @pytest.mark.parametrize("visibility", ["@internal", "@external"]) def test_logging_from_non_mutable(mutability, visibility): @@ -58,7 +119,23 @@ def test_logging_from_non_mutable(mutability, visibility): {visibility} {mutability} def test(): - log Test(1) + log Test(n=1) """ with pytest.raises(StructureException): compiler.compile_code(code) + + +def test_logging_with_positional_args(get_contract, get_logs): + # TODO: Remove when positional arguments are fully deprecated + code = """ +event Test: + n: uint256 + +@external +def test(): + log Test(1) + """ + c = get_contract(code) + c.test() + (log,) = get_logs(c, "Test") + assert log.args.n == 1 diff --git a/tests/functional/syntax/test_slice.py b/tests/functional/syntax/test_slice.py index 6bb666527e..6a091c9da3 100644 --- a/tests/functional/syntax/test_slice.py +++ b/tests/functional/syntax/test_slice.py @@ -53,6 +53,22 @@ def foo(inp: Bytes[10]) -> Bytes[4]: def foo() -> Bytes[10]: return slice(b"badmintonzzz", 1, 10) """, + # test constant folding for `slice()` `length` argument + """ +@external +def foo(): + x: Bytes[32] = slice(msg.data, 0, 31 + 1) + """, + """ +@external +def foo(a: address): + x: Bytes[32] = slice(a.code, 0, 31 + 1) + """, + """ +@external +def foo(inp: Bytes[5], start: uint256) -> Bytes[3]: + return slice(inp, 0, 1 + 1) + """, ] diff --git a/tests/functional/syntax/test_string.py b/tests/functional/syntax/test_string.py index 77cb7eaee6..1dc354f773 100644 --- a/tests/functional/syntax/test_string.py +++ b/tests/functional/syntax/test_string.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import StructureException +from vyper.exceptions import InvalidLiteral, StructureException valid_list = [ """ @@ -11,25 +11,13 @@ def foo() -> String[10]: """, """ @external -def foo(): - x: String[11] = "‘très bien!" - """, - """ -@external def foo() -> bool: - x: String[15] = "‘très bien!" + x: String[15] = "tres bien!" y: String[15] = "test" return x != y """, """ @external -def foo() -> bool: - x: String[15] = "‘très bien!" - y: String[12] = "test" - return x != y - """, - """ -@external def test() -> String[100]: return "hello world!" """, @@ -46,13 +34,36 @@ def test_string_success(good_code): """ @external def foo(): + # invalid type annotation - should be String[N] a: String = "abc" """, StructureException, - ) + ), + ( + """ +@external +@view +def compile_hash() -> bytes32: + # GH issue #3088 - ord("è") == 232 + return keccak256("è") + """, + InvalidLiteral, + ), + ( + """ +@external +def foo() -> bool: + # ord("‘") == 161 + x: String[15] = "‘très bien!" + y: String[12] = "test" + return x != y + """, + InvalidLiteral, + ), ] @pytest.mark.parametrize("bad_code,exc", invalid_list) -def test_string_fail(assert_compile_failed, get_contract, bad_code, exc): - assert_compile_failed(lambda: get_contract(bad_code), exc) +def test_string_fail(get_contract, bad_code, exc): + with pytest.raises(exc): + compiler.compile_code(bad_code) diff --git a/tests/functional/syntax/test_structs.py b/tests/functional/syntax/test_structs.py index 9a9a397c48..c08859cd92 100644 --- a/tests/functional/syntax/test_structs.py +++ b/tests/functional/syntax/test_structs.py @@ -5,6 +5,7 @@ from vyper import compiler from vyper.exceptions import ( InstantiationException, + InvalidAttribute, StructureException, SyntaxException, TypeMismatch, @@ -32,7 +33,8 @@ def foo(): """, UnknownAttribute, ), - """ + ( + """ struct A: x: int128 y: int128 @@ -41,6 +43,8 @@ def foo(): def foo(): self.a = A(x=1) """, + InstantiationException, + ), """ struct A: x: int128 @@ -61,7 +65,8 @@ def foo(): def foo(): self.a = A(self.b) """, - """ + ( + """ struct A: x: int128 y: int128 @@ -70,6 +75,8 @@ def foo(): def foo(): self.a = A({x: 1}) """, + InstantiationException, + ), """ struct C: c: int128 @@ -386,7 +393,7 @@ def foo(): def foo(): self.b = B(foo=1, foo=2) """, - UnknownAttribute, + InvalidAttribute, ), ( """ diff --git a/tests/functional/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py index 04835bb0f0..a1faa1c6a5 100644 --- a/tests/functional/syntax/test_unbalanced_return.py +++ b/tests/functional/syntax/test_unbalanced_return.py @@ -195,7 +195,7 @@ def test() -> int128: if 1 == 1 : return 1 else: - assert msg.sender != msg.sender + assert msg.sender != self return 0 """, """ diff --git a/tests/functional/venom/__init__.py b/tests/functional/venom/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional/venom/parser/__init__.py b/tests/functional/venom/parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional/venom/parser/test_parsing.py b/tests/functional/venom/parser/test_parsing.py new file mode 100644 index 0000000000..bd536a8cfa --- /dev/null +++ b/tests/functional/venom/parser/test_parsing.py @@ -0,0 +1,352 @@ +from tests.venom_utils import assert_bb_eq, assert_ctx_eq +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral, IRVariable +from vyper.venom.context import DataItem, DataSection, IRContext +from vyper.venom.function import IRFunction +from vyper.venom.parser import parse_venom + + +def test_single_bb(): + source = """ + function main { + main: + stop + } + """ + + parsed_ctx = parse_venom(source) + + expected_ctx = IRContext() + expected_ctx.add_function(main_fn := IRFunction(IRLabel("main"))) + main_bb = main_fn.get_basic_block("main") + main_bb.append_instruction("stop") + + assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_multi_bb_single_fn(): + source = """ + function start { + start: + %1 = callvalue + jnz %1, @fine, @has_callvalue + fine: + %2 = calldataload 4 + %4 = add %2, 279387 + return %2, %4 + has_callvalue: + revert 0, 0 + } + """ + + parsed_ctx = parse_venom(source) + + expected_ctx = IRContext() + expected_ctx.add_function(start_fn := IRFunction(IRLabel("start"))) + + start_bb = start_fn.get_basic_block("start") + start_bb.append_instruction("callvalue", ret=IRVariable("1")) + start_bb.append_instruction("jnz", IRVariable("1"), IRLabel("fine"), IRLabel("has_callvalue")) + + start_fn.append_basic_block(fine_bb := IRBasicBlock(IRLabel("fine"), start_fn)) + fine_bb.append_instruction("calldataload", IRLiteral(4), ret=IRVariable("2")) + fine_bb.append_instruction("add", IRLiteral(279387), IRVariable("2"), ret=IRVariable("4")) + fine_bb.append_instruction("return", IRVariable("4"), IRVariable("2")) + + has_callvalue_bb = IRBasicBlock(IRLabel("has_callvalue"), start_fn) + start_fn.append_basic_block(has_callvalue_bb) + has_callvalue_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) + has_callvalue_bb.append_instruction("stop") + + assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_data_section(): + parsed_ctx = parse_venom( + """ + function entry { + entry: + stop + } + + data readonly { + dbsection selector_buckets: + db @selector_bucket_0 + db @fallback + db @selector_bucket_2 + db @selector_bucket_3 + db @fallback + db @selector_bucket_5 + db @selector_bucket_6 + } + """ + ) + + expected_ctx = IRContext() + expected_ctx.add_function(entry_fn := IRFunction(IRLabel("entry"))) + entry_fn.get_basic_block("entry").append_instruction("stop") + + expected_ctx.data_segment = [ + DataSection( + IRLabel("selector_buckets"), + [ + DataItem(IRLabel("selector_bucket_0")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_2")), + DataItem(IRLabel("selector_bucket_3")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_5")), + DataItem(IRLabel("selector_bucket_6")), + ], + ) + ] + + assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_multi_function(): + parsed_ctx = parse_venom( + """ + function entry { + entry: + invoke @check_cv + jmp @wow + wow: + mstore 0, 1 + return 0, 32 + } + + function check_cv { + check_cv: + %1 = callvalue + %2 = param + jnz %1, @has_value, @no_value + no_value: + ret %2 + has_value: + revert 0, 0 + } + """ + ) + + expected_ctx = IRContext() + expected_ctx.add_function(entry_fn := IRFunction(IRLabel("entry"))) + + entry_bb = entry_fn.get_basic_block("entry") + entry_bb.append_instruction("invoke", IRLabel("check_cv")) + entry_bb.append_instruction("jmp", IRLabel("wow")) + + entry_fn.append_basic_block(wow_bb := IRBasicBlock(IRLabel("wow"), entry_fn)) + wow_bb.append_instruction("mstore", IRLiteral(1), IRLiteral(0)) + wow_bb.append_instruction("return", IRLiteral(32), IRLiteral(0)) + + expected_ctx.add_function(check_fn := IRFunction(IRLabel("check_cv"))) + + check_entry_bb = check_fn.get_basic_block("check_cv") + check_entry_bb.append_instruction("callvalue", ret=IRVariable("1")) + check_entry_bb.append_instruction("param", ret=IRVariable("2")) + check_entry_bb.append_instruction( + "jnz", IRVariable("1"), IRLabel("has_value"), IRLabel("no_value") + ) + check_fn.append_basic_block(no_value_bb := IRBasicBlock(IRLabel("no_value"), check_fn)) + no_value_bb.append_instruction("ret", IRVariable("2")) + + check_fn.append_basic_block(value_bb := IRBasicBlock(IRLabel("has_value"), check_fn)) + value_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) + value_bb.append_instruction("stop") + + assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_multi_function_and_data(): + parsed_ctx = parse_venom( + """ + function entry { + entry: + invoke @check_cv + jmp @wow + wow: + mstore 0, 1 + return 0, 32 + } + + function check_cv { + check_cv: + %1 = callvalue + %2 = param + jnz %1, @has_value, @no_value + no_value: + ret %2 + has_value: + revert 0, 0 + } + + data readonly { + dbsection selector_buckets: + db @selector_bucket_0 + db @fallback + db @selector_bucket_2 + db @selector_bucket_3 + db @selector_bucket_6 + } + """ + ) + + expected_ctx = IRContext() + expected_ctx.add_function(entry_fn := IRFunction(IRLabel("entry"))) + + entry_bb = entry_fn.get_basic_block("entry") + entry_bb.append_instruction("invoke", IRLabel("check_cv")) + entry_bb.append_instruction("jmp", IRLabel("wow")) + + entry_fn.append_basic_block(wow_bb := IRBasicBlock(IRLabel("wow"), entry_fn)) + wow_bb.append_instruction("mstore", IRLiteral(1), IRLiteral(0)) + wow_bb.append_instruction("return", IRLiteral(32), IRLiteral(0)) + + expected_ctx.add_function(check_fn := IRFunction(IRLabel("check_cv"))) + + check_entry_bb = check_fn.get_basic_block("check_cv") + check_entry_bb.append_instruction("callvalue", ret=IRVariable("1")) + check_entry_bb.append_instruction("param", ret=IRVariable("2")) + check_entry_bb.append_instruction( + "jnz", IRVariable("1"), IRLabel("has_value"), IRLabel("no_value") + ) + check_fn.append_basic_block(no_value_bb := IRBasicBlock(IRLabel("no_value"), check_fn)) + no_value_bb.append_instruction("ret", IRVariable("2")) + + check_fn.append_basic_block(value_bb := IRBasicBlock(IRLabel("has_value"), check_fn)) + value_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) + value_bb.append_instruction("stop") + + expected_ctx.data_segment = [ + DataSection( + IRLabel("selector_buckets"), + [ + DataItem(IRLabel("selector_bucket_0")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_2")), + DataItem(IRLabel("selector_bucket_3")), + DataItem(IRLabel("selector_bucket_6")), + ], + ) + ] + + assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_phis(): + # @external + # def _loop() -> uint256: + # res: uint256 = 9 + # for i: uint256 in range(res, bound=10): + # res = res + i + # return res + source = """ + function __main_entry { + __main_entry: ; IN=[] OUT=[fallback, 1_then] => {} + %27 = 0 + %1 = calldataload %27 + %28 = %1 + %29 = 224 + %2 = shr %29, %28 + %31 = %2 + %30 = 1729138561 + %4 = xor %30, %31 + %32 = %4 + jnz %32, @fallback, @1_then + ; (__main_entry) + + + 1_then: ; IN=[__main_entry] OUT=[4_condition] => {%11, %var8_0} + %6 = callvalue + %33 = %6 + %7 = iszero %33 + %34 = %7 + assert %34 + %var8_0 = 9 + %11 = 0 + nop + jmp @4_condition + ; (__main_entry) + + + 4_condition: ; IN=[1_then, 5_body] OUT=[5_body, 7_exit] => {%11:3, %var8_0:2} + %var8_0:2 = phi @1_then, %var8_0, @5_body, %var8_0:3 + %11:3 = phi @1_then, %11, @5_body, %11:4 + %35 = %11:3 + %36 = 9 + %15 = xor %36, %35 + %37 = %15 + jnz %37, @5_body, @7_exit + ; (__main_entry) + + + 5_body: ; IN=[4_condition] OUT=[4_condition] => {%11:4, %var8_0:3} + %38 = %11:3 + %39 = %var8_0:2 + %22 = add %39, %38 + %41 = %22 + %40 = %var8_0:2 + %24 = gt %40, %41 + %42 = %24 + %25 = iszero %42 + %43 = %25 + assert %43 + %var8_0:3 = %22 + %44 = %11:3 + %45 = 1 + %11:4 = add %45, %44 + jmp @4_condition + ; (__main_entry) + + + 7_exit: ; IN=[4_condition] OUT=[] => {} + %46 = %var8_0:2 + %47 = 64 + mstore %47, %46 + %48 = 32 + %49 = 64 + return %49, %48 + ; (__main_entry) + + + fallback: ; IN=[__main_entry] OUT=[] => {} + %50 = 0 + %51 = 0 + revert %51, %50 + stop + ; (__main_entry) + } ; close function __main_entry + """ + ctx = parse_venom(source) + + expected_ctx = IRContext() + expected_ctx.add_function(entry_fn := IRFunction(IRLabel("__main_entry"))) + + expect_bb = IRBasicBlock(IRLabel("4_condition"), entry_fn) + entry_fn.append_basic_block(expect_bb) + + expect_bb.append_instruction( + "phi", + IRLabel("1_then"), + IRVariable("%var8_0"), + IRLabel("5_body"), + IRVariable("%var8_0:3"), + ret=IRVariable("var8_0:2"), + ) + expect_bb.append_instruction( + "phi", + IRLabel("1_then"), + IRVariable("%11"), + IRLabel("5_body"), + IRVariable("%11:4"), + ret=IRVariable("11:3"), + ) + expect_bb.append_instruction("store", IRVariable("11:3"), ret=IRVariable("%35")) + expect_bb.append_instruction("store", IRLiteral(9), ret=IRVariable("%36")) + expect_bb.append_instruction("xor", IRVariable("%35"), IRVariable("%36"), ret=IRVariable("%15")) + expect_bb.append_instruction("store", IRVariable("%15"), ret=IRVariable("%37")) + expect_bb.append_instruction("jnz", IRVariable("%37"), IRLabel("5_body"), IRLabel("7_exit")) + # other basic blocks omitted for brevity + + parsed_fn = next(iter(ctx.functions.values())) + assert_bb_eq(parsed_fn.get_basic_block(expect_bb.label.name), expect_bb) diff --git a/tests/functional/venom/test_venom_repr.py b/tests/functional/venom/test_venom_repr.py new file mode 100644 index 0000000000..1fb5d0486a --- /dev/null +++ b/tests/functional/venom/test_venom_repr.py @@ -0,0 +1,126 @@ +import copy +import glob +import textwrap + +import pytest + +from tests.venom_utils import assert_ctx_eq, parse_venom +from vyper.compiler import compile_code +from vyper.compiler.phases import generate_bytecode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental, run_passes_on +from vyper.venom.context import IRContext + +""" +Check that venom text format round-trips through parser +""" + + +def get_example_vy_filenames(): + return glob.glob("**/*.vy", root_dir="examples/", recursive=True) + + +@pytest.mark.parametrize("vy_filename", get_example_vy_filenames()) +def test_round_trip_examples(vy_filename, debug, optimize, compiler_settings, request): + """ + Check all examples round trip + """ + path = f"examples/{vy_filename}" + with open(path) as f: + vyper_source = f.read() + + if debug and optimize == OptimizationLevel.CODESIZE: + # FIXME: some round-trips fail when debug is enabled due to labels + # not getting pinned + request.node.add_marker(pytest.mark.xfail(strict=False)) + + _round_trip_helper(vyper_source, optimize, compiler_settings) + + +# pure vyper sources +vyper_sources = [ + """ + @external + def _loop() -> uint256: + res: uint256 = 9 + for i: uint256 in range(res, bound=10): + res = res + i + return res + """ +] + + +@pytest.mark.parametrize("vyper_source", vyper_sources) +def test_round_trip_sources(vyper_source, debug, optimize, compiler_settings, request): + """ + Test vyper_sources round trip + """ + vyper_source = textwrap.dedent(vyper_source) + + if debug and optimize == OptimizationLevel.CODESIZE: + # FIXME: some round-trips fail when debug is enabled due to labels + # not getting pinned + request.node.add_marker(pytest.mark.xfail(strict=False)) + + _round_trip_helper(vyper_source, optimize, compiler_settings) + + +def _round_trip_helper(vyper_source, optimize, compiler_settings): + # helper function to test venom round-tripping thru the parser + # use two helpers because run_passes_on and + # generate_assembly_experimental are both destructive (mutating) on + # the IRContext + _helper1(vyper_source, optimize) + _helper2(vyper_source, optimize, compiler_settings) + + +def _helper1(vyper_source, optimize): + """ + Check that we are able to run passes on the round-tripped venom code + and that it is valid (generates bytecode) + """ + # note: compiling any later stage than bb_runtime like `asm` or + # `bytecode` modifies the bb_runtime data structure in place and results + # in normalization of the venom cfg (which breaks again make_ssa) + out = compile_code(vyper_source, output_formats=["bb_runtime"]) + + bb_runtime = out["bb_runtime"] + venom_code = IRContext.__repr__(bb_runtime) + + ctx = parse_venom(venom_code) + + assert_ctx_eq(bb_runtime, ctx) + + # check it's valid to run venom passes+analyses + # (note this breaks bytecode equality, in the future we should + # test that separately) + run_passes_on(ctx, optimize) + + # test we can generate assembly+bytecode + asm = generate_assembly_experimental(ctx) + generate_bytecode(asm, compiler_metadata=None) + + +def _helper2(vyper_source, optimize, compiler_settings): + """ + Check that we can compile to bytecode, and without running venom passes, + that the output bytecode is equal to going through the normal vyper pipeline + """ + settings = copy.copy(compiler_settings) + # bytecode equivalence only makes sense if we use venom pipeline + settings.experimental_codegen = True + + out = compile_code(vyper_source, settings=settings, output_formats=["bb_runtime"]) + bb_runtime = out["bb_runtime"] + venom_code = IRContext.__repr__(bb_runtime) + + ctx = parse_venom(venom_code) + + assert_ctx_eq(bb_runtime, ctx) + + # test we can generate assembly+bytecode + asm = generate_assembly_experimental(ctx, optimize=optimize) + bytecode = generate_bytecode(asm, compiler_metadata=None) + + out = compile_code(vyper_source, settings=settings, output_formats=["bytecode_runtime"]) + assert "0x" + bytecode.hex() == out["bytecode_runtime"] diff --git a/tests/integration/test_pickle_ast.py b/tests/integration/test_pickle_ast.py new file mode 100644 index 0000000000..2c6144603a --- /dev/null +++ b/tests/integration/test_pickle_ast.py @@ -0,0 +1,19 @@ +import copy +import pickle + +from vyper.compiler.phases import CompilerData + + +def test_pickle_ast(): + code = """ +@external +def foo(): + self.bar() + y: uint256 = 5 + x: uint256 = 5 +def bar(): + pass + """ + f = CompilerData(code) + copy.deepcopy(f.annotated_vyper_module) + pickle.loads(pickle.dumps(f.annotated_vyper_module)) diff --git a/tests/unit/ast/nodes/test_binary.py b/tests/unit/ast/nodes/test_binary.py index d7662bc4bb..4bebe0abc2 100644 --- a/tests/unit/ast/nodes/test_binary.py +++ b/tests/unit/ast/nodes/test_binary.py @@ -1,5 +1,6 @@ import pytest +from tests.ast_utils import deepequals from vyper import ast as vy_ast from vyper.exceptions import SyntaxException @@ -18,7 +19,7 @@ def x(): """ ) - assert expected == mutated + assert deepequals(expected, mutated) def test_binary_length(): diff --git a/tests/unit/ast/nodes/test_compare_nodes.py b/tests/unit/ast/nodes/test_compare_nodes.py index 164cd3d371..d228e40bd1 100644 --- a/tests/unit/ast/nodes/test_compare_nodes.py +++ b/tests/unit/ast/nodes/test_compare_nodes.py @@ -1,3 +1,4 @@ +from tests.ast_utils import deepequals from vyper import ast as vy_ast @@ -6,21 +7,21 @@ def test_compare_different_node_clases(): left = vyper_ast.body[0].target right = vyper_ast.body[0].value - assert left != right + assert not deepequals(left, right) def test_compare_different_nodes_same_class(): vyper_ast = vy_ast.parse_to_ast("[1, 2]") left, right = vyper_ast.body[0].value.elements - assert left != right + assert not deepequals(left, right) def test_compare_different_nodes_same_value(): vyper_ast = vy_ast.parse_to_ast("[1, 1]") left, right = vyper_ast.body[0].value.elements - assert left != right + assert not deepequals(left, right) def test_compare_similar_node(): @@ -28,11 +29,11 @@ def test_compare_similar_node(): left = vy_ast.Int(value=1) right = vy_ast.Int(value=1) - assert left == right + assert deepequals(left, right) def test_compare_same_node(): vyper_ast = vy_ast.parse_to_ast("42") node = vyper_ast.body[0].value - assert node == node + assert deepequals(node, node) diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index aab8ac0b2d..fd9f65a7d3 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -110,3 +110,20 @@ def test_compare_type_mismatch(op): old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() + + +@pytest.mark.parametrize("op", ["==", "!="]) +def test_compare_eq_bytes(get_contract, op): + left, right = "0xA1AAB33F", "0xa1aab33f" + source = f""" +@external +def foo(a: bytes4, b: bytes4) -> bool: + return a {op} b + """ + contract = get_contract(source) + + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value + new_node = old_node.get_folded_value() + + assert contract.foo(left, right) == new_node.value diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 1b61764d57..6d82b1d2ab 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -33,11 +33,14 @@ def foo(): """ foo: constant(bytes4) = 0x12_34_56 """, + """ +foo: constant(bytes4) = 0X12345678 + """, ] @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code, dummy_input_bundle): +def test_invalid_checksum(code): with pytest.raises(InvalidLiteral): vyper_module = vy_ast.parse_to_ast(code) - semantics.analyze_module(vyper_module, dummy_input_bundle) + semantics.analyze_module(vyper_module) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 7e1641e49e..afba043113 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,6 +1,6 @@ import ast as python_ast -from vyper.ast.parse import annotate_python_ast, pre_parse +from vyper.ast.parse import PreParser, annotate_python_ast class AssertionVisitor(python_ast.NodeVisitor): @@ -28,12 +28,13 @@ def foo() -> int128: def get_contract_info(source_code): - _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) - py_ast = python_ast.parse(reformatted_code) + pre_parser = PreParser() + pre_parser.parse(source_code) + py_ast = python_ast.parse(pre_parser.reformatted_code) - annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) + annotate_python_ast(py_ast, pre_parser.reformatted_code, pre_parser) - return py_ast, reformatted_code + return py_ast, pre_parser.reformatted_code def test_it_annotates_ast_with_source_code(): diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 81c3dc46fa..cfad0795bc 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,5 +1,7 @@ +import copy import json +from tests.ast_utils import deepequals from vyper import compiler from vyper.ast.nodes import NODE_SRC_ATTRIBUTES from vyper.ast.parse import parse_to_ast @@ -137,7 +139,7 @@ def test() -> int128: new_dict = json.loads(out_json) new_ast = dict_to_ast(new_dict) - assert new_ast == original_ast + assert deepequals(new_ast, original_ast) # strip source annotations like lineno, we don't care for inspecting @@ -216,24 +218,27 @@ def foo(): input_bundle = make_input_bundle({"lib1.vy": lib1, "main.vy": main}) lib1_file = input_bundle.load_file("lib1.vy") - out = compiler.compile_from_file_input( + lib1_out = compiler.compile_from_file_input( lib1_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] ) - lib1_ast = out["annotated_ast_dict"]["ast"] + + lib1_ast = copy.deepcopy(lib1_out["annotated_ast_dict"]["ast"]) lib1_sha256sum = lib1_ast.pop("source_sha256sum") assert lib1_sha256sum == lib1_file.sha256sum to_strip = NODE_SRC_ATTRIBUTES + ("resolved_path", "variable_reads", "variable_writes") _strip_source_annotations(lib1_ast, to_strip=to_strip) main_file = input_bundle.load_file("main.vy") - out = compiler.compile_from_file_input( + main_out = compiler.compile_from_file_input( main_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] ) - main_ast = out["annotated_ast_dict"]["ast"] + main_ast = main_out["annotated_ast_dict"]["ast"] main_sha256sum = main_ast.pop("source_sha256sum") assert main_sha256sum == main_file.sha256sum _strip_source_annotations(main_ast, to_strip=to_strip) + assert main_out["annotated_ast_dict"]["imports"][0] == lib1_out["annotated_ast_dict"]["ast"] + # TODO: would be nice to refactor this into bunch of small test cases assert main_ast == { "ast_type": "Module", @@ -395,6 +400,7 @@ def foo(): "node_id": 0, "path": "main.vy", "source_id": 1, + "is_interface": False, "type": { "name": "main.vy", "type_decl_node": {"node_id": 0, "source_id": 1}, @@ -1171,6 +1177,7 @@ def foo(): "node_id": 0, "path": "lib1.vy", "source_id": 0, + "is_interface": False, "type": { "name": "lib1.vy", "type_decl_node": {"node_id": 0, "source_id": 0}, @@ -1255,7 +1262,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "x"}, + "target": { + "ast_type": "Name", + "id": "x", + "variable_reads": [ + {"name": "x", "decl_node": {"node_id": 15, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", @@ -1300,7 +1313,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "x"}, + "target": { + "ast_type": "Name", + "id": "x", + "variable_reads": [ + {"name": "x", "decl_node": {"node_id": 35, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", @@ -1317,7 +1336,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "y"}, + "target": { + "ast_type": "Name", + "id": "y", + "variable_reads": [ + {"name": "y", "decl_node": {"node_id": 44, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", @@ -1758,3 +1783,49 @@ def qux2(): }, } ] + + +def test_annotated_ast_export_recursion(make_input_bundle): + sources = { + "main.vy": """ +import lib1 + +@external +def foo(): + lib1.foo() + """, + "lib1.vy": """ +import lib2 + +def foo(): + lib2.foo() + """, + "lib2.vy": """ +def foo(): + pass + """, + } + + input_bundle = make_input_bundle(sources) + + def compile_and_get_ast(file_name): + file = input_bundle.load_file(file_name) + output = compiler.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] + ) + return output["annotated_ast_dict"] + + lib1_ast = compile_and_get_ast("lib1.vy")["ast"] + lib2_ast = compile_and_get_ast("lib2.vy")["ast"] + main_out = compile_and_get_ast("main.vy") + + lib1_import_ast = main_out["imports"][1] + lib2_import_ast = main_out["imports"][0] + + # path is once virtual, once libX.vy + # type contains name which is based on path + keys = [s for s in lib1_import_ast.keys() if s not in {"path", "type"}] + + for key in keys: + assert lib1_ast[key] == lib1_import_ast[key] + assert lib2_ast[key] == lib2_import_ast[key] diff --git a/tests/unit/ast/test_natspec.py b/tests/unit/ast/test_natspec.py index 710b7a9312..37120d2978 100644 --- a/tests/unit/ast/test_natspec.py +++ b/tests/unit/ast/test_natspec.py @@ -436,3 +436,19 @@ def test_natspec_parsed_implicitly(): # anything beyond ast is blocked with pytest.raises(NatSpecSyntaxException): compile_code(code, output_formats=["annotated_ast_dict"]) + + +def test_natspec_exception_contains_file_path(): + code = """ +@external +def foo() -> (int128,uint256): + ''' + @return int128 + @return uint256 + @return this should fail + ''' + return 1, 2 + """ + + with pytest.raises(NatSpecSyntaxException, match=r'contract "VyperContract\.vy:\d+"'): + parse_natspec(code) diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index e0bfcbc2ef..96df6cf245 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,3 +1,4 @@ +from tests.ast_utils import deepequals from vyper.ast.parse import parse_to_ast @@ -12,7 +13,7 @@ def test() -> int128: ast1 = parse_to_ast(code) ast2 = parse_to_ast("\n \n" + code + "\n\n") - assert ast1 == ast2 + assert deepequals(ast1, ast2) def test_ast_unequal(): @@ -32,4 +33,4 @@ def test() -> int128: ast1 = parse_to_ast(code1) ast2 = parse_to_ast(code2) - assert ast1 != ast2 + assert not deepequals(ast1, ast2) diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index da7d72b8ec..510d1e0ed2 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -1,7 +1,9 @@ +from pathlib import Path + import pytest from vyper import compile_code -from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.ast.pre_parser import PreParser, validate_version_pragma from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException, VersionException @@ -56,6 +58,24 @@ def test_invalid_version_pragma(file_version, mock_version): validate_version_pragma(f"{file_version}", file_version, (SRC_LINE)) +def test_invalid_version_contains_file(mock_version): + mock_version(COMPILER_VERSION) + with pytest.raises(VersionException, match=r'contract "mock\.vy:\d+"'): + compile_code("# pragma version ^0.3.10", resolved_path=Path("mock.vy")) + + +def test_imported_invalid_version_contains_correct_file( + mock_version, make_input_bundle, chdir_tmp_path +): + code_a = "# pragma version ^0.3.10" + code_b = "import A" + input_bundle = make_input_bundle({"A.vy": code_a, "B.vy": code_b}) + mock_version(COMPILER_VERSION) + + with pytest.raises(VersionException, match=r'contract "A\.vy:\d+"'): + compile_code(code_b, input_bundle=input_bundle) + + prerelease_valid_versions = [ "<0.1.1-beta.9", "<0.1.1b9", @@ -174,9 +194,10 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _, _ = pre_parse(code) + pre_parser = PreParser() + pre_parser.parse(code) - assert settings == pre_parse_settings + assert pre_parser.settings == pre_parse_settings compiler_data = CompilerData(code) @@ -191,6 +212,26 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve assert compiler_data.settings == compiler_data_settings +pragma_venom = [ + """ + #pragma venom + """, + """ + #pragma experimental-codegen + """, +] + + +@pytest.mark.parametrize("code", pragma_venom) +def test_parse_venom_pragma(code): + pre_parser = PreParser() + pre_parser.parse(code) + assert pre_parser.settings.experimental_codegen is True + + compiler_data = CompilerData(code) + assert compiler_data.settings.experimental_codegen is True + + invalid_pragmas = [ # evm-versionnn """ @@ -218,13 +259,22 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve # pragma evm-version cancun # pragma evm-version shanghai """, + # duplicate setting of venom + """ + #pragma venom + #pragma experimental-codegen + """, + """ + #pragma venom + #pragma venom + """, ] @pytest.mark.parametrize("code", invalid_pragmas) def test_invalid_pragma(code): with pytest.raises(StructureException): - pre_parse(code) + PreParser().parse(code) def test_version_exception_in_import(make_input_bundle): diff --git a/tests/unit/ast/test_tokenizer.py b/tests/unit/ast/test_tokenizer.py new file mode 100644 index 0000000000..f6000e0425 --- /dev/null +++ b/tests/unit/ast/test_tokenizer.py @@ -0,0 +1,94 @@ +""" +Tests that the tokenizer / parser are passing correct source location +info to the AST +""" +import pytest + +from vyper.ast.parse import parse_to_ast +from vyper.compiler import compile_code +from vyper.exceptions import UndeclaredDefinition + + +def test_log_token_aligned(): + # GH issue 3430 + code = """ +event A: + b: uint256 + +@external +def f(): + log A(b=d) + """ + with pytest.raises(UndeclaredDefinition) as e: + compile_code(code) + + expected = """ + 'd' has not been declared. + + function "f", line 7:12 + 6 def f(): + ---> 7 log A(b=d) + -------------------^ + 8 + """ # noqa: W291 + assert expected.strip() == str(e.value).strip() + + +def test_log_token_aligned2(): + # GH issue 3059 + code = """ +interface Contract: + def foo(): nonpayable + +event MyEvent: + a: address + +@external +def foo(c: Contract): + log MyEvent(a=c.address) + """ + compile_code(code) + + +def test_log_token_aligned3(): + # https://github.com/vyperlang/vyper/pull/3808#pullrequestreview-1900570163 + code = """ +import ITest + +implements: ITest + +event Foo: + a: address + +@external +def foo(u: uint256): + log Foo(empty(address)) + log i.Foo(empty(address)) + """ + # not semantically valid code, check we can at least parse it + assert parse_to_ast(code) is not None + + +def test_log_token_aligned4(): + # GH issue 4139 + code = """ +b: public(uint256) + +event Transfer: + random: indexed(uint256) + shi: uint256 + +@external +def transfer(): + log Transfer(T(self).b(), 10) + return + """ + # not semantically valid code, check we can at least parse it + assert parse_to_ast(code) is not None + + +def test_long_string_non_coding_token(): + # GH issue 2258 + code = '\r[[]]\ndef _(e:[],l:[]):\n """"""""""""""""""""""""""""""""""""""""""""""""""""""\n f.n()' # noqa: E501 + # not valid code, but should at least parse + assert parse_to_ast(code) is not None diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f02a8471e2..d2c495a2fa 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -2,6 +2,7 @@ import pytest +from vyper.cli.vyper_json import compile_json from vyper.compiler import compile_code from vyper.evm.opcodes import version_check from vyper.exceptions import StorageLayoutException @@ -13,7 +14,7 @@ def test_storage_layout_overrides(): b: uint256""" storage_layout_overrides = { - "a": {"type": "uint256", "slot": 1, "n_slots": 1}, + "a": {"type": "uint256", "slot": 5, "n_slots": 1}, "b": {"type": "uint256", "slot": 0, "n_slots": 1}, } @@ -26,6 +27,31 @@ def test_storage_layout_overrides(): assert out["layout"] == expected_output +def test_storage_layout_overrides_json(): + code = """ +a: uint256 +b: uint256""" + + storage_layout_overrides = { + "a": {"type": "uint256", "slot": 5, "n_slots": 1}, + "b": {"type": "uint256", "slot": 0, "n_slots": 1}, + } + + input_json = { + "language": "Vyper", + "sources": {"contracts/foo.vy": {"content": code}}, + "storage_layout_overrides": {"contracts/foo.vy": storage_layout_overrides}, + "settings": {"outputSelection": {"*": ["*"]}}, + } + + out = compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_overrides + ) + assert ( + compile_json(input_json)["contracts"]["contracts/foo.vy"]["foo"]["layout"] == out["layout"] + ) + + def test_storage_layout_for_more_complex(): code = """ foo: HashMap[address, uint256] diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 3856aa3362..a1f5ca098c 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -1,17 +1,27 @@ import contextlib +import json import sys +import warnings import zipfile from pathlib import Path import pytest +from vyper.cli.compile_archive import compiler_data_from_zip from vyper.cli.vyper_compile import compile_files -from vyper.cli.vyper_json import compile_json +from vyper.cli.vyper_json import compile_from_input_dict, compile_json +from vyper.compiler import INTERFACE_OUTPUT_FORMATS, OUTPUT_FORMATS from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.output_bundle import OutputBundle from vyper.compiler.phases import CompilerData from vyper.utils import sha256sum +TAMPERED_INTEGRITY_SUM = sha256sum("tampered integrity sum") + +INTEGRITY_WARNING = f"Mismatched integrity sum! Expected {TAMPERED_INTEGRITY_SUM}" +INTEGRITY_WARNING += " but got {integrity}." # noqa: FS003 +INTEGRITY_WARNING += " (This likely indicates a corrupted archive)" + def test_combined_json_keys(chdir_tmp_path, make_file): make_file("bar.vy", "") @@ -297,6 +307,9 @@ def foo() -> uint256: import lib import jsonabi +a: uint256 +b: uint256 + @external def foo() -> uint256: return lib.foo() @@ -305,28 +318,52 @@ def foo() -> uint256: def bar(x: uint256) -> uint256: return extcall jsonabi(msg.sender).test_json(x) """ + storage_layout_overrides = { + "a": {"type": "uint256", "n_slots": 1, "slot": 5}, + "b": {"type": "uint256", "n_slots": 1, "slot": 0}, + } + storage_layout_source = json.dumps(storage_layout_overrides) + tmpdir = tmp_path_factory.mktemp("fake-package") with open(tmpdir / "lib.vy", "w") as f: f.write(library_source) with open(tmpdir / "jsonabi.json", "w") as f: f.write(json_source) + with open(tmpdir / "layout.json", "w") as f: + f.write(storage_layout_source) contract_file = make_file("contract.vy", contract_source) - return (tmpdir, tmpdir / "lib.vy", tmpdir / "jsonabi.json", contract_file) + contract_hash = sha256sum(contract_source) + library_hash = sha256sum(library_source) + jsonabi_hash = sha256sum(json_source) + resolved_imports_hash = sha256sum(contract_hash + sha256sum(library_hash) + jsonabi_hash) + storage_layout_hash = sha256sum(storage_layout_source) + expected_integrity = sha256sum(storage_layout_hash + resolved_imports_hash) + + return ( + tmpdir, + tmpdir / "lib.vy", + tmpdir / "jsonabi.json", + tmpdir / "layout.json", + contract_file, + expected_integrity, + ) def test_import_sys_path(input_files): - tmpdir, _, _, contract_file = input_files + tmpdir, _, _, _, contract_file, _ = input_files with mock_sys_path(tmpdir): assert compile_files([contract_file], ["combined_json"]) is not None def test_archive_output(input_files): - tmpdir, _, _, contract_file = input_files + tmpdir, library_file, jsonabi_file, storage_layout_path, contract_file, integrity = input_files search_paths = [".", tmpdir] - s = compile_files([contract_file], ["archive"], paths=search_paths) + s = compile_files( + [contract_file], ["archive"], paths=search_paths, storage_layout_paths=[storage_layout_path] + ) archive_bytes = s[contract_file]["archive"] archive_path = Path("foo.zip") @@ -336,13 +373,28 @@ def test_archive_output(input_files): assert zipfile.is_zipfile(archive_path) # compare compiling the two input bundles - out = compile_files([contract_file], ["integrity", "bytecode"], paths=search_paths) - out2 = compile_files([archive_path], ["integrity", "bytecode"]) + out = compile_files( + [contract_file], + ["integrity", "bytecode", "layout"], + paths=search_paths, + storage_layout_paths=[storage_layout_path], + ) + out2 = compile_files([archive_path], ["integrity", "bytecode", "layout"]) assert out[contract_file] == out2[archive_path] + # tamper with the integrity sum + archive_compiler_data = compiler_data_from_zip(archive_path, None, False) + archive_compiler_data.expected_integrity_sum = TAMPERED_INTEGRITY_SUM + + with warnings.catch_warnings(record=True) as w: + assert archive_compiler_data.integrity_sum is not None + + assert len(w) == 1, [s.message for s in w] + assert str(w[0].message).startswith(INTEGRITY_WARNING.format(integrity=integrity)) + def test_archive_b64_output(input_files): - tmpdir, _, _, contract_file = input_files + tmpdir, _, _, _, contract_file, _ = input_files search_paths = [".", tmpdir] out = compile_files( @@ -360,41 +412,155 @@ def test_archive_b64_output(input_files): assert out[contract_file] == out2[archive_path] -def test_solc_json_output(input_files): - tmpdir, _, _, contract_file = input_files +def test_archive_compile_options(input_files): + tmpdir, _, _, _, contract_file, _ = input_files search_paths = [".", tmpdir] - out = compile_files([contract_file], ["solc_json"], paths=search_paths) + options = ["abi_python", "json", "ast", "annotated_ast", "ir_json"] + + for option in options: + out = compile_files([contract_file], ["archive_b64", option], paths=search_paths) + + archive_b64 = out[contract_file].pop("archive_b64") + + archive_path = Path("foo.zip.b64") + with archive_path.open("w") as f: + f.write(archive_b64) + + # compare compiling the two input bundles + out2 = compile_files([archive_path], [option]) + + if option in ["ast", "annotated_ast"]: + # would have to normalize paths and imports, so just verify it compiles + continue + + assert out[contract_file] == out2[archive_path] + + +format_options = [ + "bytecode", + "bytecode_runtime", + "blueprint_bytecode", + "abi", + "abi_python", + "source_map", + "source_map_runtime", + "method_identifiers", + "userdoc", + "devdoc", + "metadata", + "combined_json", + "layout", + "ast", + "annotated_ast", + "interface", + "external_interface", + "opcodes", + "opcodes_runtime", + "ir", + "ir_json", + "ir_runtime", + "asm", + "integrity", + "archive", + "solc_json", +] + + +def test_compile_vyz_with_options(input_files): + tmpdir, _, _, _, contract_file, _ = input_files + search_paths = [".", tmpdir] + + for option in format_options: + out_archive = compile_files([contract_file], ["archive"], paths=search_paths) + + archive = out_archive[contract_file].pop("archive") + + archive_path = Path("foo.zip.out.vyz") + with archive_path.open("wb") as f: + f.write(archive) + + # compare compiling the two input bundles + out = compile_files([contract_file], [option], paths=search_paths) + out2 = compile_files([archive_path], [option]) + + if option in ["ast", "annotated_ast", "metadata"]: + # would have to normalize paths and imports, so just verify it compiles + continue + + if option in ["ir_runtime", "ir", "archive"]: + # ir+ir_runtime is different due to being different compiler runs + # archive is different due to different metadata (timestamps) + continue + + assert out[contract_file] == out2[archive_path] + + +def test_archive_compile_simultaneous_options(input_files): + tmpdir, _, _, _, contract_file, _ = input_files + search_paths = [".", tmpdir] + for option in format_options: + with pytest.raises(ValueError) as e: + _ = compile_files([contract_file], ["archive", option], paths=search_paths) + + err_opt = "archive" + if option in ("combined_json", "solc_json"): + err_opt = option + + assert f"If using {err_opt} it must be the only output format requested" in str(e.value) + + +def test_solc_json_output(input_files): + tmpdir, _, _, storage_layout_path, contract_file, integrity = input_files + search_paths = [".", tmpdir] + + out = compile_files( + [contract_file], + ["solc_json"], + paths=search_paths, + storage_layout_paths=[storage_layout_path], + ) json_input = out[contract_file]["solc_json"] # check that round-tripping solc_json thru standard json produces # the same as compiling directly json_out = compile_json(json_input)["contracts"]["contract.vy"] json_out_bytecode = json_out["contract"]["evm"]["bytecode"]["object"] + json_out_layout = json_out["contract"]["layout"]["storage_layout"] - out2 = compile_files([contract_file], ["integrity", "bytecode"], paths=search_paths) + out2 = compile_files( + [contract_file], + ["integrity", "bytecode", "layout"], + paths=search_paths, + storage_layout_paths=[storage_layout_path], + ) assert out2[contract_file]["bytecode"] == json_out_bytecode + assert out2[contract_file]["layout"]["storage_layout"] == json_out_layout + + # tamper with the integrity sum + json_input["integrity"] = TAMPERED_INTEGRITY_SUM + _, warn_data = compile_from_input_dict(json_input) + + w = warn_data[Path("contract.vy")] + assert len(w) == 1, [s.message for s in w] + assert str(w[0].message).startswith(INTEGRITY_WARNING.format(integrity=integrity)) # maybe this belongs in tests/unit/compiler? def test_integrity_sum(input_files): - tmpdir, library_file, jsonabi_file, contract_file = input_files + tmpdir, library_file, jsonabi_file, storage_layout_path, contract_file, integrity = input_files search_paths = [".", tmpdir] - out = compile_files([contract_file], ["integrity"], paths=search_paths) - - with library_file.open() as f, contract_file.open() as g, jsonabi_file.open() as h: - library_contents = f.read() - contract_contents = g.read() - jsonabi_contents = h.read() + out = compile_files( + [contract_file], + ["integrity"], + paths=search_paths, + storage_layout_paths=[storage_layout_path], + ) - contract_hash = sha256sum(contract_contents) - library_hash = sha256sum(library_contents) - jsonabi_hash = sha256sum(jsonabi_contents) - expected = sha256sum(contract_hash + sha256sum(library_hash) + jsonabi_hash) - assert out[contract_file]["integrity"] == expected + assert out[contract_file]["integrity"] == integrity # does this belong in tests/unit/compiler? @@ -425,3 +591,31 @@ def test_archive_search_path(tmp_path_factory, make_file, chdir_tmp_path): used_dir = search_paths[-1].stem # either dir1 or dir2 assert output_bundle.used_search_paths == [".", "0/" + used_dir] + + +def test_compile_interface_file(make_file): + interface = """ +@view +@external +def foo() -> String[1]: + ... + +@view +@external +def bar() -> String[1]: + ... + +@external +def baz() -> uint8: + ... + + """ + file = make_file("interface.vyi", interface) + compile_files([file], INTERFACE_OUTPUT_FORMATS) + + # check unallowed output formats + for f in OUTPUT_FORMATS: + if f in INTERFACE_OUTPUT_FORMATS: + continue + with pytest.raises(ValueError): + compile_files([file], [f]) diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index ef3284cd15..f921d250a4 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -9,6 +9,7 @@ compile_json, exc_handler_to_dict, get_inputs, + get_settings, ) from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input from vyper.compiler.input_bundle import JSONInputBundle @@ -19,6 +20,9 @@ import contracts.library as library +a: uint256 +b: uint256 + @external def foo(a: address) -> bool: return extcall IBar(a).bar(1) @@ -28,16 +32,29 @@ def baz() -> uint256: return self.balance + library.foo() """ +FOO_STORAGE_LAYOUT_OVERRIDES = { + "a": {"type": "uint256", "n_slots": 1, "slot": 5}, + "b": {"type": "uint256", "n_slots": 1, "slot": 0}, +} + BAR_CODE = """ import contracts.ibar as IBar implements: IBar +c: uint256 +d: uint256 + @external def bar(a: uint256) -> bool: return True """ +BAR_STORAGE_LAYOUT_OVERRIDES = { + "c": {"type": "uint256", "n_slots": 1, "slot": 13}, + "d": {"type": "uint256", "n_slots": 1, "slot": 7}, +} + BAR_VYI = """ @external def bar(a: uint256) -> bool: @@ -72,7 +89,7 @@ def oopsie(a: uint256) -> bool: @pytest.fixture(scope="function") -def input_json(optimize, evm_version, experimental_codegen): +def input_json(optimize, evm_version, experimental_codegen, debug): return { "language": "Vyper", "sources": { @@ -86,6 +103,11 @@ def input_json(optimize, evm_version, experimental_codegen): "optimize": optimize.name.lower(), "evmVersion": evm_version, "experimentalCodegen": experimental_codegen, + "debug": debug, + }, + "storage_layout_overrides": { + "contracts/foo.vy": FOO_STORAGE_LAYOUT_OVERRIDES, + "contracts/bar.vy": BAR_STORAGE_LAYOUT_OVERRIDES, }, } @@ -126,7 +148,10 @@ def test_compile_json(input_json, input_bundle): del output_formats["cfg"] del output_formats["cfg_runtime"] foo = compile_from_file_input( - foo_input, output_formats=output_formats, input_bundle=input_bundle + foo_input, + output_formats=output_formats, + input_bundle=input_bundle, + storage_layout_override=FOO_STORAGE_LAYOUT_OVERRIDES, ) library_input = input_bundle.load_file("contracts/library.vy") @@ -136,7 +161,10 @@ def test_compile_json(input_json, input_bundle): bar_input = input_bundle.load_file("contracts/bar.vy") bar = compile_from_file_input( - bar_input, output_formats=output_formats, input_bundle=input_bundle + bar_input, + output_formats=output_formats, + input_bundle=input_bundle, + storage_layout_override=BAR_STORAGE_LAYOUT_OVERRIDES, ) compile_code_results = { @@ -169,6 +197,7 @@ def test_compile_json(input_json, input_bundle): "interface": data["interface"], "ir": data["ir_dict"], "userdoc": data["userdoc"], + "layout": data["layout"], "metadata": data["metadata"], "evm": { "bytecode": { @@ -216,7 +245,16 @@ def test_different_outputs(input_bundle, input_json): foo = contracts["contracts/foo.vy"]["foo"] bar = contracts["contracts/bar.vy"]["bar"] - assert sorted(bar.keys()) == ["abi", "devdoc", "evm", "interface", "ir", "metadata", "userdoc"] + assert sorted(bar.keys()) == [ + "abi", + "devdoc", + "evm", + "interface", + "ir", + "layout", + "metadata", + "userdoc", + ] assert sorted(foo.keys()) == ["evm"] @@ -237,7 +275,7 @@ def test_wrong_language(): def test_exc_handler_raises_syntax(input_json): input_json["sources"]["badcode.vy"] = {"content": BAD_SYNTAX_CODE} - with pytest.raises(SyntaxException): + with pytest.raises(SyntaxException, match=r'contract "badcode\.vy:\d+"'): compile_json(input_json) @@ -268,6 +306,14 @@ def test_exc_handler_to_dict_compiler(input_json): assert error["type"] == "TypeMismatch" +def test_unknown_storage_layout_overrides(input_json): + unknown_contract_path = "contracts/baz.vy" + input_json["storage_layout_overrides"] = {unknown_contract_path: FOO_STORAGE_LAYOUT_OVERRIDES} + with pytest.raises(JSONError) as e: + compile_json(input_json) + assert e.value.args[0] == f"unknown target for storage layout override: {unknown_contract_path}" + + def test_source_ids_increment(input_json): input_json["settings"]["outputSelection"] = {"*": ["ast", "evm.deployedBytecode.sourceMap"]} result = compile_json(input_json) @@ -293,5 +339,65 @@ def get(filename, contractname): def test_relative_import_paths(input_json): input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": "from ... import foo"} input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": "from . import baz"} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from baz import baz"} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from .baz import baz"} compile_from_input_dict(input_json) + + +def test_compile_json_with_abi_top(make_input_bundle): + stream = """ +{ + "abi": [ + { + "name": "validate", + "inputs": [ + { "name": "creator", "type": "address" }, + { "name": "token", "type": "address" }, + { "name": "amount_per_second", "type": "uint256" }, + { "name": "reason", "type": "bytes" } + ], + "outputs": [{ "name": "max_stream_life", "type": "uint256" }] + } + ] +} + """ + code = """ +from . import stream + """ + input_bundle = make_input_bundle({"stream.json": stream, "code.vy": code}) + file_input = input_bundle.load_file("code.vy") + vyper.compiler.compile_from_file_input(file_input, input_bundle=input_bundle) + + +def test_compile_json_with_experimental_codegen(): + code = { + "language": "Vyper", + "sources": {"foo.vy": {"content": "@external\ndef foo() -> bool:\n return True"}}, + "settings": { + "evmVersion": "cancun", + "optimize": "gas", + "venom": True, + "search_paths": [], + "outputSelection": {"*": ["ast"]}, + }, + } + + settings = get_settings(code) + assert settings.experimental_codegen is True + + +def test_compile_json_with_both_venom_aliases(): + code = { + "language": "Vyper", + "sources": {"foo.vy": {"content": ""}}, + "settings": { + "evmVersion": "cancun", + "optimize": "gas", + "experimentalCodegen": False, + "venom": False, + "search_paths": [], + "outputSelection": {"*": ["ast"]}, + }, + } + with pytest.raises(JSONError) as e: + get_settings(code) + assert e.value.args[0] == "both experimentalCodegen and venom cannot be set" diff --git a/tests/unit/cli/vyper_json/test_get_settings.py b/tests/unit/cli/vyper_json/test_get_settings.py index 540a26f062..077f424d45 100644 --- a/tests/unit/cli/vyper_json/test_get_settings.py +++ b/tests/unit/cli/vyper_json/test_get_settings.py @@ -1,6 +1,6 @@ import pytest -from vyper.cli.vyper_json import get_evm_version +from vyper.cli.vyper_json import get_evm_version, get_settings from vyper.exceptions import JSONError @@ -30,3 +30,14 @@ def test_early_evm(evm_version_str): @pytest.mark.parametrize("evm_version_str", ["london", "paris", "shanghai", "cancun"]) def test_valid_evm(evm_version_str): assert evm_version_str == get_evm_version({"settings": {"evmVersion": evm_version_str}}) + + +def test_experimental_codegen_settings(): + input_json = {"settings": {}} + assert get_settings(input_json).experimental_codegen is None + + input_json = {"settings": {"experimentalCodegen": True}} + assert get_settings(input_json).experimental_codegen is True + + input_json = {"settings": {"experimentalCodegen": False}} + assert get_settings(input_json).experimental_codegen is False diff --git a/tests/unit/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py index f7fbfe673c..e409c43af6 100644 --- a/tests/unit/cli/vyper_json/test_output_selection.py +++ b/tests/unit/cli/vyper_json/test_output_selection.py @@ -2,6 +2,7 @@ import pytest +from vyper import compiler from vyper.cli.vyper_json import TRANSLATE_MAP, get_output_formats from vyper.exceptions import JSONError @@ -76,3 +77,55 @@ def test_solc_style(): def test_metadata(): input_json = {"sources": {"foo.vy": ""}, "settings": {"outputSelection": {"*": ["metadata"]}}} assert get_output_formats(input_json) == {PurePath("foo.vy"): ["metadata"]} + + +def test_metadata_contain_all_reachable_functions(make_input_bundle, chdir_tmp_path): + code_a = """ +@internal +def foo() -> uint256: + return 43 + +@internal +def faa() -> uint256: + return 76 + """ + + code_b = """ +import A + +@internal +def foo() -> uint256: + return 43 + +@external +def bar(): + self.foo() + A.foo() + assert 1 != 12 + """ + + input_bundle = make_input_bundle({"A.vy": code_a, "B.vy": code_b}) + file_input = input_bundle.load_file("B.vy") + + res = compiler.compile_from_file_input( + file_input, input_bundle=input_bundle, output_formats=["metadata"] + ) + function_infos = res["metadata"]["function_info"] + + assert "foo (0)" in function_infos + assert "foo (1)" in function_infos + assert "bar (2)" in function_infos + # faa is unreachable, should not be in metadata or bytecode + assert not any("faa" in key for key in function_infos.keys()) + + assert function_infos["foo (0)"]["function_id"] == 0 + assert function_infos["foo (1)"]["function_id"] == 1 + assert function_infos["bar (2)"]["function_id"] == 2 + + assert function_infos["foo (0)"]["module_path"] == "B.vy" + assert function_infos["foo (1)"]["module_path"] == "A.vy" + assert function_infos["bar (2)"]["module_path"] == "B.vy" + + assert function_infos["foo (0)"]["source_id"] == input_bundle.load_file("B.vy").source_id + assert function_infos["foo (1)"]["source_id"] == input_bundle.load_file("A.vy").source_id + assert function_infos["bar (2)"]["source_id"] == input_bundle.load_file("B.vy").source_id diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 213adce017..9fdc4c493f 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -54,17 +54,27 @@ def test_bytecode_runtime(): assert out["bytecode_runtime"].removeprefix("0x") in out["bytecode"].removeprefix("0x") -def test_bytecode_signature(): - out = vyper.compile_code(simple_contract_code, output_formats=["bytecode_runtime", "bytecode"]) +def test_bytecode_signature(optimize, debug): + out = vyper.compile_code( + simple_contract_code, output_formats=["bytecode_runtime", "bytecode", "integrity"] + ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) metadata = _parse_cbor_metadata(initcode) - runtime_len, data_section_lengths, immutables_len, compiler = metadata + integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + + if debug and optimize == OptimizationLevel.CODESIZE: + # debug forces dense jumptable no matter the size of selector table + expected_data_section_lengths = [5, 7] + else: + expected_data_section_lengths = [] + + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) - assert data_section_lengths == [] + assert data_section_lengths == expected_data_section_lengths assert immutables_len == 0 assert compiler == {"vyper": list(vyper.version.version_tuple)} @@ -73,14 +83,18 @@ def test_bytecode_signature_dense_jumptable(): settings = Settings(optimize=OptimizationLevel.CODESIZE) out = vyper.compile_code( - many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings + many_functions, + output_formats=["bytecode_runtime", "bytecode", "integrity"], + settings=settings, ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) metadata = _parse_cbor_metadata(initcode) - runtime_len, data_section_lengths, immutables_len, compiler = metadata + integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) assert data_section_lengths == [5, 35] @@ -92,14 +106,18 @@ def test_bytecode_signature_sparse_jumptable(): settings = Settings(optimize=OptimizationLevel.GAS) out = vyper.compile_code( - many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings + many_functions, + output_formats=["bytecode_runtime", "bytecode", "integrity"], + settings=settings, ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) metadata = _parse_cbor_metadata(initcode) - runtime_len, data_section_lengths, immutables_len, compiler = metadata + integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) assert data_section_lengths == [8] @@ -107,17 +125,27 @@ def test_bytecode_signature_sparse_jumptable(): assert compiler == {"vyper": list(vyper.version.version_tuple)} -def test_bytecode_signature_immutables(): - out = vyper.compile_code(has_immutables, output_formats=["bytecode_runtime", "bytecode"]) +def test_bytecode_signature_immutables(debug, optimize): + out = vyper.compile_code( + has_immutables, output_formats=["bytecode_runtime", "bytecode", "integrity"] + ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) metadata = _parse_cbor_metadata(initcode) - runtime_len, data_section_lengths, immutables_len, compiler = metadata + integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + + if debug and optimize == OptimizationLevel.CODESIZE: + # debug forces dense jumptable no matter the size of selector table + expected_data_section_lengths = [5, 7] + else: + expected_data_section_lengths = [] + + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) - assert data_section_lengths == [] + assert data_section_lengths == expected_data_section_lengths assert immutables_len == 32 assert compiler == {"vyper": list(vyper.version.version_tuple)} @@ -129,7 +157,10 @@ def test_bytecode_signature_deployed(code, get_contract, env): deployed_code = env.get_code(c.address) metadata = _parse_cbor_metadata(c.bytecode) - runtime_len, data_section_lengths, immutables_len, compiler = metadata + integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + + out = vyper.compile_code(code, output_formats=["integrity"]) + assert integrity_hash.hex() == out["integrity"] assert compiler == {"vyper": list(vyper.version.version_tuple)} diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index d99b546403..ae1999a26e 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -97,8 +97,44 @@ def update_foo(): self.foo += 1 """ error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] - assert "safeadd" in list(error_map.values()) - assert "fallback function" in list(error_map.values()) + assert "safeadd" in error_map.values() + assert "fallback function" in error_map.values() + + +def test_error_map_with_user_error(): + code = """ +@external +def foo(): + raise "some error" + """ + error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + assert "user revert with reason" in error_map.values() + + +def test_error_map_with_user_error2(): + code = """ +@external +def foo(i: uint256): + a: DynArray[uint256, 10] = [1] + a[i % 10] = 2 + """ + error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + assert "safemod" in error_map.values() + + +def test_error_map_not_overriding_errors(): + code = """ +@external +def foo(i: uint256): + raise self.bar(5%i) + +@pure +def bar(i: uint256) -> String[32]: + return "foo foo" + """ + error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + assert "user revert with reason" in error_map.values() + assert "safemod" in error_map.values() def test_compress_source_map(): diff --git a/tests/unit/compiler/venom/test_algebraic_binopt.py b/tests/unit/compiler/venom/test_algebraic_binopt.py new file mode 100644 index 0000000000..5486787225 --- /dev/null +++ b/tests/unit/compiler/venom/test_algebraic_binopt.py @@ -0,0 +1,584 @@ +import pytest + +from tests.venom_utils import assert_ctx_eq, parse_from_basic_block +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.passes import AlgebraicOptimizationPass, StoreElimination + +""" +Test abstract binop+unop optimizations in algebraic optimizations pass +""" + + +def _sccp_algebraic_runner(pre, post): + ctx = parse_from_basic_block(pre) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + StoreElimination(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() + + assert_ctx_eq(ctx, parse_from_basic_block(post)) + + +def test_sccp_algebraic_opt_sub_xor(): + # x - x -> 0 + # x ^ x -> 0 + pre = """ + _global: + %par = param + %1 = sub %par, %par + %2 = xor %par, %par + return %1, %2 + """ + post = """ + _global: + %par = param + return 0, 0 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_zero_sub_add_xor(): + # x + 0 == x - 0 == x ^ 0 -> x + # (this cannot be done for 0 - x) + pre = """ + _global: + %par = param + %1 = sub %par, 0 + %2 = xor %par, 0 + %3 = add %par, 0 + %4 = sub 0, %par + %5 = add 0, %par + %6 = xor 0, %par + return %1, %2, %3, %4, %5, %6 + """ + post = """ + _global: + %par = param + %4 = sub 0, %par + return %par, %par, %par, %4, %par, %par + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_sub_xor_max(): + # x ^ 0xFF..FF -> not x + # -1 - x -> ~x + pre = """ + _global: + %par = param + %tmp = -1 + %1 = xor -1, %par + %2 = xor %par, -1 + + %3 = sub -1, %par + + return %1, %2, %3 + """ + post = """ + _global: + %par = param + %1 = not %par + %2 = not %par + %3 = not %par + return %1, %2, %3 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_shift(): + # x << 0 == x >> 0 == x (sar) 0 -> x + # sar is right arithmetic shift + pre = """ + _global: + %par = param + %1 = shl 0, %par + %2 = shr 0, %1 + %3 = sar 0, %2 + return %1, %2, %3 + """ + post = """ + _global: + %par = param + return %par, %par, %par + """ + + _sccp_algebraic_runner(pre, post) + + +@pytest.mark.parametrize("opcode", ("mul", "and", "div", "sdiv", "mod", "smod")) +def test_mul_by_zero(opcode): + # x * 0 == 0 * x == x % 0 == 0 % x == x // 0 == 0 // x == x & 0 == 0 & x -> 0 + pre = f""" + _global: + %par = param + %1 = {opcode} 0, %par + %2 = {opcode} %par, 0 + return %1, %2 + """ + post = """ + _global: + %par = param + return 0, 0 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_multi_neutral_elem(): + # x * 1 == 1 * x == x / 1 -> x + # checks for non comutative ops + pre = """ + _global: + %par = param + %1_1 = mul 1, %par + %1_2 = mul %par, 1 + %2_1 = div 1, %par + %2_2 = div %par, 1 + %3_1 = sdiv 1, %par + %3_2 = sdiv %par, 1 + return %1_1, %1_2, %2_1, %2_2, %3_1, %3_2 + """ + post = """ + _global: + %par = param + %2_1 = div 1, %par + %3_1 = sdiv 1, %par + return %par, %par, %2_1, %par, %3_1, %par + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_mod_zero(): + # x % 1 -> 0 + pre = """ + _global: + %par = param + %1 = mod %par, 1 + %2 = smod %par, 1 + return %1, %2 + """ + post = """ + _global: + %par = param + return 0, 0 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_and_max(): + # x & 0xFF..FF == 0xFF..FF & x -> x + max_uint256 = 2**256 - 1 + pre = f""" + _global: + %par = param + %tmp = {max_uint256} + %1 = and %par, %tmp + %2 = and %tmp, %par + return %1, %2 + """ + post = """ + _global: + %par = param + return %par, %par + """ + + _sccp_algebraic_runner(pre, post) + + +# test powers of 2 from n==2 to n==255. +# (skip 1 since there are specialized rules for n==1) +@pytest.mark.parametrize("n", range(2, 256)) +def test_sccp_algebraic_opt_mul_div_to_shifts(n): + # x * 2**n -> x << n + # x / 2**n -> x >> n + y = 2**n + pre = f""" + _global: + %par = param + %1 = mul %par, {y} + %2 = mod %par, {y} + %3 = div %par, {y} + %4 = mul {y}, %par + %5 = mod {y}, %par ; note: this is blocked! + %6 = div {y}, %par ; blocked! + return %1, %2, %3, %4, %5, %6 + """ + post = f""" + _global: + %par = param + %1 = shl {n}, %par + %2 = and {y - 1}, %par + %3 = shr {n}, %par + %4 = shl {n}, %par + %5 = mod {y}, %par + %6 = div {y}, %par + return %1, %2, %3, %4, %5, %6 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_exp(): + # x ** 0 == 0 ** x -> 1 + # x ** 1 -> x + pre = """ + _global: + %par = param + %1 = exp %par, 0 + %2 = exp 1, %par + %3 = exp 0, %par + %4 = exp %par, 1 + return %1, %2, %3, %4 + """ + post = """ + _global: + %par = param + %3 = iszero %par + return 1, 1, %3, %par + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_compare_self(): + # x < x == x > x -> 0 + pre = """ + _global: + %par = param + %tmp = %par + %1 = gt %tmp, %par + %2 = sgt %tmp, %par + %3 = lt %tmp, %par + %4 = slt %tmp, %par + return %1, %2, %3, %4 + """ + post = """ + _global: + %par = param + return 0, 0, 0, 0 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_or(): + # x | 0 -> x + # x | 0xFF..FF -> 0xFF..FF + max_uint256 = 2**256 - 1 + pre = f""" + _global: + %par = param + %1 = or %par, 0 + %2 = or %par, {max_uint256} + %3 = or 0, %par + %4 = or {max_uint256}, %par + return %1, %2, %3, %4 + """ + post = f""" + _global: + %par = param + return %par, {max_uint256}, %par, {max_uint256} + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_eq(): + # (x == 0) == (0 == x) -> iszero x + # x == x -> 1 + # x == 0xFFFF..FF -> iszero(not x) + pre = """ + global: + %par = param + %1 = eq %par, 0 + %2 = eq 0, %par + + %3 = eq %par, -1 + %4 = eq -1, %par + + %5 = eq %par, %par + return %1, %2, %3, %4, %5 + """ + post = """ + global: + %par = param + %1 = iszero %par + %2 = iszero %par + %6 = not %par + %3 = iszero %6 + %7 = not %par + %4 = iszero %7 + return %1, %2, %3, %4, 1 + """ + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_boolean_or(): + # x | (non zero) -> 1 if it is only used as boolean + some_nonzero = 123 + pre = f""" + _global: + %par = param + %1 = or %par, {some_nonzero} + %2 = or %par, {some_nonzero} + assert %1 + %3 = or {some_nonzero}, %par + %4 = or {some_nonzero}, %par + assert %3 + return %2, %4 + """ + post = f""" + _global: + %par = param + %2 = or {some_nonzero}, %par + assert 1 + %4 = or {some_nonzero}, %par + assert 1 + return %2, %4 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_sccp_algebraic_opt_boolean_eq(): + # x == y -> iszero (x ^ y) if it is only used as boolean + pre = """ + _global: + %par = param + %par2 = param + %1 = eq %par, %par2 + %2 = eq %par, %par2 + assert %1 + return %2 + + """ + post = """ + _global: + %par = param + %par2 = param + %3 = xor %par, %par2 + %1 = iszero %3 + %2 = eq %par, %par2 + assert %1 + return %2 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_compare_never(): + # unsigned x > 0xFF..FF == x < 0 -> 0 + # signed: x > MAX_SIGNED (0x3F..FF) == x < MIN_SIGNED (0xF0..00) -> 0 + min_int256 = -(2**255) + max_int256 = 2**255 - 1 + min_uint256 = 0 + max_uint256 = 2**256 - 1 + pre = f""" + _global: + %par = param + + %1 = slt %par, {min_int256} + %2 = sgt %par, {max_int256} + %3 = lt %par, {min_uint256} + %4 = gt %par, {max_uint256} + + return %1, %2, %3, %4 + """ + post = """ + _global: + %par = param + return 0, 0, 0, 0 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_comparison_zero(): + # x > 0 => iszero(iszero x) + # 0 < x => iszero(iszero x) + pre = """ + _global: + %par = param + %1 = lt 0, %par + %2 = gt %par, 0 + return %1, %2 + """ + post = """ + _global: + %par = param + %3 = iszero %par + %1 = iszero %3 + %4 = iszero %par + %2 = iszero %4 + return %1, %2 + """ + + _sccp_algebraic_runner(pre, post) + + +def test_comparison_almost_never(): + # unsigned: + # x < 1 => eq x 0 => iszero x + # MAX_UINT - 1 < x => eq x MAX_UINT => iszero(not x) + # signed + # x < MIN_INT + 1 => eq x MIN_INT + # MAX_INT - 1 < x => eq x MAX_INT + + max_uint256 = 2**256 - 1 + max_int256 = 2**255 - 1 + min_int256 = -(2**255) + pre1 = f""" + _global: + %par = param + %1 = lt %par, 1 + %2 = gt %par, {max_uint256 - 1} + %3 = sgt %par, {max_int256 - 1} + %4 = slt %par, {min_int256 + 1} + + return %1, %2, %3, %4 + """ + # commuted versions - produce same output + pre2 = f""" + _global: + %par = param + %1 = gt 1, %par + %2 = lt {max_uint256 - 1}, %par + %3 = slt {max_int256 - 1}, %par + %4 = sgt {min_int256 + 1}, %par + return %1, %2, %3, %4 + """ + post = f""" + _global: + %par = param + ; lt %par, 1 => eq 0, %par => iszero %par + %1 = iszero %par + ; x > MAX_UINT256 - 1 => eq MAX_UINT x => iszero(not x) + %5 = not %par + %2 = iszero %5 + %3 = eq {max_int256}, %par + %4 = eq {min_int256}, %par + return %1, %2, %3, %4 + """ + + _sccp_algebraic_runner(pre1, post) + _sccp_algebraic_runner(pre2, post) + + +def test_comparison_almost_always(): + # unsigned + # x > 0 => iszero(iszero x) + # 0 < x => iszero(iszero x) + # x < MAX_UINT => iszero(eq x MAX_UINT) => iszero(iszero(not x)) + # signed + # x < MAX_INT => iszero(eq MAX_INT) => iszero(iszero(xor MAX_INT x)) + + max_uint256 = 2**256 - 1 + max_int256 = 2**255 - 1 + min_int256 = -(2**255) + + pre1 = f""" + _global: + %par = param + %1 = gt %par, 0 + %2 = lt %par, {max_uint256} + assert %2 + %3 = slt %par, {max_int256} + assert %3 + %4 = sgt %par, {min_int256} + assert %4 + return %1 + """ + # commuted versions + pre2 = f""" + _global: + %par = param + %1 = lt 0, %par + %2 = gt {max_uint256}, %par + assert %2 + %3 = sgt {max_int256}, %par + assert %3 + %4 = slt {min_int256}, %par + assert %4 + return %1 + """ + post = f""" + _global: + %par = param + %5 = iszero %par + %1 = iszero %5 + %9 = not %par ; (eq -1 x) => (iszero (not x)) + %6 = iszero %9 + %2 = iszero %6 + assert %2 + %10 = xor %par, {max_int256} + %7 = iszero %10 + %3 = iszero %7 + assert %3 + %11 = xor %par, {min_int256} + %8 = iszero %11 + %4 = iszero %8 + assert %4 + return %1 + """ + + _sccp_algebraic_runner(pre1, post) + _sccp_algebraic_runner(pre2, post) + + +@pytest.mark.parametrize("val", (100, 2, 3, -100)) +def test_comparison_ge_le(val): + # iszero(x < 100) => 99 < x + # iszero(x > 100) => 101 > x + + up = val + 1 + down = val - 1 + + abs_val = abs(val) + abs_up = abs_val + 1 + abs_down = abs_val - 1 + + pre1 = f""" + _global: + %par = param + %1 = lt %par, {abs_val} + %3 = gt %par, {abs_val} + %2 = iszero %1 + %4 = iszero %3 + %5 = slt %par, {val} + %7 = sgt %par, {val} + %6 = iszero %5 + %8 = iszero %7 + return %2, %4, %6, %8 + """ + pre2 = f""" + _global: + %par = param + %1 = gt {abs_val}, %par + %3 = lt {abs_val}, %par + %2 = iszero %1 + %4 = iszero %3 + %5 = sgt {val}, %par + %7 = slt {val}, %par + %6 = iszero %5 + %8 = iszero %7 + return %2, %4, %6, %8 + """ + post = f""" + _global: + %par = param + %1 = lt {abs_down}, %par + %3 = gt {abs_up}, %par + %5 = slt {down}, %par + %7 = sgt {up}, %par + return %1, %3, %5, %7 + """ + + _sccp_algebraic_runner(pre1, post) + _sccp_algebraic_runner(pre2, post) diff --git a/tests/unit/compiler/venom/test_algebraic_optimizer.py b/tests/unit/compiler/venom/test_algebraic_optimizer.py index e0368d4197..00ccb0684a 100644 --- a/tests/unit/compiler/venom/test_algebraic_optimizer.py +++ b/tests/unit/compiler/venom/test_algebraic_optimizer.py @@ -1,11 +1,10 @@ import pytest -from vyper.venom.analysis.analysis import IRAnalysesCache +import vyper +from vyper.venom.analysis import IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.context import IRContext -from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass -from vyper.venom.passes.make_ssa import MakeSSA -from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass +from vyper.venom.passes import AlgebraicOptimizationPass, MakeSSA, RemoveUnusedVariablesPass @pytest.mark.parametrize("iszero_count", range(5)) @@ -127,3 +126,80 @@ def test_interleaved_case(interleave_point): assert bb.instructions[-1].operands[0] == op3_inv else: assert bb.instructions[-1].operands[0] == op3 + + +def test_offsets(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", 32) + op2 = bb.append_instruction("add", 0, IRLabel("mem")) + op3 = bb.append_instruction("store", 64) + bb.append_instruction("dloadbytes", op1, op2, op3) + op5 = bb.append_instruction("mload", op3) + op6 = bb.append_instruction("iszero", op5) + bb.append_instruction("jnz", op6, br1.label, br2.label) + + op01 = br1.append_instruction("store", 32) + op02 = br1.append_instruction("add", 0, IRLabel("mem")) + op03 = br1.append_instruction("store", 64) + br1.append_instruction("dloadbytes", op01, op02, op03) + op05 = br1.append_instruction("mload", op03) + op06 = br1.append_instruction("iszero", op05) + br1.append_instruction("return", p1, op06) + + op11 = br2.append_instruction("store", 32) + op12 = br2.append_instruction("add", 0, IRLabel("mem")) + op13 = br2.append_instruction("store", 64) + br2.append_instruction("dloadbytes", op11, op12, op13) + op15 = br2.append_instruction("mload", op13) + op16 = br2.append_instruction("iszero", op15) + br2.append_instruction("return", p1, op16) + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + + offset_count = 0 + for bb in fn.get_basic_blocks(): + for instruction in bb.instructions: + assert instruction.opcode != "add" + if instruction.opcode == "offset": + offset_count += 1 + + assert offset_count == 3 + + +# Test the case of https://github.com/vyperlang/vyper/issues/4288 +def test_ssa_after_algebraic_optimization(): + code = """ +@internal +def _do_math(x: uint256) -> uint256: + value: uint256 = x + result: uint256 = 0 + + if (x >> 128 != 0): + x >>= 128 + if (x >> 64 != 0): + x >>= 64 + + if 1 < value: + result = 1 + + return result + +@external +def run() -> uint256: + return self._do_math(10) + """ + + vyper.compile_code(code, output_formats=["bytecode"]) diff --git a/tests/unit/compiler/venom/test_branch_optimizer.py b/tests/unit/compiler/venom/test_branch_optimizer.py index b6e806e217..a96ed0709c 100644 --- a/tests/unit/compiler/venom/test_branch_optimizer.py +++ b/tests/unit/compiler/venom/test_branch_optimizer.py @@ -1,9 +1,7 @@ -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis import DFGAnalysis, IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.context import IRContext -from vyper.venom.passes.branch_optimization import BranchOptimizationPass -from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes import BranchOptimizationPass, MakeSSA def test_simple_jump_case(): @@ -18,15 +16,16 @@ def test_simple_jump_case(): fn.append_basic_block(br2) p1 = bb.append_instruction("param") + p2 = bb.append_instruction("param") op1 = bb.append_instruction("store", p1) op2 = bb.append_instruction("store", 64) op3 = bb.append_instruction("add", op1, op2) jnz_input = bb.append_instruction("iszero", op3) bb.append_instruction("jnz", jnz_input, br1.label, br2.label) - br1.append_instruction("add", op3, 10) + br1.append_instruction("add", op3, p1) br1.append_instruction("stop") - br2.append_instruction("add", op3, p1) + br2.append_instruction("add", op3, p2) br2.append_instruction("stop") term_inst = bb.instructions[-1] @@ -49,6 +48,6 @@ def test_simple_jump_case(): # Test that the dfg is updated correctly dfg = ac.request_analysis(DFGAnalysis) - assert dfg is old_dfg, "DFG should not be invalidated by BranchOptimizationPass" + assert dfg is not old_dfg, "DFG should be invalidated by BranchOptimizationPass" assert term_inst in dfg.get_uses(op3), "jnz not using the new condition" assert term_inst not in dfg.get_uses(jnz_input), "jnz still using the old condition" diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py index 29f86df221..30a2e4564e 100644 --- a/tests/unit/compiler/venom/test_dominator_tree.py +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -2,12 +2,11 @@ from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.dominators import DominatorTreeAnalysis +from vyper.venom.analysis import DominatorTreeAnalysis, IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable from vyper.venom.context import IRContext from vyper.venom.function import IRFunction -from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes import MakeSSA def _add_bb( diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 44c4ed0404..89b06796e3 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -1,6 +1,8 @@ from vyper.compiler.settings import OptimizationLevel from vyper.venom import generate_assembly_experimental +from vyper.venom.analysis import IRAnalysesCache from vyper.venom.context import IRContext +from vyper.venom.passes import StoreExpansionPass def test_duplicate_operands(): @@ -13,7 +15,7 @@ def test_duplicate_operands(): %3 = mul %1, %2 stop - Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] + Should compile to: [PUSH1, 10, DUP1, DUP2, ADD, MUL, POP, STOP] """ ctx = IRContext() fn = ctx.create_function("test") @@ -23,5 +25,9 @@ def test_duplicate_operands(): bb.append_instruction("mul", sum_, op) bb.append_instruction("stop") - asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) - assert asm == ["PUSH1", 10, "DUP1", "DUP1", "ADD", "MUL", "STOP"] + ac = IRAnalysesCache(fn) + StoreExpansionPass(ac, fn).run_pass() + + optimize = OptimizationLevel.GAS + asm = generate_assembly_experimental(ctx, optimize=optimize) + assert asm == ["PUSH1", 10, "DUP1", "DUP2", "ADD", "MUL", "POP", "STOP"] diff --git a/tests/unit/compiler/venom/test_literals_codesize.py b/tests/unit/compiler/venom/test_literals_codesize.py new file mode 100644 index 0000000000..4de4d9de64 --- /dev/null +++ b/tests/unit/compiler/venom/test_literals_codesize.py @@ -0,0 +1,117 @@ +import pytest + +from vyper.utils import evm_not +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRLiteral +from vyper.venom.context import IRContext +from vyper.venom.passes import ReduceLiteralsCodesize + + +def _calc_push_size(val: int): + s = hex(val).removeprefix("0x") + if len(s) % 2 != 0: # justify to multiple of 2 + s = "0" + s + return 1 + len(s) + + +should_invert = [2**256 - 1] + [((2**i) - 1) << (256 - i) for i in range(121, 256 + 1)] + + +@pytest.mark.parametrize("orig_value", should_invert) +def test_literal_codesize_ff_inversion(orig_value): + """ + Test that literals like 0xfffffffffffabcd get inverted to `not 0x5432` + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + inst0 = bb.instructions[0] + assert inst0.opcode == "not" + op0 = inst0.operands[0] + assert evm_not(op0.value) == orig_value + # check the optimization actually improved codesize, after accounting + # for the addl NOT instruction + assert _calc_push_size(op0.value) + 1 < _calc_push_size(orig_value) + + +should_not_invert = [1, 0xFE << 248 | (2**248 - 1)] + [ + ((2**255 - 1) >> i) << i for i in range(0, 3 * 8) +] + + +@pytest.mark.parametrize("orig_value", should_not_invert) +def test_literal_codesize_no_inversion(orig_value): + """ + Check funky cases where inversion would result in bytecode increase + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "store" + assert bb.instructions[0].operands[0].value == orig_value + + +should_shl = ( + [2**i for i in range(3 * 8, 255)] + + [((2**i) - 1) << (256 - i) for i in range(1, 121)] + + [((2**255 - 1) >> i) << i for i in range(3 * 8, 254)] +) + + +@pytest.mark.parametrize("orig_value", should_shl) +def test_literal_codesize_shl(orig_value): + """ + Test that literals like 0xabcd00000000 get transformed to `shl 32 0xabcd` + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "shl" + op0, op1 = bb.instructions[0].operands + assert op0.value << op1.value == orig_value + + # check the optimization actually improved codesize, after accounting + # for the addl PUSH and SHL instructions + assert _calc_push_size(op0.value) + _calc_push_size(op1.value) + 1 < _calc_push_size(orig_value) + + +should_not_shl = [1 << i for i in range(0, 3 * 8)] + [ + 0x0, + (((2 ** (256 - 2)) - 1) << (2 * 8)) ^ (2**255), +] + + +@pytest.mark.parametrize("orig_value", should_not_shl) +def test_literal_codesize_no_shl(orig_value): + """ + Check funky cases where shl transformation would result in bytecode increase + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "store" + assert bb.instructions[0].operands[0].value == orig_value diff --git a/tests/unit/compiler/venom/test_load_elimination.py b/tests/unit/compiler/venom/test_load_elimination.py new file mode 100644 index 0000000000..52c7baf3c9 --- /dev/null +++ b/tests/unit/compiler/venom/test_load_elimination.py @@ -0,0 +1,129 @@ +from tests.venom_utils import assert_ctx_eq, parse_from_basic_block +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.passes.load_elimination import LoadElimination + + +def _check_pre_post(pre, post): + ctx = parse_from_basic_block(pre) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + LoadElimination(ac, fn).run_pass() + + assert_ctx_eq(ctx, parse_from_basic_block(post)) + + +def _check_no_change(pre): + _check_pre_post(pre, pre) + + +def test_simple_load_elimination(): + pre = """ + main: + %ptr = 11 + %1 = mload %ptr + + %2 = mload %ptr + + stop + """ + post = """ + main: + %ptr = 11 + %1 = mload %ptr + + %2 = %1 + + stop + """ + _check_pre_post(pre, post) + + +def test_equivalent_var_elimination(): + """ + Test that the lattice can "peer through" equivalent vars + """ + pre = """ + main: + %1 = 11 + %2 = %1 + %3 = mload %1 + + %4 = mload %2 + + stop + """ + post = """ + main: + %1 = 11 + %2 = %1 + %3 = mload %1 + + %4 = %3 # %2 == %1 + + stop + """ + _check_pre_post(pre, post) + + +def test_elimination_barrier(): + """ + Check for barrier between load/load + """ + pre = """ + main: + %1 = 11 + %2 = mload %1 + %3 = %100 + # fence - writes to memory + staticcall %3, %3, %3, %3 + %4 = mload %1 + """ + _check_no_change(pre) + + +def test_store_load_elimination(): + """ + Check that lattice stores the result of mstores (even through + equivalent variables) + """ + pre = """ + main: + %val = 55 + %ptr1 = 11 + %ptr2 = %ptr1 + mstore %ptr1, %val + + %3 = mload %ptr2 + + stop + """ + post = """ + main: + %val = 55 + %ptr1 = 11 + %ptr2 = %ptr1 + mstore %ptr1, %val + + %3 = %val + + stop + """ + _check_pre_post(pre, post) + + +def test_store_load_barrier(): + """ + Check for barrier between store/load + """ + pre = """ + main: + %ptr = 11 + %val = 55 + mstore %ptr, %val + %3 = %100 ; arbitrary + # fence + staticcall %3, %3, %3, %3 + %4 = mload %ptr + """ + _check_no_change(pre) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py index 9cea1a20a4..7f6b2c0cba 100644 --- a/tests/unit/compiler/venom/test_make_ssa.py +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -1,48 +1,52 @@ -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.basicblock import IRBasicBlock, IRLabel -from vyper.venom.context import IRContext -from vyper.venom.passes.make_ssa import MakeSSA +from tests.venom_utils import assert_ctx_eq, parse_venom +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.passes import MakeSSA -def test_phi_case(): - ctx = IRContext() - fn = ctx.create_function("_global") - - bb = fn.get_basic_block() - - bb_cont = IRBasicBlock(IRLabel("condition"), fn) - bb_then = IRBasicBlock(IRLabel("then"), fn) - bb_else = IRBasicBlock(IRLabel("else"), fn) - bb_if_exit = IRBasicBlock(IRLabel("if_exit"), fn) - fn.append_basic_block(bb_cont) - fn.append_basic_block(bb_then) - fn.append_basic_block(bb_else) - fn.append_basic_block(bb_if_exit) - - v = bb.append_instruction("mload", 64) - bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) - - bb_if_exit.append_instruction("add", v, 1, ret=v) - bb_if_exit.append_instruction("jmp", bb_cont.label) +def _check_pre_post(pre, post): + ctx = parse_venom(pre) + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + assert_ctx_eq(ctx, parse_venom(post)) - bb_then.append_instruction("assert", bb_then.append_instruction("mload", 96)) - bb_then.append_instruction("jmp", bb_if_exit.label) - bb_else.append_instruction("jmp", bb_if_exit.label) - bb.append_instruction("jmp", bb_cont.label) - - ac = IRAnalysesCache(fn) - MakeSSA(ac, fn).run_pass() - - condition_block = fn.get_basic_block("condition") - assert len(condition_block.instructions) == 2 - - phi_inst = condition_block.instructions[0] - assert phi_inst.opcode == "phi" - assert phi_inst.operands[0].name == "_global" - assert phi_inst.operands[1].name == "%1" - assert phi_inst.operands[2].name == "if_exit" - assert phi_inst.operands[3].name == "%1" - assert phi_inst.output.name == "%1" - assert phi_inst.output.value != phi_inst.operands[1].value - assert phi_inst.output.value != phi_inst.operands[3].value +def test_phi_case(): + pre = """ + function loop { + main: + %v = mload 64 + jmp @test + test: + jnz %v, @then, @else + then: + %t = mload 96 + assert %t + jmp @if_exit + else: + jmp @if_exit + if_exit: + %v = add %v, 1 + jmp @test + } + """ + post = """ + function loop { + main: + %v = mload 64 + jmp @test + test: + %v:1 = phi @main, %v, @if_exit, %v:2 + jnz %v:1, @then, @else + then: + %t = mload 96 + assert %t + jmp @if_exit + else: + jmp @if_exit + if_exit: + %v:2 = add %v:1, 1 + jmp @test + } + """ + _check_pre_post(pre, post) diff --git a/tests/unit/compiler/venom/test_memmerging.py b/tests/unit/compiler/venom/test_memmerging.py new file mode 100644 index 0000000000..d309752621 --- /dev/null +++ b/tests/unit/compiler/venom/test_memmerging.py @@ -0,0 +1,1065 @@ +import pytest + +from tests.venom_utils import assert_ctx_eq, parse_from_basic_block, parse_venom +from vyper.evm.opcodes import version_check +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.passes import SCCP, MemMergePass + + +def _check_pre_post(pre, post): + ctx = parse_from_basic_block(pre) + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + MemMergePass(ac, fn).run_pass() + assert_ctx_eq(ctx, parse_from_basic_block(post)) + + +def _check_no_change(pre): + _check_pre_post(pre, pre) + + +# for parametrizing tests +LOAD_COPY = [("dload", "dloadbytes"), ("calldataload", "calldatacopy")] + + +def test_memmerging(): + """ + Basic memory merge test + All mloads and mstores can be + transformed into mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + mstore 1064, %3 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_out_of_order(): + """ + interleaved mloads/mstores which can be transformed into mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 32 + %2 = mload 0 + mstore 132, %1 + %3 = mload 64 + mstore 164, %3 + mstore 100, %2 + stop + """ + + post = """ + _global: + mcopy 100, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_imposs(): + """ + Test case of impossible merge + Impossible because of the overlap + [0 96] + [32 128] + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 32, %1 + + ; BARRIER - overlap between src and dst + ; (writes to source of potential mcopy) + mstore 64, %2 + + mstore 96, %3 + stop + """ + _check_no_change(pre) + + +def test_memmerging_imposs_mstore(): + """ + Test case of impossible merge + Impossible because of the mstore barrier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 16 + mstore 1000, %1 + %3 = mload 1000 ; BARRIER - load from dst of potential mcopy + mstore 1016, %2 + mstore 2000, %3 + stop + """ + _check_no_change(pre) + + +@pytest.mark.xfail +def test_memmerging_bypass_fence(): + """ + We should be able to optimize this to an mcopy(0, 1000, 64), but + currently do not + """ + if not version_check(begin="cancun"): + raise AssertionError() # xfail + + pre = """ + function _global { + _global: + %1 = mload 0 + %2 = mload 32 + mstore %1, 1000 + %3 = mload 1000 + mstore 1032, %2 + mstore 2000, %3 + stop + } + """ + + ctx = parse_venom(pre) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + MemMergePass(ac, fn).run_pass() + + fn = next(iter(ctx.functions.values())) + bb = fn.entry + assert any(inst.opcode == "mcopy" for inst in bb.instructions) + + +def test_memmerging_imposs_unkown_place(): + """ + Test case of impossible merge + Impossible because of the + non constant address mload and mstore barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = param + %2 = mload 0 + %3 = mload %1 ; BARRIER + %4 = mload 32 + %5 = mload 64 + mstore 1000, %2 + mstore 1032, %4 + mstore 10, %1 ; BARRIER + mstore 1064, %5 + stop + """ + _check_no_change(pre) + + +def test_memmerging_imposs_msize(): + """ + Test case of impossible merge + Impossible because of the msize barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = msize ; BARRIER + %3 = mload 32 + %4 = mload 64 + mstore 1000, %1 + mstore 1032, %3 + %5 = msize ; BARRIER + mstore 1064, %4 + return %2, %5 + """ + _check_no_change(pre) + + +def test_memmerging_partial_msize(): + """ + Only partial merge possible + because of the msize barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + %4 = msize ; BARRIER + mstore 1064, %3 + return %4 + """ + + post = """ + _global: + %3 = mload 64 + mcopy 1000, 0, 64 + %4 = msize + mstore 1064, %3 + return %4 + """ + _check_pre_post(pre, post) + + +def test_memmerging_partial_overlap(): + """ + Two different copies from overlapping + source range + + [0 128] + [24 88] + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + %4 = mload 96 + %5 = mload 24 + %6 = mload 56 + mstore 1064, %3 + mstore 1096, %4 + mstore 1000, %1 + mstore 1032, %2 + mstore 2024, %5 + mstore 2056, %6 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 128 + mcopy 2024, 24, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_partial_different_effect(): + """ + Only partial merge possible + because of the generic memory + effect barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + dloadbytes 2000, 1000, 1000 ; BARRIER + mstore 1064, %3 + stop + """ + + post = """ + _global: + %3 = mload 64 + mcopy 1000, 0, 64 + dloadbytes 2000, 1000, 1000 + mstore 1064, %3 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerge_ok_interval_subset(): + """ + Test subintervals get subsumed by larger intervals + mstore(, mload()) + mcopy(, , 64) + => + mcopy(, , 64) + Because the first mload/mstore is contained in the mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + mstore 100, %1 + mcopy 100, 0, 33 + stop + """ + + post = """ + _global: + mcopy 100, 0, 33 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_ok_overlap(): + """ + Test for with source overlap + which is ok to do + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 24 + %3 = mload 48 + mstore 1000, %1 + mstore 1024, %2 + mstore 1048, %3 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 80 + stop + """ + + _check_pre_post(pre, post) + + +def test_memmerging_mcopy(): + """ + Test that sequences of mcopy get merged (not just loads/stores) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 32 + mcopy 1032, 32, 32 + mcopy 1064, 64, 64 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 128 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_small(): + """ + Test that sequences of mcopies get merged, and that mcopy of 32 bytes + gets transformed to mload/mstore (saves 1 byte) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 16 + mcopy 1016, 16, 16 + stop + """ + + post = """ + _global: + %1 = mload 0 + mstore 1000, %1 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_weird_bisect(): + """ + Check that bisect_left finds the correct merge + copy(80, 100, 2) + copy(150, 60, 1) + copy(82, 102, 3) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 80, 100, 2 + mcopy 150, 60, 1 + mcopy 82, 102, 3 + stop + """ + + post = """ + _global: + mcopy 150, 60, 1 + mcopy 80, 100, 5 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_weird_bisect2(): + """ + Check that bisect_left finds the correct merge + copy(80, 50, 2) + copy(20, 100, 1) + copy(82, 52, 3) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 80, 50, 2 + mcopy 20, 100, 1 + mcopy 82, 52, 3 + stop + """ + + post = """ + _global: + mcopy 20, 100, 1 + mcopy 80, 50, 5 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_allowed_overlapping(): + """ + Test merge of interleaved mload/mstore/mcopy works + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 32 + mcopy 1000, 32, 128 + %2 = mload 0 + mstore 2032, %1 + mstore 2000, %2 + stop + """ + + post = """ + _global: + mcopy 1000, 32, 128 + mcopy 2000, 0, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_allowed_overlapping2(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 64 + %1 = mload 1032 + mstore 2000, %1 + %2 = mload 1064 + mstore 2032, %2 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 64 + mcopy 2000, 1032, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_unused_mload(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 64, %2 + + # does not interfere with the mload/mstore merging even though + # it cannot be removed + %3 = mload 32 + + mstore 32, %1 + return %3, %3 + """ + + post = """ + _global: + %3 = mload 32 + mcopy 32, 100, 64 + return %3, %3 + """ + + _check_pre_post(pre, post) + + +def test_memmerging_unused_mload1(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 0, %1 + + # does not interfere with the mload/mstore merging even though + # it cannot be removed + %3 = mload 32 + + mstore 32, %2 + return %3, %3 + """ + + post = """ + _global: + %3 = mload 32 + mcopy 0, 100, 64 + return %3, %3 + """ + _check_pre_post(pre, post) + + +def test_memmerging_mload_read_after_write_hazard(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 0, %1 + %3 = mload 32 + mstore 32, %2 + %4 = mload 64 + + ; BARRIER - the load is overriden by existing copy + mstore 1000, %3 + mstore 1032, %4 + stop + """ + + post = """ + _global: + %3 = mload 32 + mcopy 0, 100, 64 + %4 = mload 64 + mstore 1000, %3 + mstore 1032, %4 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_read_after_write_hazard(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 32, 64 + mcopy 2000, 1000, 64 ; BARRIER + mcopy 1064, 96, 64 + stop + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mstore+mstore version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 100 + %3 = mload 32 + %4 = mload 132 + mstore 1000, %1 + mstore 1000, %2 ; BARRIER + mstore 1032, %4 + mstore 1032, %3 ; BARRIER + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write_mstore_and_mcopy(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mstore+mcopy version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 132 + mstore 1000, %1 + mcopy 1000, 100, 16 ; write barrier + mstore 1032, %2 + mcopy 1016, 116, 64 + stop + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write_only_mcopy(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mcopy+mcopy version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 16 + mcopy 1000, 100, 16 ; write barrier + mcopy 1016, 116, 64 + mcopy 1016, 16, 64 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 16 + mcopy 1000, 100, 80 + mcopy 1016, 16, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_not_allowed_overlapping(): + if not version_check(begin="cancun"): + return + + # NOTE: maybe optimization is possible here, to: + # mcopy 2000, 1000, 64 + # mcopy 1000, 0, 128 + pre = """ + _global: + %1 = mload 1000 + %2 = mload 1032 + mcopy 1000, 0, 128 + mstore 2000, %1 ; BARRIER - the mload and mcopy cannot be combined + mstore 2032, %2 + stop + """ + _check_no_change(pre) + + +def test_memmerging_not_allowed_overlapping2(): + if not version_check(begin="cancun"): + return + + # NOTE: maybe optimization is possible here, to: + # mcopy 2000, 1000, 64 + # mcopy 1000, 0, 128 + pre = """ + _global: + %1 = mload 1032 + mcopy 1000, 0, 64 + mstore 2000, %1 + %2 = mload 1064 + mstore 2032, %2 + stop + """ + + _check_no_change(pre) + + +def test_memmerging_existing_copy_overwrite(): + """ + Check that memmerge does not write over source of another copy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 64 + %1 = mload 2000 + + # barrier, write over source of existing copy + mstore 0, %1 + + mcopy 1064, 64, 64 + stop + """ + + _check_no_change(pre) + + +def test_memmerging_double_use(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + mstore 1000, %1 + mstore 1032, %2 + return %1 + """ + + post = """ + _global: + %1 = mload 0 + mcopy 1000, 0, 64 + return %1 + """ + + _check_pre_post(pre, post) + + +@pytest.mark.parametrize("load_opcode,copy_opcode", LOAD_COPY) +def test_memmerging_load(load_opcode, copy_opcode): + pre = f""" + _global: + %1 = {load_opcode} 0 + mstore 32, %1 + %2 = {load_opcode} 32 + mstore 64, %2 + stop + """ + + post = f""" + _global: + {copy_opcode} 32, 0, 64 + stop + """ + _check_pre_post(pre, post) + + +@pytest.mark.parametrize("load_opcode,copy_opcode", LOAD_COPY) +def test_memmerging_two_intervals_diff_offset(load_opcode, copy_opcode): + """ + Test different dloadbytes/calldatacopy sequences are separately merged + """ + pre = f""" + _global: + %1 = {load_opcode} 0 + mstore 0, %1 + {copy_opcode} 32, 32, 64 + %2 = {load_opcode} 0 + mstore 8, %2 + {copy_opcode} 40, 32, 64 + stop + """ + + post = f""" + _global: + {copy_opcode} 0, 0, 96 + {copy_opcode} 8, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_1(): + """ + Test of basic memzeroing done with mstore only + """ + + pre = """ + _global: + mstore 32, 0 + mstore 64, 0 + mstore 96, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_2(): + """ + Test of basic memzeroing done with calldatacopy only + + sequence of these instruction will + zero out the memory at destination + %1 = calldatasize + calldatacopy %1 + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 64, %1, 128 + %2 = calldatasize + calldatacopy 192, %2, 128 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %3 = calldatasize + calldatacopy 64, %3, 256 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_3(): + """ + Test of basic memzeroing done with combination of + mstores and calldatacopies + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 100 + mstore 100, 0 + %2 = calldatasize + calldatacopy 132, %2, 100 + mstore 232, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %3 = calldatasize + calldatacopy 0, %3, 264 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_small_calldatacopy(): + """ + Test of converting calldatacopy of + size 32 into mstore + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 32 + stop + """ + + post = """ + _global: + %1 = calldatasize + mstore 0, 0 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_smaller_calldatacopy(): + """ + Test merging smaller (<32) calldatacopies + into either calldatacopy or mstore + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 8 + %2 = calldatasize + calldatacopy 8, %2, 16 + %3 = calldatasize + calldatacopy 100, %3, 8 + %4 = calldatasize + calldatacopy 108, %4, 16 + %5 = calldatasize + calldatacopy 124, %5, 8 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %6 = calldatasize + calldatacopy 0, %6, 24 + %3 = calldatasize + %4 = calldatasize + %5 = calldatasize + mstore 100, 0 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_overlap(): + """ + Test of merging overlaping zeroing intervals + + [128 160] + [136 192] + """ + + pre = """ + _global: + mstore 100, 0 + %1 = calldatasize + calldatacopy 108, %1, 56 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + calldatacopy 100, %2, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_imposs(): + """ + Test of memzeroing barriers caused + by non constant arguments + """ + + pre = """ + _global: + %1 = param ; abstract location, causes barrier + mstore 32, 0 + mstore %1, 0 + mstore 64, 0 + %2 = calldatasize + calldatacopy %1, %2, 10 + mstore 96, 0 + %3 = calldatasize + calldatacopy 10, %3, %1 + mstore 128, 0 + calldatacopy 10, %1, 10 + mstore 160, 0 + stop + """ + _check_no_change(pre) + + +def test_memzeroing_imposs_effect(): + """ + Test of memzeroing bariers caused + by different effect + """ + + pre = """ + _global: + mstore 32, 0 + dloadbytes 10, 20, 30 ; BARRIER + mstore 64, 0 + stop + """ + _check_no_change(pre) + + +def test_memzeroing_overlaping(): + """ + Test merging overlapping memzeroes (they can be merged + since both result in zeroes being written to destination) + """ + + pre = """ + _global: + mstore 32, 0 + mstore 96, 0 + mstore 32, 0 + mstore 64, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_interleaved(): + """ + Test merging overlapping memzeroes (they can be merged + since both result in zeroes being written to destination) + """ + + pre = """ + _global: + mstore 32, 0 + mstore 1000, 0 + mstore 64, 0 + mstore 1032, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 64 + %2 = calldatasize + calldatacopy 1000, %2, 64 + stop + """ + _check_pre_post(pre, post) diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 313fbb3ebd..a38e4b4158 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -1,8 +1,7 @@ -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis import CFGAnalysis, IRAnalysesCache from vyper.venom.context import IRContext from vyper.venom.function import IRBasicBlock, IRLabel -from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes import NormalizationPass def test_multi_entry_block_1(): diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py index e65839136e..0d46b61acd 100644 --- a/tests/unit/compiler/venom/test_sccp.py +++ b/tests/unit/compiler/venom/test_sccp.py @@ -1,11 +1,10 @@ import pytest from vyper.exceptions import StaticAssertionException -from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis import IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral, IRVariable from vyper.venom.context import IRContext -from vyper.venom.passes.make_ssa import MakeSSA -from vyper.venom.passes.sccp import SCCP +from vyper.venom.passes import SCCP, MakeSSA from vyper.venom.passes.sccp.sccp import LatticeEnum @@ -168,8 +167,8 @@ def test_cont_phi_case(): assert sccp.lattice[IRVariable("%2")].value == 32 assert sccp.lattice[IRVariable("%3")].value == 64 assert sccp.lattice[IRVariable("%4")].value == 96 - assert sccp.lattice[IRVariable("%5", version=1)].value == 106 - assert sccp.lattice[IRVariable("%5", version=2)] == LatticeEnum.BOTTOM + assert sccp.lattice[IRVariable("%5", version=2)].value == 106 + assert sccp.lattice[IRVariable("%5", version=1)] == LatticeEnum.BOTTOM assert sccp.lattice[IRVariable("%5")].value == 2 @@ -208,6 +207,38 @@ def test_cont_phi_const_case(): assert sccp.lattice[IRVariable("%2")].value == 32 assert sccp.lattice[IRVariable("%3")].value == 64 assert sccp.lattice[IRVariable("%4")].value == 96 - assert sccp.lattice[IRVariable("%5", version=1)].value == 106 - assert sccp.lattice[IRVariable("%5", version=2)].value == 97 + # dependent on cfg traversal order + assert sccp.lattice[IRVariable("%5", version=2)].value == 106 + assert sccp.lattice[IRVariable("%5", version=1)].value == 97 assert sccp.lattice[IRVariable("%5")].value == 2 + + +def test_phi_reduction_after_unreachable_block(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) + + op = bb.append_instruction("store", 1) + true = IRLiteral(1) + bb.append_instruction("jnz", true, br1.label, join.label) + + op1 = br1.append_instruction("store", 2) + + br1.append_instruction("jmp", join.label) + + join.append_instruction("phi", bb.label, op, br1.label, op1) + join.append_instruction("stop") + + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + + assert join.instructions[0].opcode == "store", join.instructions[0] + assert join.instructions[0].operands == [op1] + + assert join.instructions[1].opcode == "stop" diff --git a/tests/unit/compiler/venom/test_simplify_cfg.py b/tests/unit/compiler/venom/test_simplify_cfg.py new file mode 100644 index 0000000000..583f10efda --- /dev/null +++ b/tests/unit/compiler/venom/test_simplify_cfg.py @@ -0,0 +1,37 @@ +from tests.venom_utils import assert_ctx_eq, parse_venom +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.passes import SCCP, SimplifyCFGPass + + +def test_phi_reduction_after_block_pruning(): + pre = """ + function _global { + _global: + jnz 1, @then, @else + then: + %1 = 1 + jmp @join + else: + %2 = 2 + jmp @join + join: + %3 = phi @then, %1, @else, %2 + stop + } + """ + post = """ + function _global { + _global: + %1 = 1 + %3 = %1 + stop + } + """ + ctx1 = parse_venom(pre) + for fn in ctx1.functions.values(): + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() + + ctx2 = parse_venom(post) + assert_ctx_eq(ctx1, ctx2) diff --git a/tests/unit/compiler/venom/test_stack_cleanup.py b/tests/unit/compiler/venom/test_stack_cleanup.py new file mode 100644 index 0000000000..7198861771 --- /dev/null +++ b/tests/unit/compiler/venom/test_stack_cleanup.py @@ -0,0 +1,17 @@ +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental +from vyper.venom.context import IRContext + + +def test_cleanup_stack(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + ret_val = bb.append_instruction("param") + op = bb.append_instruction("store", 10) + op2 = bb.append_instruction("store", op) + bb.append_instruction("add", op, op2) + bb.append_instruction("ret", ret_val) + + asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) + assert asm == ["PUSH1", 10, "DUP1", "ADD", "POP", "JUMP"] diff --git a/tests/unit/compiler/venom/test_stack_reorder.py b/tests/unit/compiler/venom/test_stack_reorder.py new file mode 100644 index 0000000000..8f38e00cdb --- /dev/null +++ b/tests/unit/compiler/venom/test_stack_reorder.py @@ -0,0 +1,33 @@ +from vyper.venom import generate_assembly_experimental +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.context import IRContext +from vyper.venom.passes import StoreExpansionPass + + +def test_stack_reorder(): + """ + Test to was created from the example in the + issue https://github.com/vyperlang/vyper/issues/4215 + this example should fail with original stack reorder + algorithm but succeed with new one + """ + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + var0 = bb.append_instruction("store", 1) + var1 = bb.append_instruction("store", 2) + var2 = bb.append_instruction("store", 3) + var3 = bb.append_instruction("store", 4) + var4 = bb.append_instruction("store", 5) + + bb.append_instruction("staticcall", var0, var1, var2, var3, var4, var3) + + ret_val = bb.append_instruction("add", var4, var4) + + bb.append_instruction("ret", ret_val) + + ac = IRAnalysesCache(fn) + StoreExpansionPass(ac, fn).run_pass() + + generate_assembly_experimental(ctx) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index b5bf86494d..aa9a702be3 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value, dummy_input_bundle): +def test_type_mismatch(namespace, value): code = f""" a: uint256[3] @@ -22,11 +22,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value, dummy_input_bundle): +def test_invalid_literal(namespace, value): code = f""" a: uint256[3] @@ -37,11 +37,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value, dummy_input_bundle): +def test_out_of_bounds(namespace, value): code = f""" a: uint256[3] @@ -52,11 +52,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value, dummy_input_bundle): +def test_undeclared_definition(namespace, value): code = f""" a: uint256[3] @@ -67,11 +67,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value, dummy_input_bundle): +def test_invalid_reference(namespace, value): code = f""" a: uint256[3] @@ -82,4 +82,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 990c839fde..da2e63c5fc 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -5,18 +5,37 @@ from vyper.semantics.analysis import analyze_module -def test_self_function_call(dummy_input_bundle): +def test_self_function_call(): code = """ @internal def foo(): self.foo() """ vyper_module = parse_to_ast(code) - with pytest.raises(CallViolation): - analyze_module(vyper_module, dummy_input_bundle) + with pytest.raises(CallViolation) as e: + analyze_module(vyper_module) + assert e.value.message == "Contract contains cyclic function call: foo -> foo" -def test_cyclic_function_call(dummy_input_bundle): + +def test_self_function_call2(): + code = """ +@external +def foo(): + self.bar() + +@internal +def bar(): + self.bar() + """ + vyper_module = parse_to_ast(code) + with pytest.raises(CallViolation) as e: + analyze_module(vyper_module) + + assert e.value.message == "Contract contains cyclic function call: foo -> bar -> bar" + + +def test_cyclic_function_call(): code = """ @internal def foo(): @@ -27,11 +46,13 @@ def bar(): self.foo() """ vyper_module = parse_to_ast(code) - with pytest.raises(CallViolation): - analyze_module(vyper_module, dummy_input_bundle) + with pytest.raises(CallViolation) as e: + analyze_module(vyper_module) + assert e.value.message == "Contract contains cyclic function call: foo -> bar -> foo" -def test_multi_cyclic_function_call(dummy_input_bundle): + +def test_multi_cyclic_function_call(): code = """ @internal def foo(): @@ -50,11 +71,42 @@ def potato(): self.foo() """ vyper_module = parse_to_ast(code) - with pytest.raises(CallViolation): - analyze_module(vyper_module, dummy_input_bundle) + with pytest.raises(CallViolation) as e: + analyze_module(vyper_module) + + expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> foo" + + assert e.value.message == expected_message + + +def test_multi_cyclic_function_call2(): + code = """ +@internal +def foo(): + self.bar() + +@internal +def bar(): + self.baz() + +@internal +def baz(): + self.potato() + +@internal +def potato(): + self.bar() + """ + vyper_module = parse_to_ast(code) + with pytest.raises(CallViolation) as e: + analyze_module(vyper_module) + + expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> bar" + + assert e.value.message == expected_message -def test_global_ann_assign_callable_no_crash(dummy_input_bundle): +def test_global_ann_assign_callable_no_crash(): code = """ balanceOf: public(HashMap[address, uint256]) @@ -64,5 +116,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index d7d4f7083b..810ff0a8b9 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -5,7 +5,7 @@ from vyper.semantics.analysis import analyze_module -def test_modify_iterator_function_outside_loop(dummy_input_bundle): +def test_modify_iterator_function_outside_loop(): code = """ a: uint256[3] @@ -21,10 +21,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_pass_memory_var_to_other_function(dummy_input_bundle): +def test_pass_memory_var_to_other_function(): code = """ @internal @@ -41,10 +41,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator(dummy_input_bundle): +def test_modify_iterator(): code = """ a: uint256[3] @@ -56,10 +56,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_keywords(dummy_input_bundle): +def test_bad_keywords(): code = """ @internal @@ -70,10 +70,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_bound(dummy_input_bundle): +def test_bad_bound(): code = """ @internal @@ -84,10 +84,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_function_call(dummy_input_bundle): +def test_modify_iterator_function_call(): code = """ a: uint256[3] @@ -103,10 +103,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call(dummy_input_bundle): +def test_modify_iterator_recursive_function_call(): code = """ a: uint256[3] @@ -126,10 +126,10 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): +def test_modify_iterator_recursive_function_call_topsort(): # test the analysis works no matter the order of functions code = """ a: uint256[3] @@ -149,12 +149,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_through_struct(dummy_input_bundle): +def test_modify_iterator_through_struct(): # GH issue 3429 code = """ struct A: @@ -170,12 +170,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_complex_expr(dummy_input_bundle): +def test_modify_iterator_complex_expr(): # GH issue 3429 # avoid false positive! code = """ @@ -189,10 +189,10 @@ def foo(): self.b[self.a[1]] = i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_siblings(dummy_input_bundle): +def test_modify_iterator_siblings(): # test we can modify siblings in an access tree code = """ struct Foo: @@ -207,10 +207,10 @@ def foo(): self.f.b += i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_subscript_barrier(dummy_input_bundle): +def test_modify_subscript_barrier(): # test that Subscript nodes are a barrier for analysis code = """ struct Foo: @@ -229,7 +229,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `b`" @@ -269,7 +269,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(code, dummy_input_bundle): +def test_iterator_type_inference_checker(code): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/tests/utils.py b/tests/utils.py index 8548c4f47a..b9dc443c0d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os from vyper import ast as vy_ast +from vyper.compiler.phases import CompilerData from vyper.semantics.analysis.constant_folding import constant_fold from vyper.utils import DECIMAL_EPSILON, round_towards_zero @@ -28,3 +29,24 @@ def parse_and_fold(source_code): def decimal_to_int(*args): s = decimal.Decimal(*args) return round_towards_zero(s / DECIMAL_EPSILON) + + +def check_precompile_asserts(source_code): + # common sanity check for some tests, that calls to precompiles + # are correctly wrapped in an assert. + + compiler_data = CompilerData(source_code) + deploy_ir = compiler_data.ir_nodes + runtime_ir = compiler_data.ir_runtime + + def _check(ir_node, parent=None): + if ir_node.value == "staticcall": + precompile_addr = ir_node.args[1] + if isinstance(precompile_addr.value, int) and precompile_addr.value < 10: + assert parent is not None and parent.value == "assert" + for arg in ir_node.args: + _check(arg, ir_node) + + _check(deploy_ir) + # technically runtime_ir is contained in deploy_ir, but check it anyways. + _check(runtime_ir) diff --git a/tests/venom_utils.py b/tests/venom_utils.py new file mode 100644 index 0000000000..6ddc61f615 --- /dev/null +++ b/tests/venom_utils.py @@ -0,0 +1,49 @@ +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.context import IRContext +from vyper.venom.function import IRFunction +from vyper.venom.parser import parse_venom + + +def parse_from_basic_block(source: str, funcname="_global"): + """ + Parse an IRContext from a basic block + """ + source = f"function {funcname} {{\n{source}\n}}" + return parse_venom(source) + + +def instructions_eq(i1: IRInstruction, i2: IRInstruction) -> bool: + return i1.output == i2.output and i1.opcode == i2.opcode and i1.operands == i2.operands + + +def assert_bb_eq(bb1: IRBasicBlock, bb2: IRBasicBlock): + assert bb1.label.value == bb2.label.value + for i1, i2 in zip(bb1.instructions, bb2.instructions): + assert instructions_eq(i1, i2), (bb1, f"[{i1}] != [{i2}]") + + # assert after individual instruction checks, makes it easier to debug + # if there is a difference. + assert len(bb1.instructions) == len(bb2.instructions) + + +def assert_fn_eq(fn1: IRFunction, fn2: IRFunction): + assert fn1.name.value == fn2.name.value + assert len(fn1._basic_block_dict) == len(fn2._basic_block_dict) + + for name1, bb1 in fn1._basic_block_dict.items(): + assert name1 in fn2._basic_block_dict + assert_bb_eq(bb1, fn2._basic_block_dict[name1]) + + # check function entry is the same + assert fn1.entry.label == fn2.entry.label + + +def assert_ctx_eq(ctx1: IRContext, ctx2: IRContext): + for label1, fn1 in ctx1.functions.items(): + assert label1 in ctx2.functions + assert_fn_eq(fn1, ctx2.functions[label1]) + assert len(ctx1.functions) == len(ctx2.functions) + + # check entry function is the same + assert next(iter(ctx1.functions.keys())) == next(iter(ctx2.functions.keys())) + assert ctx1.data_segment == ctx2.data_segment, ctx2.data_segment diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 97f9f70e24..bc2f9ba77c 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -318,7 +318,7 @@ COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ -STRING: /b?("(?!"").*?(? NatspecOutput: + try: + return _parse_natspec(annotated_vyper_module) + except NatSpecSyntaxException as e: + e.resolved_path = annotated_vyper_module.resolved_path + raise e + + +def _parse_natspec(annotated_vyper_module: vy_ast.Module) -> NatspecOutput: """ Parses NatSpec documentation from a contract. diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index b4042c75a7..24a0f9ade3 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -331,26 +331,10 @@ def get_fields(cls) -> set: slot_fields = [x for i in cls.__mro__ for x in getattr(i, "__slots__", [])] return set(i for i in slot_fields if not i.startswith("_")) - def __hash__(self): - values = [getattr(self, i, None) for i in VyperNode._public_slots] - return hash(tuple(values)) - def __deepcopy__(self, memo): # default implementation of deepcopy is a hotspot return pickle.loads(pickle.dumps(self)) - def __eq__(self, other): - # CMC 2024-03-03 I'm not sure it makes much sense to compare AST - # nodes, especially if they come from other modules - if not isinstance(other, type(self)): - return False - if getattr(other, "node_id", None) != getattr(self, "node_id", None): - return False - for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__): - if getattr(self, field_name, None) != getattr(other, field_name, None): - return False - return True - def __repr__(self): cls = type(self) class_repr = f"{cls.__module__}.{cls.__qualname__}" @@ -638,7 +622,7 @@ class TopLevel(VyperNode): class Module(TopLevel): # metadata - __slots__ = ("path", "resolved_path", "source_id") + __slots__ = ("path", "resolved_path", "source_id", "is_interface") def to_dict(self): return dict(source_sha256sum=self.source_sha256sum, **super().to_dict()) @@ -676,7 +660,6 @@ class DocStr(VyperNode): """ __slots__ = ("value",) - _translated_fields = {"s": "value"} class arguments(VyperNode): @@ -854,10 +837,15 @@ class Hex(Constant): def validate(self): if "_" in self.value: + # TODO: revisit this, we should probably allow underscores raise InvalidLiteral("Underscores not allowed in hex literals", self) if len(self.value) % 2: raise InvalidLiteral("Hex notation requires an even number of digits", self) + if self.value.startswith("0X"): + hint = f"Did you mean `0x{self.value[2:]}`?" + raise InvalidLiteral("Hex literal begins with 0X!", self, hint=hint) + @property def n_nibbles(self): """ @@ -882,17 +870,18 @@ def bytes_value(self): class Str(Constant): __slots__ = () - _translated_fields = {"s": "value"} def validate(self): for c in self.value: - if ord(c) >= 256: + # in utf-8, bytes in the 128 and up range deviate from latin1 and + # can be control bytes, allowing multi-byte characters. + # reject them here. + if ord(c) >= 128: raise InvalidLiteral(f"'{c}' is not an allowed string literal character", self) class Bytes(Constant): __slots__ = () - _translated_fields = {"s": "value"} def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): super().__init__(parent, **kwargs) @@ -906,9 +895,19 @@ def to_dict(self): ast_dict["value"] = f"0x{self.value.hex()}" return ast_dict - @property - def s(self): - return self.value + +class HexBytes(Constant): + __slots__ = () + + def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): + super().__init__(parent, **kwargs) + if isinstance(self.value, str): + self.value = bytes.fromhex(self.value) + + def to_dict(self): + ast_dict = super().to_dict() + ast_dict["value"] = f"0x{self.value.hex()}" + return ast_dict class List(ExprNode): @@ -944,6 +943,12 @@ def validate(self): class Ellipsis(Constant): __slots__ = () + def to_dict(self): + ast_dict = super().to_dict() + # python ast ellipsis() is not json serializable; use a string + ast_dict["value"] = self.node_source_code + return ast_dict + class Dict(ExprNode): __slots__ = ("keys", "values") @@ -1053,7 +1058,7 @@ def _op(self, left, right): raise OverflowException(msg, self) from None -class FloorDiv(VyperNode): +class FloorDiv(Operator): __slots__ = () _description = "integer division" _pretty = "//" @@ -1334,7 +1339,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class AnnAssign(VyperNode): +class AnnAssign(Stmt): __slots__ = ("target", "annotation", "value") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 1c7aaf55ee..b00354c03a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -23,6 +23,7 @@ class VyperNode: end_col_offset: int = ... _metadata: dict = ... _original_node: Optional[VyperNode] = ... + _children: list[VyperNode] = ... def __init__(self, parent: Optional[VyperNode] = ..., **kwargs: Any) -> None: ... def __hash__(self) -> Any: ... def __eq__(self, other: Any) -> Any: ... @@ -69,6 +70,8 @@ class TopLevel(VyperNode): class Module(TopLevel): path: str = ... resolved_path: str = ... + source_id: int = ... + is_interface: bool = ... def namespace(self) -> Any: ... # context manager class FunctionDef(TopLevel): @@ -108,9 +111,7 @@ class ExprNode(VyperNode): class Constant(ExprNode): value: Any = ... -class Num(Constant): - @property - def n(self): ... +class Num(Constant): ... class Int(Num): value: int = ... @@ -121,14 +122,9 @@ class Hex(Num): @property def n_bytes(self): ... -class Str(Constant): - @property - def s(self): ... - -class Bytes(Constant): - @property - def s(self): ... - +class Str(Constant): ... +class Bytes(Constant): ... +class HexBytes(Constant): ... class NameConstant(Constant): ... class Ellipsis(Constant): ... diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index d4569dd644..a7cd0464ed 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,15 +1,14 @@ import ast as python_ast +import pickle import tokenize from decimal import Decimal +from functools import cached_property from typing import Any, Dict, List, Optional, Union -import asttokens - from vyper.ast import nodes as vy_ast -from vyper.ast.pre_parser import pre_parse +from vyper.ast.pre_parser import PreParser from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic, ParserException, SyntaxException -from vyper.typing import ModificationOffsets from vyper.utils import sha256sum, vyper_warn @@ -24,6 +23,24 @@ def parse_to_ast_with_settings( module_path: Optional[str] = None, resolved_path: Optional[str] = None, add_fn_node: Optional[str] = None, + is_interface: bool = False, +) -> tuple[Settings, vy_ast.Module]: + try: + return _parse_to_ast_with_settings( + vyper_source, source_id, module_path, resolved_path, add_fn_node, is_interface + ) + except SyntaxException as e: + e.resolved_path = resolved_path + raise e + + +def _parse_to_ast_with_settings( + vyper_source: str, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, + add_fn_node: Optional[str] = None, + is_interface: bool = False, ) -> tuple[Settings, vy_ast.Module]: """ Parses a Vyper source string and generates basic Vyper AST nodes. @@ -47,6 +64,9 @@ def parse_to_ast_with_settings( resolved_path: str, optional The resolved path of the source code Corresponds to FileInput.resolved_path + is_interface: bool + Indicates whether the source code should + be parsed as an interface file. Returns ------- @@ -55,12 +75,36 @@ def parse_to_ast_with_settings( """ if "\x00" in vyper_source: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, for_loop_annotations, python_source = pre_parse(vyper_source) + pre_parser = PreParser() + pre_parser.parse(vyper_source) try: - py_ast = python_ast.parse(python_source) + py_ast = python_ast.parse(pre_parser.reformatted_code) except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors - raise SyntaxException(str(e), vyper_source, e.lineno, e.offset) from None + offset = e.offset + if offset is not None: + # SyntaxError offset is 1-based, not 0-based (see: + # https://docs.python.org/3/library/exceptions.html#SyntaxError.offset) + offset -= 1 + + # adjust the column of the error if it was modified by the pre-parser + if e.lineno is not None: # help mypy + offset += pre_parser.adjustments.get((e.lineno, offset), 0) + + new_e = SyntaxException(str(e), vyper_source, e.lineno, offset) + + likely_errors = ("staticall", "staticcal") + tmp = str(new_e) + for s in likely_errors: + if s in tmp: + new_e._hint = "did you mean `staticcall`?" + break + + raise new_e from None + + # some python AST node instances are singletons and are reused between + # parse() invocations. copy the python AST so that we are using fresh + # objects. + py_ast = _deepcopy_ast(py_ast) # Add dummy function node to ensure local variables are treated as `AnnAssign` # instead of state variables (`VariableDecl`) @@ -73,21 +117,28 @@ def parse_to_ast_with_settings( annotate_python_ast( py_ast, vyper_source, - class_types, - for_loop_annotations, + pre_parser, source_id=source_id, module_path=module_path, resolved_path=resolved_path, ) # postcondition: consumed all the for loop annotations - assert len(for_loop_annotations) == 0 + assert len(pre_parser.for_loop_annotations) == 0 + + # postcondition: we have used all the hex strings found by the + # pre-parser + assert len(pre_parser.hex_string_locations) == 0 # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint + module.is_interface = is_interface + + return pre_parser.settings, module + - return settings, module +LINE_INFO_FIELDS = ("lineno", "col_offset", "end_lineno", "end_col_offset") def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: @@ -116,10 +167,9 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: def annotate_python_ast( - parsed_ast: python_ast.AST, + parsed_ast: python_ast.Module, vyper_source: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, + pre_parser: PreParser, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -133,58 +183,98 @@ def annotate_python_ast( The AST to be annotated and optimized. vyper_source: str The original vyper source code - loop_var_annotations: dict - A mapping of line numbers of `For` nodes to the tokens of the type - annotation of the iterator extracted during pre-parsing. - modification_offsets : dict - A mapping of class names to their original class types. + pre_parser: PreParser + PreParser object. Returns ------- The annotated and optimized AST. """ - tokens = asttokens.ASTTokens(vyper_source) - assert isinstance(parsed_ast, python_ast.Module) # help mypy - tokens.mark_tokens(parsed_ast) visitor = AnnotatingVisitor( - vyper_source, - modification_offsets, - for_loop_annotations, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, + vyper_source, pre_parser, source_id, module_path=module_path, resolved_path=resolved_path ) - visitor.visit(parsed_ast) + visitor.start(parsed_ast) return parsed_ast +def _deepcopy_ast(ast_node: python_ast.AST): + # pickle roundtrip is faster than copy.deepcopy() here. + return pickle.loads(pickle.dumps(ast_node)) + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str - _modification_offsets: ModificationOffsets - _loop_var_annotations: dict[int, dict[str, Any]] + _pre_parser: PreParser def __init__( self, source_code: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, - tokens: asttokens.ASTTokens, + pre_parser: PreParser, source_id: int, module_path: Optional[str] = None, resolved_path: Optional[str] = None, ): - self._tokens = tokens self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path self._source_code = source_code - self._modification_offsets = modification_offsets - self._for_loop_annotations = for_loop_annotations + self._pre_parser = pre_parser self.counter: int = 0 + @cached_property + def source_lines(self): + return self._source_code.splitlines(keepends=True) + + @cached_property + def line_offsets(self): + ofst = 0 + # ensure line_offsets has at least 1 entry for 0-line source + ret = {1: ofst} + for lineno, line in enumerate(self.source_lines): + ret[lineno + 1] = ofst + ofst += len(line) + return ret + + def start(self, node: python_ast.Module): + self._fix_missing_locations(node) + self.visit(node) + + def _fix_missing_locations(self, ast_node: python_ast.Module): + """ + adapted from cpython Lib/ast.py. adds line/col info to ast, + but unlike Lib/ast.py, adjusts *all* ast nodes, not just the + one that python defines to have line/col info. + https://github.com/python/cpython/blob/62729d79206014886f5d/Lib/ast.py#L228 + """ + assert isinstance(ast_node, python_ast.Module) + ast_node.lineno = 1 + ast_node.col_offset = 0 + ast_node.end_lineno = max(1, len(self.source_lines)) + + if len(self.source_lines) > 0: + ast_node.end_col_offset = len(self.source_lines[-1]) + else: + ast_node.end_col_offset = 0 + + def _fix(node, parent=None): + for field in LINE_INFO_FIELDS: + if parent is not None: + val = getattr(node, field, None) + # special case for USub - heisenbug when coverage is + # enabled in the test suite. + if val is None or isinstance(node, python_ast.USub): + val = getattr(parent, field) + setattr(node, field, val) + else: + assert hasattr(node, field), node + + for child in python_ast.iter_child_nodes(node): + _fix(child, node) + + _fix(ast_node) + def generic_visit(self, node): """ Annotate a node with information that simplifies Vyper node generation. @@ -192,38 +282,28 @@ def generic_visit(self, node): # Decorate every node with the original source code to allow pretty-printing errors node.full_source_code = self._source_code node.node_id = self.counter - node.ast_type = node.__class__.__name__ self.counter += 1 + node.ast_type = node.__class__.__name__ + + adjustments = self._pre_parser.adjustments + + # Load and Store behave differently inside of fix_missing_locations; + # we don't use them in the vyper AST so just skip adjusting the line + # info. + if isinstance(node, (python_ast.Load, python_ast.Store)): + return super().generic_visit(node) - # Decorate every node with source end offsets - start = (None, None) - if hasattr(node, "first_token"): - start = node.first_token.start - end = (None, None) - if hasattr(node, "last_token"): - end = node.last_token.end - if node.last_token.type == 4: - # token type 4 is a `\n`, some nodes include a trailing newline - # here we ignore it when building the node offsets - end = (end[0], end[1] - 1) - - node.lineno = start[0] - node.col_offset = start[1] - node.end_lineno = end[0] - node.end_col_offset = end[1] - - # TODO: adjust end_lineno and end_col_offset when this node is in - # modification_offsets - - if hasattr(node, "last_token"): - start_pos = node.first_token.startpos - end_pos = node.last_token.endpos - - if node.last_token.type == 4: - # ignore trailing newline once more - end_pos -= 1 - node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}" - node.node_source_code = self._source_code[start_pos:end_pos] + adj = adjustments.get((node.lineno, node.col_offset), 0) + node.col_offset += adj + + adj = adjustments.get((node.end_lineno, node.end_col_offset), 0) + node.end_col_offset += adj + + start_pos = self.line_offsets[node.lineno] + node.col_offset + end_pos = self.line_offsets[node.end_lineno] + node.end_col_offset + + node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}" + node.node_source_code = self._source_code[start_pos:end_pos] return super().generic_visit(node) @@ -257,12 +337,6 @@ def visit_Module(self, node): return self._visit_docstring(node) def visit_FunctionDef(self, node): - if node.decorator_list: - # start the source highlight at `def` to improve annotation readability - decorator_token = node.decorator_list[-1].last_token - def_token = self._tokens.find_token(decorator_token, tokenize.NAME, tok_str="def") - node.first_token = def_token - return self._visit_docstring(node) def visit_ClassDef(self, node): @@ -275,7 +349,7 @@ def visit_ClassDef(self, node): """ self.generic_visit(node) - node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] + node.ast_type = self._pre_parser.keyword_translations[(node.lineno, node.col_offset)] return node def visit_For(self, node): @@ -283,7 +357,8 @@ def visit_For(self, node): Visit a For node, splicing in the loop variable annotation provided by the pre-parser """ - annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + key = (node.lineno, node.col_offset) + annotation_tokens = self._pre_parser.for_loop_annotations.pop(key) if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more @@ -317,16 +392,13 @@ def visit_For(self, node): try: fake_node = python_ast.parse(annotation_str).body[0] + # do we need to fix location info here? + fake_node = _deepcopy_ast(fake_node) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - # fill in with asttokens info. note we can use `self._tokens` because - # it is indented to exactly the same position where it appeared - # in the original source! - self._tokens.mark_tokens(fake_node) - # replace the dummy target name with the real target name. fake_node.target = node.target # replace the For node target with the new ann_assign @@ -350,14 +422,15 @@ def visit_Expr(self, node): if isinstance(node.value, python_ast.Yield): # CMC 2024-03-03 consider unremoving this from the enclosing Expr node = node.value - node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] + key = (node.lineno, node.col_offset) + node.ast_type = self._pre_parser.keyword_translations[key] return node def visit_Await(self, node): - start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them + start_pos = node.lineno, node.col_offset self.generic_visit(node) - node.ast_type = self._modification_offsets[start_pos] + node.ast_type = self._pre_parser.keyword_translations[start_pos] return node def visit_Call(self, node): @@ -377,6 +450,9 @@ def visit_Call(self, node): assert len(dict_.keys) == len(dict_.values) for key, value in zip(dict_.keys, dict_.values): replacement_kw_node = python_ast.keyword(key.id, value) + # set locations + for attr in LINE_INFO_FIELDS: + setattr(replacement_kw_node, attr, getattr(key, attr)) kw_list.append(replacement_kw_node) node.args = [] @@ -401,7 +477,19 @@ def visit_Constant(self, node): if node.value is None or isinstance(node.value, bool): node.ast_type = "NameConstant" elif isinstance(node.value, str): - node.ast_type = "Str" + key = (node.lineno, node.col_offset) + if key in self._pre_parser.hex_string_locations: + if len(node.value) % 2 != 0: + raise SyntaxException( + "Hex string must have an even number of characters", + self._source_code, + node.lineno, + node.col_offset, + ) + node.ast_type = "HexBytes" + self._pre_parser.hex_string_locations.remove(key) + else: + node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" elif isinstance(node.value, Ellipsis.__class__): diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b12aecd0bf..8e221fb7e6 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -2,7 +2,7 @@ import io import re from collections import defaultdict -from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize +from tokenize import COMMENT, NAME, OP, STRING, TokenError, TokenInfo, tokenize, untokenize from packaging.specifiers import InvalidSpecifier, SpecifierSet @@ -12,7 +12,7 @@ # evm-version pragma from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import StructureException, SyntaxException, VersionException -from vyper.typing import ModificationOffsets, ParserPosition +from vyper.typing import ParserPosition def validate_version_pragma(version_str: str, full_source_code: str, start: ParserPosition) -> None: @@ -48,7 +48,7 @@ def validate_version_pragma(version_str: str, full_source_code: str, start: Pars ) -class ForParserState(enum.Enum): +class ParserState(enum.Enum): NOT_RUNNING = enum.auto() START_SOON = enum.auto() RUNNING = enum.auto() @@ -63,7 +63,7 @@ def __init__(self, code): self.annotations = {} self._current_annotation = None - self._state = ForParserState.NOT_RUNNING + self._state = ParserState.NOT_RUNNING self._current_for_loop = None def consume(self, token): @@ -71,15 +71,15 @@ def consume(self, token): if token.type == NAME and token.string == "for": # note: self._state should be NOT_RUNNING here, but we don't sanity # check here as that should be an error the parser will handle. - self._state = ForParserState.START_SOON + self._state = ParserState.START_SOON self._current_for_loop = token.start - if self._state == ForParserState.NOT_RUNNING: + if self._state == ParserState.NOT_RUNNING: return False # state machine: start slurping tokens if token.type == OP and token.string == ":": - self._state = ForParserState.RUNNING + self._state = ParserState.RUNNING # sanity check -- this should never really happen, but if it does, # try to raise an exception which pinpoints the source. @@ -93,12 +93,12 @@ def consume(self, token): # state machine: end slurping tokens if token.type == NAME and token.string == "in": - self._state = ForParserState.NOT_RUNNING + self._state = ParserState.NOT_RUNNING self.annotations[self._current_for_loop] = self._current_annotation or [] self._current_annotation = None return False - if self._state != ForParserState.RUNNING: + if self._state != ParserState.RUNNING: return False # slurp the token @@ -106,6 +106,45 @@ def consume(self, token): return True +class HexStringParser: + def __init__(self): + self.locations = [] + self._tokens = [] + self._state = ParserState.NOT_RUNNING + + def consume(self, token, result): + # prepare to check if the next token is a STRING + if self._state == ParserState.NOT_RUNNING: + if token.type == NAME and token.string == "x": + self._tokens.append(token) + self._state = ParserState.RUNNING + return True + + return False + + assert self._state == ParserState.RUNNING, "unreachable" + + self._state = ParserState.NOT_RUNNING + + if token.type != STRING: + # flush the tokens we have accumulated and move on + result.extend(self._tokens) + self._tokens = [] + return False + + # mark hex string in locations for later processing + self.locations.append(token.start) + + # discard the `x` token and apply sanity checks - + # we should only be discarding one token. + assert len(self._tokens) == 1 + assert (x_tok := self._tokens[0]).type == NAME and x_tok.string == "x" + self._tokens = [] # discard tokens + + result.append(token) + return True + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = { @@ -122,45 +161,59 @@ def consume(self, token): CUSTOM_EXPRESSION_TYPES = {"extcall": "ExtCall", "staticcall": "StaticCall"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: - """ - Re-formats a vyper source string into a python source string and performs - some validation. More specifically, - - * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword - * Validates "@version" pragma against current compiler version - * Prevents direct use of python "class" keyword - * Prevents use of python semi-colon statement separator - * Extracts type annotation of for loop iterators into a separate dictionary - - Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"), and a mapping of line numbers - of for loops to the type annotation of their iterators. - - Parameters - ---------- - code : str - The vyper source code to be re-formatted. - - Returns - ------- - Settings - Compilation settings based on the directives in the source code - ModificationOffsets - A mapping of class names to their original class types. - dict[tuple[int, int], list[TokenInfo]] - A mapping of line/column offsets of `For` nodes to the annotation of the for loop target - str - Reformatted python source string. - """ - result = [] - modification_offsets: ModificationOffsets = {} - settings = Settings() - for_parser = ForParser(code) +class PreParser: + # Compilation settings based on the directives in the source code + settings: Settings + + # A mapping of offsets to new class names + keyword_translations: dict[tuple[int, int], str] + + # Map from offsets in the original vyper source code to offsets + # in the new ("reformatted", i.e. python-compatible) source code + adjustments: dict[tuple[int, int], int] + + # A mapping of line/column offsets of `For` nodes to the annotation of the for loop target + for_loop_annotations: dict[tuple[int, int], list[TokenInfo]] + # A list of line/column offsets of hex string literals + hex_string_locations: list[tuple[int, int]] + # Reformatted python source string. + reformatted_code: str + + def parse(self, code: str): + """ + Re-formats a vyper source string into a python source string and performs + some validation. More specifically, + + * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword + * Validates "@version" pragma against current compiler version + * Prevents direct use of python "class" keyword + * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary + + Stores a mapping of detected interface and struct names to their + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. + + Parameters + ---------- + code : str + The vyper source code to be re-formatted. + """ + try: + self._parse(code) + except TokenError as e: + raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e + + def _parse(self, code: str): + adjustments: dict = {} + result: list[TokenInfo] = [] + keyword_translations: dict[tuple[int, int], str] = {} + settings = Settings() + for_parser = ForParser(code) + hex_string_parser = HexStringParser() + + _col_adjustments: dict[int, int] = defaultdict(lambda: 0) - _col_adjustments: dict[int, int] = defaultdict(lambda: 0) - - try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) @@ -173,6 +226,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: end = token.end line = token.line + # handle adjustments + lineno, col = token.start + adj = _col_adjustments[lineno] + newstart = lineno, col - adj + adjustments[lineno, col - adj] = adj + if typ == COMMENT: contents = string[1:].strip() if contents.startswith("@version"): @@ -207,10 +266,10 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: if evm_version not in EVM_VERSIONS: raise StructureException(f"Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version - elif pragma.startswith("experimental-codegen"): + elif pragma.startswith("experimental-codegen") or pragma.startswith("venom"): if settings.experimental_codegen is not None: raise StructureException( - "pragma experimental-codegen specified twice!", start + "pragma experimental-codegen/venom specified twice!", start ) settings.experimental_codegen = True elif pragma.startswith("enable-decimals"): @@ -229,49 +288,46 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: ) if typ == NAME: + # see if it's a keyword we need to replace + new_keyword = None if string in VYPER_CLASS_TYPES and start[1] == 0: - toks = [TokenInfo(NAME, "class", start, end, line)] - modification_offsets[start] = VYPER_CLASS_TYPES[string] + new_keyword = "class" + vyper_type = VYPER_CLASS_TYPES[string] elif string in CUSTOM_STATEMENT_TYPES: new_keyword = "yield" - adjustment = len(new_keyword) - len(string) - # adjustments for following staticcall/extcall modification_offsets - _col_adjustments[start[0]] += adjustment - toks = [TokenInfo(NAME, new_keyword, start, end, line)] - modification_offsets[start] = CUSTOM_STATEMENT_TYPES[string] + vyper_type = CUSTOM_STATEMENT_TYPES[string] elif string in CUSTOM_EXPRESSION_TYPES: - # a bit cursed technique to get untokenize to put - # the new tokens in the right place so that modification_offsets - # will work correctly. - # (recommend comparing the result of pre_parse with the - # source code side by side to visualize the whitespace) new_keyword = "await" vyper_type = CUSTOM_EXPRESSION_TYPES[string] - lineno, col_offset = start - - # fixup for when `extcall/staticcall` follows `log` - adjustment = _col_adjustments[lineno] - new_start = (lineno, col_offset + adjustment) - modification_offsets[new_start] = vyper_type + if new_keyword is not None: + keyword_translations[newstart] = vyper_type - # tells untokenize to add whitespace, preserving locations - diff = len(new_keyword) - len(string) - new_end = end[0], end[1] + diff + adjustment = len(string) - len(new_keyword) + # adjustments for following tokens + lineno, col = start + _col_adjustments[lineno] += adjustment - toks = [TokenInfo(NAME, new_keyword, start, new_end, line)] + # a bit cursed technique to get untokenize to put + # the new tokens in the right place so that + # `keyword_translations` will work correctly. + # (recommend comparing the result of parse with the + # source code side by side to visualize the whitespace) + toks = [TokenInfo(NAME, new_keyword, start, end, line)] if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - if not for_parser.consume(token): + if not for_parser.consume(token) and not hex_string_parser.consume(token, result): result.extend(toks) - except TokenError as e: - raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - - for_loop_annotations = {} - for k, v in for_parser.annotations.items(): - for_loop_annotations[k] = v.copy() + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + for_loop_annotations[k] = v.copy() - return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") + self.adjustments = adjustments + self.settings = settings + self.keyword_translations = keyword_translations + self.for_loop_annotations = for_loop_annotations + self.hex_string_locations = hex_string_parser.locations + self.reformatted_code = untokenize(result).decode("utf-8") diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index aa53dee429..a494e4a344 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -463,7 +463,7 @@ def to_flag(expr, arg, out_typ): def convert(expr, context): assert len(expr.args) == 2, "bad typecheck: convert" - arg_ast = expr.args[0] + arg_ast = expr.args[0].reduced() arg = Expr(arg_ast, context).ir_node original_arg = arg diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f59a33e2a2..17ffde0728 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -306,7 +306,7 @@ def fetch_call_return(self, node): arg = node.args[0] start_expr = node.args[1] - length_expr = node.args[2] + length_expr = node.args[2].reduced() # CMC 2022-03-22 NOTE slight code duplication with semantics/analysis/local is_adhoc_slice = arg.get("attr") == "code" or ( @@ -781,7 +781,7 @@ def build_IR(self, expr, args, kwargs, context): ["mstore", add_ofst(input_buf, 32), args[1]], ["mstore", add_ofst(input_buf, 64), args[2]], ["mstore", add_ofst(input_buf, 96), args[3]], - ["staticcall", "gas", 1, input_buf, 128, output_buf, 32], + ["assert", ["staticcall", "gas", 1, input_buf, 128, output_buf, 32]], ["mload", output_buf], ], typ=AddressT(), @@ -1257,7 +1257,8 @@ def fetch_call_return(self, node): def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) - if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: + arg = node.args[0].reduced() + if not isinstance(arg, vy_ast.List) or len(arg.elements) > 4: raise InvalidType("Expecting a list of 0-4 topics as first argument", node.args[0]) # return a concrete type for `data` @@ -1267,7 +1268,9 @@ def infer_arg_types(self, node, expected_return_typ=None): @process_inputs def build_IR(self, expr, args, kwargs, context): - topics_length = len(expr.args[0].elements) + context.check_is_not_constant(f"use {self._id}", expr) + + topics_length = len(expr.args[0].reduced().elements) topics = args[0].args topics = [unwrap_location(topic) for topic in topics] @@ -2164,10 +2167,9 @@ def build_IR(self, expr, args, kwargs, context): variables_2=variables_2, memory_allocator=context.memory_allocator, ) + z_ir = new_ctx.vars["z"].as_ir_node() ret = IRnode.from_list( - ["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable - typ=DecimalT(), - location=MEMORY, + ["seq", placeholder_copy, sqrt_ir, z_ir], typ=DecimalT(), location=MEMORY ) return b1.resolve(ret) @@ -2361,7 +2363,13 @@ def infer_kwarg_types(self, node): for kwarg in node.keywords: kwarg_name = kwarg.arg validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) - ret[kwarg_name] = get_exact_type_from_node(kwarg.value) + + typ = get_exact_type_from_node(kwarg.value) + if kwarg_name == "method_id" and isinstance(typ, BytesT): + if typ.length != 4: + raise InvalidLiteral("method_id must be exactly 4 bytes!", kwarg.value) + + ret[kwarg_name] = typ return ret def fetch_call_return(self, node): diff --git a/vyper/builtins/interfaces/IERC4626.vyi b/vyper/builtins/interfaces/IERC4626.vyi index 6d9e4c6ef7..0dd398d1f3 100644 --- a/vyper/builtins/interfaces/IERC4626.vyi +++ b/vyper/builtins/interfaces/IERC4626.vyi @@ -44,7 +44,7 @@ def previewDeposit(assets: uint256) -> uint256: ... @external -def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: +def deposit(assets: uint256, receiver: address) -> uint256: ... @view @@ -58,7 +58,7 @@ def previewMint(shares: uint256) -> uint256: ... @external -def mint(shares: uint256, receiver: address=msg.sender) -> uint256: +def mint(shares: uint256, receiver: address) -> uint256: ... @view @@ -72,7 +72,7 @@ def previewWithdraw(assets: uint256) -> uint256: ... @external -def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: +def withdraw(assets: uint256, receiver: address, owner: address) -> uint256: ... @view @@ -86,5 +86,5 @@ def previewRedeem(shares: uint256) -> uint256: ... @external -def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: +def redeem(shares: uint256, receiver: address, owner: address) -> uint256: ... diff --git a/vyper/cli/compile_archive.py b/vyper/cli/compile_archive.py index 1b52343c1c..d1dd2588ad 100644 --- a/vyper/cli/compile_archive.py +++ b/vyper/cli/compile_archive.py @@ -8,8 +8,9 @@ import zipfile from pathlib import PurePath -from vyper.compiler import compile_from_file_input +from vyper.compiler import outputs_from_compiler_data from vyper.compiler.input_bundle import FileInput, ZipInputBundle +from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings, merge_settings from vyper.exceptions import BadArchive @@ -19,6 +20,11 @@ class NotZipInput(Exception): def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): + compiler_data = compiler_data_from_zip(file_name, settings, no_bytecode_metadata) + return outputs_from_compiler_data(compiler_data, output_formats) + + +def compiler_data_from_zip(file_name, settings, no_bytecode_metadata): with open(file_name, "rb") as f: bcontents = f.read() @@ -39,6 +45,11 @@ def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): fcontents = archive.read("MANIFEST/compilation_targets").decode("utf-8") compilation_targets = fcontents.splitlines() + storage_layout_path = "MANIFEST/storage_layout.json" + storage_layout = None + if storage_layout_path in archive.namelist(): + storage_layout = json.loads(archive.read(storage_layout_path).decode("utf-8")) + if len(compilation_targets) != 1: raise BadArchive("Multiple compilation targets not supported!") @@ -59,11 +70,10 @@ def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): settings, archive_settings, lhs_source="command line", rhs_source="archive settings" ) - # TODO: validate integrity sum (probably in CompilerData) - return compile_from_file_input( + return CompilerData( file, input_bundle=input_bundle, - output_formats=output_formats, + storage_layout=storage_layout, integrity_sum=integrity, settings=settings, no_bytecode_metadata=no_bytecode_metadata, diff --git a/vyper/cli/venom_main.py b/vyper/cli/venom_main.py new file mode 100755 index 0000000000..3114246e04 --- /dev/null +++ b/vyper/cli/venom_main.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +import argparse +import sys + +import vyper +import vyper.evm.opcodes as evm +from vyper.compiler.phases import generate_bytecode +from vyper.compiler.settings import OptimizationLevel, Settings, set_global_settings +from vyper.venom import generate_assembly_experimental, run_passes_on +from vyper.venom.parser import parse_venom + +""" +Standalone entry point into venom compiler. Parses venom input and emits +bytecode. +""" + + +def _parse_cli_args(): + return _parse_args(sys.argv[1:]) + + +def _parse_args(argv: list[str]): + parser = argparse.ArgumentParser( + description="Venom EVM IR parser & compiler", formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument("input_file", help="Venom sourcefile", nargs="?") + parser.add_argument("--version", action="version", version=vyper.__long_version__) + parser.add_argument( + "--evm-version", + help=f"Select desired EVM version (default {evm.DEFAULT_EVM_VERSION})", + choices=list(evm.EVM_VERSIONS), + dest="evm_version", + ) + parser.add_argument( + "--stdin", action="store_true", help="whether to pull venom input from stdin" + ) + + args = parser.parse_args(argv) + + if args.evm_version is not None: + set_global_settings(Settings(evm_version=args.evm_version)) + + if args.stdin: + if not sys.stdin.isatty(): + venom_source = sys.stdin.read() + else: + # No input provided + print("Error: --stdin flag used but no input provided") + sys.exit(1) + else: + if args.input_file is None: + print("Error: No input file provided, either use --stdin or provide a path") + sys.exit(1) + with open(args.input_file, "r") as f: + venom_source = f.read() + + ctx = parse_venom(venom_source) + run_passes_on(ctx, OptimizationLevel.default()) + asm = generate_assembly_experimental(ctx) + bytecode = generate_bytecode(asm, compiler_metadata=None) + print(f"0x{bytecode.hex()}") + + +if __name__ == "__main__": + _parse_args(sys.argv[1:]) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index b7e664b975..09f8324dcf 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -5,7 +5,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Iterable, Iterator, Optional, Set, TypeVar +from typing import Any, Optional import vyper import vyper.codegen.ir_node as ir_node @@ -15,8 +15,7 @@ from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT, OptimizationLevel, Settings from vyper.typing import ContractPath, OutputFormats - -T = TypeVar("T") +from vyper.utils import uniq format_options_help = """Format to print, one or more of: bytecode (default) - Deployable bytecode @@ -34,6 +33,8 @@ layout - Storage layout of a Vyper contract ast - AST (not yet annotated) in JSON format annotated_ast - Annotated AST in JSON format +cfg - Control flow graph of deployable bytecode +cfg_runtime - Control flow graph of runtime bytecode interface - Vyper interface of a contract external_interface - External interface of a contract, used for outside contract calls opcodes - List of opcodes as a string @@ -41,6 +42,8 @@ ir - Intermediate representation in list format ir_json - Intermediate representation in JSON format ir_runtime - Intermediate representation of runtime bytecode in list format +bb - Basic blocks of Venom IR for deployable bytecode +bb_runtime - Basic blocks of Venom IR for runtime bytecode asm - Output the EVM assembly of the deployable bytecode integrity - Output the integrity hash of the source code archive - Output the build as an archive file @@ -121,8 +124,7 @@ def _parse_args(argv): ) parser.add_argument( "--evm-version", - help=f"Select desired EVM version (default {evm.DEFAULT_EVM_VERSION}). " - "note: cancun support is EXPERIMENTAL", + help=f"Select desired EVM version (default {evm.DEFAULT_EVM_VERSION})", choices=list(evm.EVM_VERSIONS), dest="evm_version", ) @@ -177,7 +179,8 @@ def _parse_args(argv): parser.add_argument("-o", help="Set the output path", dest="output_path") parser.add_argument( "--experimental-codegen", - help="The compiler use the new IR codegen. This is an experimental feature.", + "--venom", + help="The compiler uses the new IR codegen. This is an experimental feature.", action="store_true", dest="experimental_codegen", ) @@ -259,20 +262,6 @@ def _parse_args(argv): _cli_helper(f, output_formats, compiled) -def uniq(seq: Iterable[T]) -> Iterator[T]: - """ - Yield unique items in ``seq`` in order. - """ - seen: Set[T] = set() - - for x in seq: - if x in seen: - continue - - seen.add(x) - yield x - - def exc_handler(contract_path: ContractPath, exception: Exception) -> None: print(f"Error compiling: {contract_path}") raise exception @@ -355,7 +344,7 @@ def compile_files( # we allow this instead of requiring a different mode (like # `--zip`) so that verifier pipelines do not need a different # workflow for archive files and single-file contracts. - output = compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata) + output = compile_from_zip(file_name, final_formats, settings, no_bytecode_metadata) ret[file_path] = output continue except NotZipInput: diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index beab06e3df..5f632f4167 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -12,6 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError +from vyper.typing import StorageLayout from vyper.utils import OrderedSet, keccak256 TRANSLATE_MAP = { @@ -206,6 +207,19 @@ def get_inputs(input_dict: dict) -> dict[PurePath, Any]: return ret +def get_storage_layout_overrides(input_dict: dict) -> dict[PurePath, StorageLayout]: + storage_layout_overrides: dict[PurePath, StorageLayout] = {} + + for path, value in input_dict.get("storage_layout_overrides", {}).items(): + if path not in input_dict["sources"]: + raise JSONError(f"unknown target for storage layout override: {path}") + + path = PurePath(path) + storage_layout_overrides[path] = value + + return storage_layout_overrides + + # get unique output formats for each contract, given the input_dict # NOTE: would maybe be nice to raise on duplicated output formats def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: @@ -249,16 +263,17 @@ def get_search_paths(input_dict: dict) -> list[PurePath]: return [PurePath(p) for p in ret] -def compile_from_input_dict( - input_dict: dict, exc_handler: Callable = exc_handler_raises -) -> tuple[dict, dict]: - if input_dict["language"] != "Vyper": - raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.") - +def get_settings(input_dict: dict) -> Settings: evm_version = get_evm_version(input_dict) optimize = input_dict["settings"].get("optimize") - experimental_codegen = input_dict["settings"].get("experimentalCodegen", False) + + experimental_codegen = input_dict["settings"].get("experimentalCodegen") + if experimental_codegen is None: + experimental_codegen = input_dict["settings"].get("venom") + elif input_dict["settings"].get("venom") is not None: + raise JSONError("both experimentalCodegen and venom cannot be set") + if isinstance(optimize, bool): # bool optimization level for backwards compatibility warnings.warn( @@ -271,15 +286,34 @@ def compile_from_input_dict( else: assert optimize is None - settings = Settings( - evm_version=evm_version, optimize=optimize, experimental_codegen=experimental_codegen + debug = input_dict["settings"].get("debug", None) + + # TODO: maybe change these to camelCase for consistency + enable_decimals = input_dict["settings"].get("enable_decimals", None) + + return Settings( + evm_version=evm_version, + optimize=optimize, + experimental_codegen=experimental_codegen, + debug=debug, + enable_decimals=enable_decimals, ) + +def compile_from_input_dict( + input_dict: dict, exc_handler: Callable = exc_handler_raises +) -> tuple[dict, dict]: + if input_dict["language"] != "Vyper": + raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.") + + settings = get_settings(input_dict) + no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) integrity = input_dict.get("integrity") sources = get_inputs(input_dict) + storage_layout_overrides = get_storage_layout_overrides(input_dict) output_formats = get_output_formats(input_dict) compilation_targets = list(output_formats.keys()) search_paths = get_search_paths(input_dict) @@ -289,6 +323,7 @@ def compile_from_input_dict( res, warnings_dict = {}, {} warnings.simplefilter("always") for contract_path in compilation_targets: + storage_layout_override = storage_layout_overrides.get(contract_path) with warnings.catch_warnings(record=True) as caught_warnings: try: # use load_file to get a unique source_id @@ -298,6 +333,7 @@ def compile_from_input_dict( file, input_bundle=input_bundle, output_formats=output_formats[contract_path], + storage_layout_override=storage_layout_override, integrity_sum=integrity, settings=settings, no_bytecode_metadata=no_bytecode_metadata, @@ -337,6 +373,9 @@ def format_to_output_dict(compiler_data: dict) -> dict: if key in data: output_contracts[key] = data[key] + if "layout" in data: + output_contracts["layout"] = data["layout"] + if "method_identifiers" in data: output_contracts["evm"] = {"methodIdentifiers": data["method_identifiers"]} diff --git a/vyper/codegen/abi_encoder.py b/vyper/codegen/abi_encoder.py index 09a22cd857..2ea8e3b6fd 100644 --- a/vyper/codegen/abi_encoder.py +++ b/vyper/codegen/abi_encoder.py @@ -73,7 +73,6 @@ def _encode_dyn_array_helper(dst, ir_node, context): # TODO handle this upstream somewhere if ir_node.value == "multi": buf = context.new_internal_variable(dst.typ) - buf = IRnode.from_list(buf, typ=dst.typ, location=MEMORY) _bufsz = dst.typ.abi_type.size_bound() return [ "seq", diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index f49914ac78..7995b7b9f5 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -15,6 +15,17 @@ class Constancy(enum.Enum): Constant = 1 +_alloca_id = 0 + + +def _generate_alloca_id(): + # note: this gets reset between compiler runs by codegen.core.reset_names + global _alloca_id + + _alloca_id += 1 + return _alloca_id + + @dataclass(frozen=True) class Alloca: name: str @@ -22,6 +33,8 @@ class Alloca: typ: VyperType size: int + _id: int + def __post_init__(self): assert self.typ.memory_bytes_required == self.size @@ -233,7 +246,9 @@ def _new_variable( pos = f"$palloca_{ofst}_{size}" else: pos = f"$alloca_{ofst}_{size}" - alloca = Alloca(name=name, offset=ofst, typ=typ, size=size) + + alloca_id = _generate_alloca_id() + alloca = Alloca(name=name, offset=ofst, typ=typ, size=size, _id=alloca_id) var = VariableRecord( name=name, diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index eddba9e1b1..448c38296b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,3 +1,4 @@ +import vyper.codegen.context as ctx from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import _opt_codesize, _opt_gas, _opt_none from vyper.evm.address_space import ( @@ -325,7 +326,7 @@ def copy_bytes(dst, src, length, length_bound): copy_op = ["mcopy", dst, src, length] gas_bound = _mcopy_gas_bound(length_bound) else: - copy_op = ["staticcall", "gas", 4, src, length, dst, length] + copy_op = ["assert", ["staticcall", "gas", 4, src, length, dst, length]] gas_bound = _identity_gas_bound(length_bound) elif src.location == CALLDATA: copy_op = ["calldatacopy", dst, src, length] @@ -855,6 +856,9 @@ def reset_names(): global _label _label = 0 + # could be refactored + ctx._alloca_id = 0 + # returns True if t is ABI encoded and is a type that needs any kind of # validation @@ -924,6 +928,26 @@ def potential_overlap(left, right): return False +# similar to `potential_overlap()`, but compares left's _reads_ vs +# right's _writes_. +# TODO: `potential_overlap()` can probably be replaced by this function, +# but all the cases need to be checked. +def read_write_overlap(left, right): + if not isinstance(left, IRnode) or not isinstance(right, IRnode): + return False + + if left.typ._is_prim_word and right.typ._is_prim_word: + return False + + if len(left.referenced_variables & right.variable_writes) > 0: + return True + + if len(left.referenced_variables) > 0 and right.contains_risky_call: + return True + + return False + + # Create an x=y statement, where the types may be compound def make_setter(left, right, hi=None): check_assign(left, right) @@ -1097,7 +1121,7 @@ def ensure_in_memory(ir_var, context): return ir_var typ = ir_var.typ - buf = IRnode.from_list(context.new_internal_variable(typ), typ=typ, location=MEMORY) + buf = context.new_internal_variable(typ) do_copy = make_setter(buf, ir_var) return IRnode.from_list(["seq", do_copy, buf], typ=typ, location=MEMORY) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 65df5a0930..d3059e4245 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -21,6 +21,7 @@ make_setter, pop_dyn_array, potential_overlap, + read_write_overlap, sar, shl, shr, @@ -40,6 +41,7 @@ UnimplementedException, tag_exceptions, ) +from vyper.semantics.analysis.utils import get_expr_writes from vyper.semantics.types import ( AddressT, BoolT, @@ -49,6 +51,7 @@ FlagT, HashMapT, InterfaceT, + ModuleT, SArrayT, StringT, StructT, @@ -58,13 +61,7 @@ from vyper.semantics.types.bytestrings import _BytestringT from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T -from vyper.utils import ( - DECIMAL_DIVISOR, - bytes_to_int, - is_checksum_encoded, - string_to_bytes, - vyper_warn, -) +from vyper.utils import DECIMAL_DIVISOR, bytes_to_int, is_checksum_encoded, vyper_warn ENVIRONMENT_VARIABLES = {"block", "msg", "tx", "chain"} @@ -86,6 +83,9 @@ def __init__(self, node, context, is_stmt=False): self.ir_node = fn() assert isinstance(self.ir_node, IRnode), self.ir_node + writes = set(access.variable for access in get_expr_writes(self.expr)) + self.ir_node._writes = writes + self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.ast_source = self.expr @@ -129,18 +129,21 @@ def parse_Hex(self): # String literals def parse_Str(self): - bytez, bytez_length = string_to_bytes(self.expr.value) - typ = StringT(bytez_length) - return self._make_bytelike(typ, bytez, bytez_length) + bytez = self.expr.value.encode("utf-8") + return self._make_bytelike(StringT, bytez) # Byte literals def parse_Bytes(self): - bytez = self.expr.value - bytez_length = len(self.expr.value) - typ = BytesT(bytez_length) - return self._make_bytelike(typ, bytez, bytez_length) + return self._make_bytelike(BytesT, self.expr.value) + + def parse_HexBytes(self): + # HexBytes already has value as bytes + assert isinstance(self.expr.value, bytes) + return self._make_bytelike(BytesT, self.expr.value) - def _make_bytelike(self, btype, bytez, bytez_length): + def _make_bytelike(self, typeclass, bytez): + bytez_length = len(bytez) + btype = typeclass(bytez_length) placeholder = self.context.new_internal_variable(btype) seq = [] seq.append(["mstore", placeholder, bytez_length]) @@ -264,7 +267,8 @@ def parse_Attribute(self): return IRnode.from_list(["~calldata"], typ=BytesT(0)) elif key == "msg.value" and self.context.is_payable: return IRnode.from_list(["callvalue"], typ=UINT256_T) - elif key == "msg.gas": + elif key in ("msg.gas", "msg.mana"): + # NOTE: `msg.mana` is an alias for `msg.gas` return IRnode.from_list(["gas"], typ=UINT256_T) elif key == "block.prevrandao": if not version_check(begin="paris"): @@ -352,6 +356,8 @@ def parse_Subscript(self): elif is_array_like(sub.typ): index = Expr.parse_value_expr(self.expr.slice, self.context) + if read_write_overlap(sub, index): + raise CompilerPanic("risky overlap") elif is_tuple_like(sub.typ): # should we annotate expr.slice in the frontend with the @@ -666,7 +672,8 @@ def parse_Call(self): # TODO fix cyclic import from vyper.builtins._signatures import BuiltinFunctionT - func_t = self.expr.func._metadata["type"] + func = self.expr.func + func_t = func._metadata["type"] if isinstance(func_t, BuiltinFunctionT): return func_t.build_IR(self.expr, self.context) @@ -677,8 +684,14 @@ def parse_Call(self): return self.handle_struct_literal() # Interface constructor. Bar(
). - if is_type_t(func_t, InterfaceT): + if is_type_t(func_t, InterfaceT) or func.get("attr") == "__at__": assert not self.is_stmt # sanity check typechecker + + # magic: do sanity checks for module.__at__ + if func.get("attr") == "__at__": + assert isinstance(func_t, MemberFunctionT) + assert isinstance(func.value._metadata["type"], ModuleT) + (arg0,) = self.expr.args arg_ir = Expr(arg0, self.context).ir_node @@ -688,16 +701,16 @@ def parse_Call(self): return arg_ir if isinstance(func_t, MemberFunctionT): - darray = Expr(self.expr.func.value, self.context).ir_node + # TODO consider moving these to builtins or a dedicated file + darray = Expr(func.value, self.context).ir_node assert isinstance(darray.typ, DArrayT) args = [Expr(x, self.context).ir_node for x in self.expr.args] - if self.expr.func.attr == "pop": - # TODO consider moving this to builtins - darray = Expr(self.expr.func.value, self.context).ir_node + if func.attr == "pop": + darray = Expr(func.value, self.context).ir_node assert len(self.expr.args) == 0 return_item = not self.is_stmt return pop_dyn_array(darray, return_popped_item=return_item) - elif self.expr.func.attr == "append": + elif func.attr == "append": (arg,) = args check_assign( dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) @@ -706,13 +719,14 @@ def parse_Call(self): ret = ["seq"] if potential_overlap(darray, arg): tmp = self.context.new_internal_variable(arg.typ) - tmp = IRnode.from_list(tmp, typ=arg.typ, location=MEMORY) ret.append(make_setter(tmp, arg)) arg = tmp ret.append(append_dyn_array(darray, arg)) return IRnode.from_list(ret) + raise CompilerPanic("unreachable!") # pragma: nocover + assert isinstance(func_t, ContractFunctionT) assert func_t.is_internal or func_t.is_constructor diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index 72fff5378f..331b991bfe 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -71,7 +71,9 @@ def _pack_arguments(fn_type, args, context): pack_args.append(["mstore", buf, util.method_id_int(abi_signature)]) if len(args) != 0: - pack_args.append(abi_encode(add_ofst(buf, 32), args_as_tuple, context, bufsz=buflen)) + encode_buf = add_ofst(buf, 32) + encode_buflen = buflen - 32 + pack_args.append(abi_encode(encode_buf, args_as_tuple, context, bufsz=encode_buflen)) return buf, pack_args, args_ofst, args_len @@ -107,8 +109,7 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp # unpack strictly if not needs_clamp(wrapped_return_t, encoding): - # revert when returndatasize is not in bounds, except when - # skip_contract_check is enabled. + # revert when returndatasize is not in bounds # NOTE: there is an optimization here: when needs_clamp is True, # make_setter (implicitly) checks returndatasize during abi # decoding. @@ -123,14 +124,13 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp # another thing we could do instead once we have the machinery is to # simply always use make_setter instead of having this assertion, and # rely on memory analyser to optimize out the memory movement. - if not call_kwargs.skip_contract_check: - assertion = IRnode.from_list( - ["assert", ["ge", "returndatasize", min_return_size]], - error_msg="returndatasize too small", - ) - unpacker.append(assertion) - return_buf = buf + assertion = IRnode.from_list( + ["assert", ["ge", "returndatasize", min_return_size]], + error_msg="returndatasize too small", + ) + unpacker.append(assertion) + return_buf = buf else: return_buf = context.new_internal_variable(wrapped_return_t) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 97d9c45fb6..81ec47f10f 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Union import vyper.ast as vy_ast -from vyper.compiler.settings import VYPER_COLOR_OUTPUT +from vyper.compiler.settings import VYPER_COLOR_OUTPUT, get_global_settings from vyper.evm.address_space import AddrSpace from vyper.evm.opcodes import get_ir_opcodes from vyper.exceptions import CodegenPanic, CompilerPanic @@ -378,13 +378,18 @@ def is_complex_ir(self): and self.value.lower() not in do_not_cache ) - # set an error message and push down into all children. - # useful for overriding an error message generated by a helper - # function with a more specific error message. + # set an error message and push down to its children that don't have error_msg set def set_error_msg(self, error_msg: str) -> None: + if self.error_msg is not None: + raise CompilerPanic(f"{self.value} already has error message {self.error_msg}") + self._set_error_msg(error_msg) + + def _set_error_msg(self, error_msg: str) -> None: + if self.error_msg is not None: + return self.error_msg = error_msg for arg in self.args: - arg.set_error_msg(error_msg) + arg._set_error_msg(error_msg) # get the unique symbols contained in this node, which provides # sanity check invariants for the optimizer. @@ -426,6 +431,10 @@ def is_pointer(self) -> bool: @property # probably could be cached_property but be paranoid def _optimized(self): + if get_global_settings().experimental_codegen: + # in venom pipeline, we don't need to inline constants. + return self + # TODO figure out how to fix this circular import from vyper.ir.optimizer import optimize @@ -467,6 +476,18 @@ def referenced_variables(self): return ret + @cached_property + def variable_writes(self): + ret = getattr(self, "_writes", set()) + + for arg in self.args: + ret |= arg.variable_writes + + if getattr(self, "is_self_call", False): + ret |= self.invoked_function_ir.func_ir.variable_writes + + return ret + @cached_property def contains_risky_call(self): ret = self.value in ("call", "delegatecall", "staticcall", "create", "create2") @@ -611,7 +632,7 @@ def from_list( else: return cls( obj[0], - [cls.from_list(o, ast_source=ast_source) for o in obj[1:]], + [cls.from_list(o, ast_source=ast_source, error_msg=error_msg) for o in obj[1:]], typ, location=location, annotation=annotation, diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 8cd599d3e7..fef6070d14 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -59,10 +59,8 @@ def ir_for_self_call(stmt_expr, context): # allocate space for the return buffer # TODO allocate in stmt and/or expr.py if func_t.return_type is not None: - return_buffer = IRnode.from_list( - context.new_internal_variable(func_t.return_type), - annotation=f"{return_label}_return_buf", - ) + return_buffer = context.new_internal_variable(func_t.return_type) + return_buffer.annotation = f"{return_label}_return_buf" else: return_buffer = None diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 830f2f923d..b1e26d7d5f 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -19,7 +19,6 @@ ) from vyper.codegen.expr import Expr from vyper.codegen.return_ import make_return_stmt -from vyper.evm.address_space import MEMORY from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions from vyper.semantics.types import DArrayT from vyper.semantics.types.shortcuts import UINT256_T @@ -56,13 +55,11 @@ def parse_Name(self): def parse_AnnAssign(self): ltyp = self.stmt.target._metadata["type"] varname = self.stmt.target.id - alloced = self.context.new_variable(varname, ltyp) + lhs = self.context.new_variable(varname, ltyp) assert self.stmt.value is not None rhs = Expr(self.stmt.value, self.context).ir_node - lhs = IRnode.from_list(alloced, typ=ltyp, location=MEMORY) - return make_setter(lhs, rhs) def parse_Assign(self): @@ -76,7 +73,6 @@ def parse_Assign(self): # complex - i.e., it spans multiple words. for safety, we # copy to a temporary buffer before copying to the destination. tmp = self.context.new_internal_variable(src.typ) - tmp = IRnode.from_list(tmp, typ=src.typ, location=MEMORY) ret.append(make_setter(tmp, src)) src = tmp @@ -97,7 +93,13 @@ def parse_If(self): def parse_Log(self): event = self.stmt._metadata["type"] - args = [Expr(arg, self.context).ir_node for arg in self.stmt.value.args] + if len(self.stmt.value.keywords) > 0: + # keyword arguments + to_compile = [arg.value for arg in self.stmt.value.keywords] + else: + # positional arguments + to_compile = self.stmt.value.args + args = [Expr(arg, self.context).ir_node for arg in to_compile] topic_ir = [] data_ir = [] @@ -247,9 +249,7 @@ def _parse_For_list(self): # user-supplied name for loop variable varname = self.stmt.target.target.id - loop_var = IRnode.from_list( - self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY - ) + loop_var = self.context.new_variable(varname, target_type) i = IRnode.from_list(self.context.fresh_varname("for_list_ix"), typ=UINT256_T) @@ -259,11 +259,7 @@ def _parse_For_list(self): # list literal, force it to memory first if isinstance(self.stmt.iter, vy_ast.List): - tmp_list = IRnode.from_list( - self.context.new_internal_variable(iter_list.typ), - typ=iter_list.typ, - location=MEMORY, - ) + tmp_list = self.context.new_internal_variable(iter_list.typ) ret.append(make_setter(tmp_list, iter_list)) iter_list = tmp_list diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0345c24931..57bd2f4096 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -46,6 +46,13 @@ "opcodes_runtime": output.build_opcodes_runtime_output, } +INTERFACE_OUTPUT_FORMATS = [ + "ast_dict", + "annotated_ast_dict", + "interface", + "external_interface", + "abi", +] UNKNOWN_CONTRACT_NAME = "" @@ -99,13 +106,6 @@ def compile_from_file_input( """ settings = settings or get_global_settings() or Settings() - if output_formats is None: - output_formats = ("bytecode",) - - # make IR output the same between runs - # TODO: move this to CompilerData.__init__() - codegen.reset_names() - compiler_data = CompilerData( file_input, input_bundle, @@ -116,17 +116,36 @@ def compile_from_file_input( no_bytecode_metadata=no_bytecode_metadata, ) + return outputs_from_compiler_data(compiler_data, output_formats, exc_handler) + + +def outputs_from_compiler_data( + compiler_data: CompilerData, + output_formats: Optional[OutputFormats] = None, + exc_handler: Optional[Callable] = None, +): + if output_formats is None: + output_formats = ("bytecode",) + ret = {} + with anchor_settings(compiler_data.settings): for output_format in output_formats: if output_format not in OUTPUT_FORMATS: raise ValueError(f"Unsupported format type {repr(output_format)}") + + is_vyi = compiler_data.file_input.resolved_path.suffix == ".vyi" + if is_vyi and output_format not in INTERFACE_OUTPUT_FORMATS: + raise ValueError( + f"Unsupported format for compiling interface: {repr(output_format)}" + ) + try: formatter = OUTPUT_FORMATS[output_format] ret[output_format] = formatter(compiler_data) except Exception as exc: if exc_handler is not None: - exc_handler(str(file_input.path), exc) + exc_handler(str(compiler_data.file_input.path), exc) else: raise exc diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index a928989393..06fee78613 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -52,6 +52,8 @@ class ABIInput(CompilerInput): def try_parse_abi(file_input: FileInput) -> CompilerInput: try: s = json.loads(file_input.source_code) + if isinstance(s, dict) and "abi" in s: + s = s["abi"] return ABIInput(**asdict(file_input), abi=s) except (ValueError, TypeError): return file_input @@ -135,9 +137,6 @@ def load_file(self, path: PathLike | str) -> CompilerInput: return res - def add_search_path(self, path: PathLike) -> None: - self.search_paths.append(path) - # temporarily add something to the search path (within the # scope of the context manager) with highest precedence. # if `path` is None, do nothing @@ -153,16 +152,15 @@ def search_path(self, path: Optional[PathLike]) -> Iterator[None]: finally: self.search_paths.pop() - # temporarily modify the top of the search path (within the - # scope of the context manager) with highest precedence to something else + # temporarily set search paths to a given list @contextlib.contextmanager - def poke_search_path(self, path: PathLike) -> Iterator[None]: - tmp = self.search_paths[-1] - self.search_paths[-1] = path + def temporary_search_paths(self, new_paths: list[PathLike]) -> Iterator[None]: + original_paths = self.search_paths + self.search_paths = new_paths try: yield finally: - self.search_paths[-1] = tmp + self.search_paths = original_paths # regular input. takes a search path(s), and `load_file()` will search all diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 577afd3822..b6a0e8ac8c 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -3,7 +3,8 @@ from collections import deque from pathlib import PurePath -from vyper.ast import ast_to_dict +import vyper.ast as vy_ast +from vyper.ast.utils import ast_to_dict from vyper.codegen.ir_node import IRnode from vyper.compiler.output_bundle import SolcJSONWriter, VyperArchiveWriter from vyper.compiler.phases import CompilerData @@ -11,9 +12,11 @@ from vyper.evm import opcodes from vyper.exceptions import VyperException from vyper.ir import compile_ir -from vyper.semantics.types.function import FunctionVisibility, StateMutability +from vyper.semantics.analysis.base import ModuleInfo +from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability +from vyper.semantics.types.module import InterfaceT from vyper.typing import StorageLayout -from vyper.utils import vyper_warn +from vyper.utils import safe_relpath, vyper_warn from vyper.warnings import ContractSizeLimitWarning @@ -26,9 +29,32 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: def build_annotated_ast_dict(compiler_data: CompilerData) -> dict: + module_t = compiler_data.annotated_vyper_module._metadata["type"] + # get all reachable imports including recursion + imported_module_infos = module_t.reachable_imports + unique_modules: dict[str, vy_ast.Module] = {} + for info in imported_module_infos: + if isinstance(info.typ, InterfaceT): + ast = info.typ.decl_node + if ast is None: # json abi + continue + else: + assert isinstance(info.typ, ModuleInfo) + ast = info.typ.module_t._module + + assert isinstance(ast, vy_ast.Module) # help mypy + # use resolved_path for uniqueness, since Module objects can actually + # come from multiple InputBundles (particularly builtin interfaces), + # so source_id is not guaranteed to be unique. + if ast.resolved_path in unique_modules: + # sanity check -- objects must be identical + assert unique_modules[ast.resolved_path] is ast + unique_modules[ast.resolved_path] = ast + annotated_ast_dict = { "contract_name": str(compiler_data.contract_path), "ast": ast_to_dict(compiler_data.annotated_vyper_module), + "imports": [ast_to_dict(ast) for ast in unique_modules.values()], } return annotated_ast_dict @@ -76,15 +102,14 @@ def build_archive_b64(compiler_data: CompilerData) -> str: def build_integrity(compiler_data: CompilerData) -> str: - return compiler_data.compilation_target._metadata["type"].integrity_sum + return compiler_data.integrity_sum def build_external_interface_output(compiler_data: CompilerData) -> str: interface = compiler_data.annotated_vyper_module._metadata["type"].interface stem = PurePath(compiler_data.contract_path).stem - # capitalize words separated by '_' - # ex: test_interface.vy -> TestInterface - name = "".join([x.capitalize() for x in stem.split("_")]) + + name = stem.title().replace("_", "") out = f"\n# External Interfaces\ninterface {name}:\n" for func in interface.functions.values(): @@ -102,22 +127,41 @@ def build_interface_output(compiler_data: CompilerData) -> str: interface = compiler_data.annotated_vyper_module._metadata["type"].interface out = "" - if interface.events: - out = "# Events\n\n" + if len(interface.structs) > 0: + out += "# Structs\n\n" + for struct in interface.structs.values(): + out += f"struct {struct.name}:\n" + for member_name, member_type in struct.members.items(): + out += f" {member_name}: {member_type}\n" + out += "\n\n" + + if len(interface.flags) > 0: + out += "# Flags\n\n" + for flag in interface.flags.values(): + out += f"flag {flag.name}:\n" + for flag_value in flag._flag_members: + out += f" {flag_value}\n" + out += "\n\n" + + if len(interface.events) > 0: + out += "# Events\n\n" for event in interface.events.values(): encoded_args = "\n ".join(f"{name}: {typ}" for name, typ in event.arguments.items()) - out = f"{out}event {event.name}:\n {encoded_args if event.arguments else 'pass'}\n" + out += f"event {event.name}:\n {encoded_args if event.arguments else 'pass'}\n\n\n" - if interface.functions: - out = f"{out}\n# Functions\n\n" + if len(interface.functions) > 0: + out += "# Functions\n\n" for func in interface.functions.values(): if func.visibility == FunctionVisibility.INTERNAL or func.name == "__init__": continue if func.mutability != StateMutability.NONPAYABLE: - out = f"{out}@{func.mutability.value}\n" + out += f"@{func.mutability.value}\n" args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments]) return_value = f" -> {func.return_type}" if func.return_type is not None else "" - out = f"{out}@external\ndef {func.name}({args}){return_value}:\n ...\n\n" + out += f"@external\ndef {func.name}({args}){return_value}:\n ...\n\n\n" + + out = out.rstrip("\n") + out += "\n" return out @@ -169,17 +213,27 @@ def build_ir_runtime_dict_output(compiler_data: CompilerData) -> dict: def build_metadata_output(compiler_data: CompilerData) -> dict: - sigs = compiler_data.function_signatures - - def _var_rec_dict(variable_record): - ret = vars(variable_record).copy() - ret["typ"] = str(ret["typ"]) - if ret["data_offset"] is None: - del ret["data_offset"] - for k in ("blockscopes", "defined_at", "encoding"): - del ret[k] - ret["location"] = ret["location"].name - return ret + # need ir info to be computed + _ = compiler_data.function_signatures + module_t = compiler_data.annotated_vyper_module._metadata["type"] + sigs = dict[str, ContractFunctionT]() + + def _fn_identifier(fn_t): + fn_id = fn_t._function_id + return f"{fn_t.name} ({fn_id})" + + for fn_t in module_t.exposed_functions: + assert isinstance(fn_t.ast_def, vy_ast.FunctionDef) + for rif_t in fn_t.reachable_internal_functions: + k = _fn_identifier(rif_t) + if k in sigs: + # sanity check that keys are injective with functions + assert sigs[k] == rif_t, (k, sigs[k], rif_t) + sigs[k] = rif_t + + fn_id = _fn_identifier(fn_t) + assert fn_id not in sigs + sigs[fn_id] = fn_t def _to_dict(func_t): ret = vars(func_t).copy() @@ -201,6 +255,10 @@ def _to_dict(func_t): ret["frame_info"] = vars(func_t._ir_info.frame_info).copy() del ret["frame_info"]["frame_vars"] # frame_var.pos might be IR, cannot serialize + ret["module_path"] = safe_relpath(func_t.decl_node.module_node.resolved_path) + ret["source_id"] = func_t.decl_node.module_node.source_id + ret["function_id"] = func_t._function_id + keep_keys = { "name", "return_type", @@ -212,6 +270,9 @@ def _to_dict(func_t): "visibility", "_ir_identifier", "nonreentrant_key", + "module_path", + "source_id", + "function_id", } ret = {k: v for k, v in ret.items() if k in keep_keys} return ret @@ -228,9 +289,13 @@ def build_method_identifiers_output(compiler_data: CompilerData) -> dict: def build_abi_output(compiler_data: CompilerData) -> list: module_t = compiler_data.annotated_vyper_module._metadata["type"] - _ = compiler_data.ir_runtime # ensure _ir_info is generated + if not compiler_data.annotated_vyper_module.is_interface: + _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() + if module_t.init_function: + abi += module_t.init_function.to_toplevel_abi_dict() + if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) @@ -320,15 +385,13 @@ def _build_source_map_output(compiler_data, bytecode, pc_maps): def build_source_map_output(compiler_data: CompilerData) -> dict: - bytecode, pc_maps = compile_ir.assembly_to_evm( - compiler_data.assembly, insert_compiler_metadata=False - ) + bytecode, pc_maps = compile_ir.assembly_to_evm(compiler_data.assembly, compiler_metadata=None) return _build_source_map_output(compiler_data, bytecode, pc_maps) def build_source_map_runtime_output(compiler_data: CompilerData) -> dict: bytecode, pc_maps = compile_ir.assembly_to_evm( - compiler_data.assembly_runtime, insert_compiler_metadata=False + compiler_data.assembly_runtime, compiler_metadata=None ) return _build_source_map_output(compiler_data, bytecode, pc_maps) diff --git a/vyper/compiler/output_bundle.py b/vyper/compiler/output_bundle.py index 92494e3a70..8af1b72289 100644 --- a/vyper/compiler/output_bundle.py +++ b/vyper/compiler/output_bundle.py @@ -1,7 +1,6 @@ import importlib import io import json -import os import zipfile from dataclasses import dataclass from functools import cached_property @@ -12,8 +11,9 @@ from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic -from vyper.semantics.analysis.module import _is_builtin -from vyper.utils import get_long_version +from vyper.semantics.analysis.imports import _is_builtin +from vyper.typing import StorageLayout +from vyper.utils import get_long_version, safe_relpath # data structures and routines for constructing "output bundles", # basically reproducible builds of a vyper contract, with varying @@ -62,7 +62,7 @@ def compiler_inputs(self) -> dict[str, CompilerInput]: sources = {} for c in inputs: - path = os.path.relpath(c.resolved_path) + path = safe_relpath(c.resolved_path) # note: there should be a 1:1 correspondence between # resolved_path and source_id, but for clarity use resolved_path # since it corresponds more directly to search path semantics. @@ -73,7 +73,7 @@ def compiler_inputs(self) -> dict[str, CompilerInput]: @cached_property def compilation_target_path(self): p = PurePath(self.compiler_data.file_input.resolved_path) - p = os.path.relpath(p) + p = safe_relpath(p) return _anonymize(p) @cached_property @@ -121,7 +121,7 @@ def used_search_paths(self) -> list[str]: sps = [sp for sp, count in tmp.items() if count > 0] assert len(sps) > 0 - return [_anonymize(os.path.relpath(sp)) for sp in sps] + return [_anonymize(safe_relpath(sp)) for sp in sps] class OutputBundleWriter: @@ -135,6 +135,11 @@ def bundle(self): def write_sources(self, sources: dict[str, CompilerInput]): raise NotImplementedError(f"write_sources: {self.__class__}") + def write_storage_layout_overrides( + self, compilation_target_path: str, storage_layout_override: StorageLayout + ): + raise NotImplementedError(f"write_storage_layout_overrides: {self.__class__}") + def write_search_paths(self, search_paths: list[str]): raise NotImplementedError(f"write_search_paths: {self.__class__}") @@ -159,8 +164,12 @@ def write(self): self.write_compilation_target([self.bundle.compilation_target_path]) self.write_search_paths(self.bundle.used_search_paths) self.write_settings(self.compiler_data.original_settings) - self.write_integrity(self.bundle.compilation_target.integrity_sum) + self.write_integrity(self.compiler_data.integrity_sum) self.write_sources(self.bundle.compiler_inputs) + if self.compiler_data.storage_layout_override is not None: + self.write_storage_layout_overrides( + self.bundle.compilation_target_path, self.compiler_data.storage_layout_override + ) class SolcJSONWriter(OutputBundleWriter): @@ -176,6 +185,13 @@ def write_sources(self, sources: dict[str, CompilerInput]): self._output["sources"].update(out) + def write_storage_layout_overrides( + self, compilation_target_path: str, storage_layout_override: StorageLayout + ): + self._output["storage_layout_overrides"] = { + compilation_target_path: storage_layout_override + } + def write_search_paths(self, search_paths: list[str]): self._output["settings"]["search_paths"] = search_paths @@ -238,6 +254,11 @@ def write_sources(self, sources: dict[str, CompilerInput]): for path, c in sources.items(): self.archive.writestr(_anonymize(path), c.contents) + def write_storage_layout_overrides( + self, compilation_target_path: str, storage_layout_override: StorageLayout + ): + self.archive.writestr("MANIFEST/storage_layout.json", json.dumps(storage_layout_override)) + def write_search_paths(self, search_paths: list[str]): self.archive.writestr("MANIFEST/searchpaths", "\n".join(search_paths)) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 147af24d67..17812ee535 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -1,9 +1,11 @@ import copy +import json import warnings from functools import cached_property from pathlib import Path, PurePath -from typing import Optional +from typing import Any, Optional +import vyper.codegen.core as codegen from vyper import ast as vy_ast from vyper.ast import natspec from vyper.codegen import module @@ -11,12 +13,14 @@ from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings, merge_settings from vyper.ir import compile_ir, optimizer +from vyper.ir.compile_ir import reset_symbols from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.analysis.data_positions import generate_layout_export +from vyper.semantics.analysis.imports import resolve_imports from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout -from vyper.utils import ERC5202_PREFIX, vyper_warn +from vyper.utils import ERC5202_PREFIX, sha256sum, vyper_warn from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") @@ -111,11 +115,14 @@ def contract_path(self): @cached_property def _generate_ast(self): + is_vyi = self.contract_path.suffix == ".vyi" + settings, ast = vy_ast.parse_to_ast_with_settings( self.source_code, self.source_id, module_path=self.contract_path.as_posix(), resolved_path=self.file_input.resolved_path.as_posix(), + is_interface=is_vyi, ) if self.original_settings: @@ -145,9 +152,46 @@ def vyper_module(self): _, ast = self._generate_ast return ast + def _compute_integrity_sum(self, imports_integrity_sum: str) -> str: + if self.storage_layout_override is not None: + layout_sum = sha256sum(json.dumps(self.storage_layout_override)) + return sha256sum(layout_sum + imports_integrity_sum) + return imports_integrity_sum + + @cached_property + def _resolve_imports(self): + # deepcopy so as to not interfere with `-f ast` output + vyper_module = copy.deepcopy(self.vyper_module) + with self.input_bundle.search_path(Path(vyper_module.resolved_path).parent): + imports = resolve_imports(vyper_module, self.input_bundle) + + # check integrity sum + integrity_sum = self._compute_integrity_sum(imports._integrity_sum) + + expected = self.expected_integrity_sum + if expected is not None and integrity_sum != expected: + # warn for now. strict/relaxed mode was considered but it costs + # interface and testing complexity to add another feature flag. + vyper_warn( + f"Mismatched integrity sum! Expected {expected}" + f" but got {integrity_sum}." + " (This likely indicates a corrupted archive)" + ) + + return vyper_module, imports, integrity_sum + + @cached_property + def integrity_sum(self): + return self._resolve_imports[2] + + @cached_property + def resolved_imports(self): + return self._resolve_imports[1] + @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: - module = generate_annotated_ast(self.vyper_module, self.input_bundle) + module = self._resolve_imports[0] + analyze_module(module) nspec = natspec.parse_natspec(module) return nspec, module @@ -167,17 +211,6 @@ def compilation_target(self): """ module_t = self.annotated_vyper_module._metadata["type"] - expected = self.expected_integrity_sum - - if expected is not None and module_t.integrity_sum != expected: - # warn for now. strict/relaxed mode was considered but it costs - # interface and testing complexity to add another feature flag. - vyper_warn( - f"Mismatched integrity sum! Expected {expected}" - f" but got {module_t.integrity_sum}." - " (This likely indicates a corrupted archive)" - ) - validate_compilation_target(module_t) return self.annotated_vyper_module @@ -249,12 +282,14 @@ def assembly_runtime(self) -> list: @cached_property def bytecode(self) -> bytes: - insert_compiler_metadata = not self.no_bytecode_metadata - return generate_bytecode(self.assembly, insert_compiler_metadata=insert_compiler_metadata) + metadata = None + if not self.no_bytecode_metadata: + metadata = bytes.fromhex(self.integrity_sum) + return generate_bytecode(self.assembly, compiler_metadata=metadata) @cached_property def bytecode_runtime(self) -> bytes: - return generate_bytecode(self.assembly_runtime, insert_compiler_metadata=False) + return generate_bytecode(self.assembly_runtime, compiler_metadata=None) @cached_property def blueprint_bytecode(self) -> bytes: @@ -267,28 +302,6 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: - """ - Validates and annotates the Vyper AST. - - Arguments - --------- - vyper_module : vy_ast.Module - Top-level Vyper AST node - - Returns - ------- - vy_ast.Module - Annotated Vyper AST - """ - vyper_module = copy.deepcopy(vyper_module) - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: analyze_module does type inference on the AST - analyze_module(vyper_module, input_bundle) - - return vyper_module - - def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -309,6 +322,10 @@ def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, IR to generate deployment bytecode IR to generate runtime bytecode """ + # make IR output the same between runs + codegen.reset_names() + reset_symbols() + with anchor_settings(settings): ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) if settings.optimize != OptimizationLevel.NONE: @@ -351,7 +368,7 @@ def _find_nested_opcode(assembly, key): return any(_find_nested_opcode(x, key) for x in sublists) -def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes: +def generate_bytecode(assembly: list, compiler_metadata: Optional[Any]) -> bytes: """ Generate bytecode from assembly instructions. @@ -365,6 +382,4 @@ def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes: bytes Final compiled bytecode. """ - return compile_ir.assembly_to_evm(assembly, insert_compiler_metadata=insert_compiler_metadata)[ - 0 - ] + return compile_ir.assembly_to_evm(assembly, compiler_metadata=compiler_metadata)[0] diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index 7c20e03906..e9840e8334 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -77,7 +77,7 @@ def as_cli(self): if self.optimize is not None: ret.append(" --optimize " + str(self.optimize)) if self.experimental_codegen is True: - ret.append(" --experimental-codegen") + ret.append(" --venom") if self.evm_version is not None: ret.append(" --evm-version " + self.evm_version) if self.debug is True: @@ -120,12 +120,12 @@ def _merge_one(lhs, rhs, helpstr): return lhs if rhs is None else rhs ret = Settings() - ret.evm_version = _merge_one(one.evm_version, two.evm_version, "evm version") - ret.optimize = _merge_one(one.optimize, two.optimize, "optimize") - ret.experimental_codegen = _merge_one( - one.experimental_codegen, two.experimental_codegen, "experimental codegen" - ) - ret.enable_decimals = _merge_one(one.enable_decimals, two.enable_decimals, "enable-decimals") + for field in dataclasses.fields(ret): + if field.name == "compiler_version": + continue + pretty_name = field.name.replace("_", "-") # e.g. evm_version -> evm-version + val = _merge_one(getattr(one, field.name), getattr(two, field.name), pretty_name) + setattr(ret, field.name, val) return ret diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3c0444b1ca..04a60b6306 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -54,6 +54,7 @@ def __init__(self, message="Error Message not found.", *items, hint=None, prev_d self.lineno = None self.col_offset = None self.annotations = None + self.resolved_path = None if len(items) == 1 and isinstance(items[0], tuple) and isinstance(items[0][0], int): # support older exceptions that don't annotate - remove this in the future! @@ -97,10 +98,7 @@ def hint(self): @property def message(self): - msg = self._message - if self.hint: - msg += f"\n\n (hint: {self.hint})" - return msg + return self._message def format_annotation(self, value): from vyper import ast as vy_ast @@ -130,13 +128,18 @@ def format_annotation(self, value): module_node = node.module_node # TODO: handle cases where module is None or vy_ast.Module - if module_node.get("path") not in (None, ""): - node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' + if module_node.get("resolved_path") not in (None, ""): + node_msg = self._format_contract_details( + node_msg, module_node.resolved_path, node.lineno + ) fn_node = node.get_ancestor(vy_ast.FunctionDef) if fn_node: node_msg = f'{node_msg}function "{fn_node.name}", ' + elif self.resolved_path is not None: + node_msg = self._format_contract_details(node_msg, self.resolved_path, node.lineno) + col_offset_str = "" if node.col_offset is None else str(node.col_offset) node_msg = f"{node_msg}line {node.lineno}:{col_offset_str} \n{source_annotation}\n" @@ -148,7 +151,21 @@ def format_annotation(self, value): node_msg = textwrap.indent(node_msg, " ") return node_msg + def _add_hint(self, msg): + hint = self.hint + if hint is None: + return msg + return msg + f"\n (hint: {self.hint})" + + def _format_contract_details(self, msg, path, lineno): + from vyper.utils import safe_relpath + + return f'{msg}contract "{safe_relpath(path)}:{lineno}", ' + def __str__(self): + return self._add_hint(self._str_helper()) + + def _str_helper(self): if not self.annotations: if self.lineno is not None and self.col_offset is not None: return f"line {self.lineno}:{self.col_offset} {self.message}" @@ -175,15 +192,14 @@ class VyperException(_BaseVyperException): class SyntaxException(VyperException): - """Invalid syntax.""" - def __init__(self, message, source_code, lineno, col_offset): + def __init__(self, message, source_code, lineno, col_offset, hint=None): item = types.SimpleNamespace() # TODO: Create an actual object for this item.lineno = lineno item.col_offset = col_offset item.full_source_code = source_code - super().__init__(message, item) + super().__init__(message, item, hint=hint) class DecimalOverrideException(VyperException): diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 4c68aa2c8f..936e6d5d72 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -54,6 +54,11 @@ def mksymbol(name=""): return f"_sym_{name}{_next_symbol}" +def reset_symbols(): + global _next_symbol + _next_symbol = 0 + + def mkdebug(pc_debugger, ast_source): i = Instruction("DEBUG", ast_source) i.pc_debugger = pc_debugger @@ -1033,6 +1038,9 @@ def _stack_peephole_opts(assembly): if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: changed = True del assembly[i] + if assembly[i] == "DUP1" and assembly[i + 1] == "SWAP1": + changed = True + del assembly[i + 1] i += 1 return changed @@ -1155,22 +1163,24 @@ def _relocate_segments(assembly): # TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps -def assembly_to_evm(assembly, pc_ofst=0, insert_compiler_metadata=False): +def assembly_to_evm(assembly, pc_ofst=0, compiler_metadata=None): bytecode, source_maps, _ = assembly_to_evm_with_symbol_map( - assembly, pc_ofst=pc_ofst, insert_compiler_metadata=insert_compiler_metadata + assembly, pc_ofst=pc_ofst, compiler_metadata=compiler_metadata ) return bytecode, source_maps -def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadata=False): +def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, compiler_metadata=None): """ Assembles assembly into EVM assembly: list of asm instructions pc_ofst: when constructing the source map, the amount to offset all pcs by (no effect until we add deploy code source map) - insert_compiler_metadata: whether to append vyper metadata to output - (should be true for runtime code) + compiler_metadata: any compiler metadata to add. pass `None` to indicate + no metadata to be added (should always be `None` for + runtime code). the value is opaque, and will be passed + directly to `cbor2.dumps()`. """ line_number_map = { "breakpoints": set(), @@ -1278,10 +1288,11 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat pc += 1 bytecode_suffix = b"" - if insert_compiler_metadata: + if compiler_metadata is not None: # this will hold true when we are in initcode assert immutables_len is not None metadata = ( + compiler_metadata, len(runtime_code), data_section_lengths, immutables_len, diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 026e0626e7..adfc7540a0 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,7 +1,7 @@ import enum -from dataclasses import dataclass +from dataclasses import dataclass, fields from functools import cached_property -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional from vyper import ast as vy_ast from vyper.compiler.input_bundle import CompilerInput @@ -13,7 +13,7 @@ if TYPE_CHECKING: from vyper.semantics.types.function import ContractFunctionT - from vyper.semantics.types.module import InterfaceT, ModuleT + from vyper.semantics.types.module import ModuleT class FunctionVisibility(StringEnum): @@ -96,6 +96,7 @@ class AnalysisResult: class ModuleInfo(AnalysisResult): module_t: "ModuleT" alias: str + # import_node: vy_ast._ImportStmt # maybe could be useful ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP ownership_decl: Optional[vy_ast.VyperNode] = None @@ -119,13 +120,19 @@ def __hash__(self): return hash(id(self.module_t)) -@dataclass(frozen=True) +@dataclass class ImportInfo(AnalysisResult): - typ: Union[ModuleInfo, "InterfaceT"] alias: str # the name in the namespace qualified_module_name: str # for error messages compiler_input: CompilerInput # to recover file info for ast export - node: vy_ast.VyperNode + parsed: Any # (json) abi | AST + _typ: Any = None # type to be filled in during analysis + + @property + def typ(self): + if self._typ is None: # pragma: nocover + raise CompilerPanic("unreachable!") + return self._typ def to_dict(self): ret = {"alias": self.alias, "qualified_module_name": self.qualified_module_name} @@ -234,6 +241,17 @@ class VarAccess: # A sentinel indicating a subscript access SUBSCRIPT_ACCESS: ClassVar[Any] = object() + # custom __reduce__ and _produce implementations to work around + # a pickle bug. + # see https://github.com/python/cpython/issues/124937#issuecomment-2392227290 + def __reduce__(self): + dict_obj = {f.name: getattr(self, f.name) for f in fields(self)} + return self.__class__._produce, (dict_obj,) + + @classmethod + def _produce(cls, data): + return cls(**data) + @cached_property def attrs(self): ret = [] @@ -257,7 +275,10 @@ def to_dict(self): # map SUBSCRIPT_ACCESS to `"$subscript_access"` (which is an identifier # which can't be constructed by the user) path = ["$subscript_access" if s is self.SUBSCRIPT_ACCESS else s for s in self.path] - varname = var.decl_node.target.id + if isinstance(var.decl_node, vy_ast.arg): + varname = var.decl_node.arg + else: + varname = var.decl_node.target.id decl_node = var.decl_node.get_id_dict() ret = {"name": varname, "decl_node": decl_node, "access_path": path} @@ -283,7 +304,6 @@ def __post_init__(self): for attr in should_match: if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic(f"Bad analysis: non-matching {attr}: {self}") - self._writes: OrderedSet[VarAccess] = OrderedSet() self._reads: OrderedSet[VarAccess] = OrderedSet() diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index 6e4166dc52..98cab0f8cb 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -180,8 +180,11 @@ def visit_Compare(self, node): raise UnfoldableNode( f"Invalid literal types for {node.op.description} comparison", node ) - - value = node.op._op(left.value, right.value) + lvalue, rvalue = left.value, right.value + if isinstance(left, vy_ast.Hex): + # Hex values are str, convert to be case-unsensitive. + lvalue, rvalue = lvalue.lower(), rvalue.lower() + value = node.op._op(lvalue, rvalue) return vy_ast.NameConstant.from_node(node, value=value) def visit_List(self, node) -> vy_ast.ExprNode: diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py deleted file mode 100644 index e406878194..0000000000 --- a/vyper/semantics/analysis/import_graph.py +++ /dev/null @@ -1,37 +0,0 @@ -import contextlib -from dataclasses import dataclass, field -from typing import Iterator - -from vyper import ast as vy_ast -from vyper.exceptions import CompilerPanic, ImportCycle - -""" -data structure for collecting import statements and validating the -import graph -""" - - -@dataclass -class ImportGraph: - # the current path in the import graph traversal - _path: list[vy_ast.Module] = field(default_factory=list) - - def push_path(self, module_ast: vy_ast.Module) -> None: - if module_ast in self._path: - cycle = self._path + [module_ast] - raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) - - self._path.append(module_ast) - - def pop_path(self, expected: vy_ast.Module) -> None: - popped = self._path.pop() - if expected != popped: - raise CompilerPanic("unreachable") - - @contextlib.contextmanager - def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: - self.push_path(module_ast) - try: - yield - finally: - self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py new file mode 100644 index 0000000000..7f02bce79d --- /dev/null +++ b/vyper/semantics/analysis/imports.py @@ -0,0 +1,342 @@ +import contextlib +from dataclasses import dataclass, field +from pathlib import Path, PurePath +from typing import Any, Iterator + +import vyper.builtins.interfaces +from vyper import ast as vy_ast +from vyper.compiler.input_bundle import ( + ABIInput, + CompilerInput, + FileInput, + FilesystemInputBundle, + InputBundle, + PathLike, +) +from vyper.exceptions import ( + CompilerPanic, + DuplicateImport, + ImportCycle, + ModuleNotFound, + StructureException, + tag_exceptions, +) +from vyper.semantics.analysis.base import ImportInfo +from vyper.utils import safe_relpath, sha256sum + +""" +collect import statements and validate the import graph. +this module is separated into its own pass so that we can resolve the import +graph quickly (without doing semantic analysis) and for cleanliness, to +segregate the I/O portion of semantic analysis into its own pass. +""" + + +@dataclass +class _ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + # stack of dicts, each item in the stack is a dict keeping + # track of imports in the current module + _imports: list[dict] = field(default_factory=list) + + @property + def imported_modules(self): + return self._imports[-1] + + @property + def current_module(self): + return self._path[-1] + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + self._imports.append({}) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + assert expected is popped, "unreachable" + self._imports.pop() + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) + + +class ImportAnalyzer: + def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): + self.input_bundle = input_bundle + self.graph = graph + self._ast_of: dict[int, vy_ast.Module] = {} + + self.seen: set[vy_ast.Module] = set() + + self._integrity_sum = None + + # should be all system paths + topmost module path + self.absolute_search_paths = input_bundle.search_paths.copy() + + def resolve_imports(self, module_ast: vy_ast.Module): + self._resolve_imports_r(module_ast) + self._integrity_sum = self._calculate_integrity_sum_r(module_ast) + + def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): + acc = [sha256sum(module_ast.full_source_code)] + for s in module_ast.get_children((vy_ast.Import, vy_ast.ImportFrom)): + info = s._metadata["import_info"] + + if info.compiler_input.path.suffix in (".vyi", ".json"): + # NOTE: this needs to be redone if interfaces can import other interfaces + acc.append(info.compiler_input.sha256sum) + else: + acc.append(self._calculate_integrity_sum_r(info.parsed)) + + return sha256sum("".join(acc)) + + def _resolve_imports_r(self, module_ast: vy_ast.Module): + if module_ast in self.seen: + return + with self.graph.enter_path(module_ast): + for node in module_ast.body: + with tag_exceptions(node): + if isinstance(node, vy_ast.Import): + self._handle_Import(node) + elif isinstance(node, vy_ast.ImportFrom): + self._handle_ImportFrom(node) + + self.seen.add(module_ast) + + def _handle_Import(self, node: vy_ast.Import): + # import x.y[name] as y[alias] + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + msg = "import requires an accompanying `as` statement" + suggested_alias = node.name[node.name.rfind(".") :] + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) + + self._add_import(node, 0, node.name, alias) + + def _handle_ImportFrom(self, node: vy_ast.ImportFrom): + # from m.n[module] import x[name] as y[alias] + + alias = node.alias + + if alias is None: + alias = node.name + + module = node.module or "" + if module: + module += "." + + qualified_module_name = module + node.name + self._add_import(node, node.level, qualified_module_name, alias) + + def _add_import( + self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str + ) -> None: + compiler_input, ast = self._load_import(node, level, qualified_module_name, alias) + node._metadata["import_info"] = ImportInfo( + alias, qualified_module_name, compiler_input, ast + ) + + # load an InterfaceT or ModuleInfo from an import. + # raises FileNotFoundError + def _load_import( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> tuple[CompilerInput, Any]: + if _is_builtin(module_str): + return _load_builtin_import(level, module_str) + + path = _import_to_path(level, module_str) + + if path in self.graph.imported_modules: + previous_import_stmt = self.graph.imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self.graph.imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self._load_file(path_vy, level) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + + try: + file = self._load_file(path.with_suffix(".vyi"), level) + assert isinstance(file, FileInput) # mypy hint + module_ast = self._ast_from_file(file) + self.resolve_imports(module_ast) + + # language does not yet allow recursion for vyi files + # self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError: + pass + + try: + file = self._load_file(path.with_suffix(".json"), level) + assert isinstance(file, ABIInput) # mypy hint + return file, file.abi + except FileNotFoundError: + pass + + hint = None + if module_str.startswith("vyper.interfaces"): + hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, hint=hint) from err + + def _load_file(self, path: PathLike, level: int) -> CompilerInput: + ast = self.graph.current_module + + search_paths: list[PathLike] # help mypy + if level != 0: # relative import + search_paths = [Path(ast.resolved_path).parent] + else: + search_paths = self.absolute_search_paths + + with self.input_bundle.temporary_search_paths(search_paths): + return self.input_bundle.load_file(path) + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_ast(file) + + return ast_of[file.source_id] + + +def _parse_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=module_path.as_posix(), + resolved_path=file.resolved_path.as_posix(), + ) + return ret + + +# convert an import to a path (without suffix) +def _import_to_path(level: int, module_str: str) -> PurePath: + base_path = "" + if level > 1: + base_path = "../" * (level - 1) + elif level == 1: + base_path = "./" + return PurePath(f"{base_path}{module_str.replace('.', '/')}/") + + +# can add more, e.g. "vyper.builtins.interfaces", etc. +BUILTIN_PREFIXES = ["ethereum.ercs"] + + +# TODO: could move this to analysis/common.py or something +def _is_builtin(module_str): + return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) + + +_builtins_cache: dict[PathLike, tuple[CompilerInput, vy_ast.Module]] = {} + + +def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy_ast.Module]: + if not _is_builtin(module_str): # pragma: nocover + raise CompilerPanic("unreachable!") + + builtins_path = vyper.builtins.interfaces.__path__[0] + # hygiene: convert to relpath to avoid leaking user directory info + # (note Path.relative_to cannot handle absolute to relative path + # conversion, so we must use the `os` module). + builtins_path = safe_relpath(builtins_path) + + search_path = Path(builtins_path).parent.parent.parent + # generate an input bundle just because it knows how to build paths. + input_bundle = FilesystemInputBundle([search_path]) + + # remap builtins directory -- + # ethereum/ercs => vyper/builtins/interfaces + remapped_module = module_str + if remapped_module.startswith("ethereum.ercs"): + remapped_module = remapped_module.removeprefix("ethereum.ercs") + remapped_module = vyper.builtins.interfaces.__package__ + remapped_module + + path = _import_to_path(level, remapped_module).with_suffix(".vyi") + + # builtins are globally the same, so we can safely cache them + # (it is also *correct* to cache them, so that types defined in builtins + # compare correctly using pointer-equality.) + if path in _builtins_cache: + file, ast = _builtins_cache[path] + return file, ast + + try: + file = input_bundle.load_file(path) + assert isinstance(file, FileInput) # mypy hint + except FileNotFoundError as e: + hint = None + components = module_str.split(".") + # common issue for upgrading codebases from v0.3.x to v0.4.x - + # hint: rename ERC20 to IERC20 + if components[-1].startswith("ERC"): + module_prefix = components[-1] + hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" + raise ModuleNotFound(module_str, hint=hint) from e + + interface_ast = _parse_ast(file) + + # no recursion needed since builtins don't have any imports + + _builtins_cache[path] = file, interface_ast + return file, interface_ast + + +def resolve_imports(module_ast: vy_ast.Module, input_bundle: InputBundle): + graph = _ImportGraph() + analyzer = ImportAnalyzer(input_bundle, graph) + analyzer.resolve_imports(module_ast) + + return analyzer diff --git a/vyper/semantics/analysis/levenshtein_utils.py b/vyper/semantics/analysis/levenshtein_utils.py index fc6e497d43..ac4fe4fab3 100644 --- a/vyper/semantics/analysis/levenshtein_utils.py +++ b/vyper/semantics/analysis/levenshtein_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Optional def levenshtein_norm(source: str, target: str) -> float: @@ -79,7 +79,7 @@ def get_levenshtein_error_suggestions(*args, **kwargs) -> Callable: def _get_levenshtein_error_suggestions( key: str, namespace: dict[str, Any], threshold: float -) -> str: +) -> Optional[str]: """ Generate an error message snippet for the suggested closest values in the provided namespace with the shortest normalized Levenshtein distance from the given key if that distance @@ -100,11 +100,11 @@ def _get_levenshtein_error_suggestions( """ if key is None or key == "": - return "" + return None distances = sorted([(i, levenshtein_norm(key, i)) for i in namespace], key=lambda k: k[1]) if len(distances) > 0 and distances[0][1] <= threshold: if len(distances) > 1 and distances[1][1] <= threshold: return f"Did you mean '{distances[0][0]}', or maybe '{distances[1][0]}'?" return f"Did you mean '{distances[0][0]}'?" - return "" + return None diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 26c6a4ef9f..461326d72d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -136,7 +136,7 @@ def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> Non parent = node.get_ancestor() if isinstance(parent, vy_ast.Call): ok_func = isinstance(parent.func, vy_ast.Name) and parent.func.id == "slice" - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if ok_func and ok_args: return @@ -154,7 +154,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: "msg.data is only allowed inside of the slice, len or raw_call functions", node ) if parent.get("func.id") == "slice": - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if not ok_args: raise StructureException( "slice(msg.data) must use a compile-time constant for length argument", parent @@ -317,7 +317,7 @@ def analyze(self): for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, modifiability=modifiability + arg.typ, location=location, modifiability=modifiability, decl_node=arg.ast_source ) for node in self.fn_node.body: @@ -363,7 +363,7 @@ def visit_AnnAssign(self, node): # validate the value before adding it to the namespace self.expr_visitor.visit(node.value, typ) - self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY, decl_node=node) self.expr_visitor.visit(node.target, typ) @@ -575,7 +575,7 @@ def visit_For(self, node): target_name = node.target.target.id # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( - target_type, modifiability=Modifiability.RUNTIME_CONSTANT + target_type, modifiability=Modifiability.RUNTIME_CONSTANT, decl_node=node.target ) self.expr_visitor.visit(node.target.target, target_type) @@ -810,13 +810,17 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: self.visit(kwarg.value, typ) elif is_type_t(func_type, EventT): - # events have no kwargs + # event ctors expected_types = func_type.typedef.arguments.values() # type: ignore - for arg, typ in zip(node.args, expected_types): - self.visit(arg, typ) + # Handle keyword args if present, otherwise use positional args + if len(node.keywords) > 0: + for kwarg, arg_type in zip(node.keywords, expected_types): + self.visit(kwarg.value, arg_type) + else: + for arg, typ in zip(node.args, expected_types): + self.visit(arg, typ) elif is_type_t(func_type, StructT): # struct ctors - # ctors have no kwargs expected_types = func_type.typedef.members.values() # type: ignore for kwarg, arg_type in zip(node.keywords, expected_types): self.visit(kwarg.value, arg_type) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index d6bbea1b48..534af4d633 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,23 +1,11 @@ -import os -from pathlib import Path, PurePath from typing import Any, Optional -import vyper.builtins.interfaces from vyper import ast as vy_ast -from vyper.compiler.input_bundle import ( - ABIInput, - CompilerInput, - FileInput, - FilesystemInputBundle, - InputBundle, - PathLike, -) from vyper.evm.opcodes import version_check from vyper.exceptions import ( BorrowException, CallViolation, CompilerPanic, - DuplicateImport, EvmVersionException, ExceptionList, ImmutableViolation, @@ -25,7 +13,6 @@ InterfaceViolation, InvalidLiteral, InvalidType, - ModuleNotFound, StateAccessViolation, StructureException, UndeclaredDefinition, @@ -45,7 +32,6 @@ from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.getters import generate_public_variable_getters -from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses from vyper.semantics.analysis.utils import ( check_modifiability, @@ -54,36 +40,23 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace -from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT +from vyper.semantics.types import TYPE_T, EventT, FlagT, InterfaceT, StructT, is_type_t from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation from vyper.utils import OrderedSet -def analyze_module( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph = None, - is_interface: bool = False, -) -> ModuleT: +def analyze_module(module_ast: vy_ast.Module) -> ModuleT: """ Analyze a Vyper module AST node, recursively analyze all its imports, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ - if import_graph is None: - import_graph = ImportGraph() - - return _analyze_module_r(module_ast, input_bundle, import_graph, is_interface) + return _analyze_module_r(module_ast, module_ast.is_interface) -def _analyze_module_r( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph, - is_interface: bool = False, -): +def _analyze_module_r(module_ast: vy_ast.Module, is_interface: bool = False): if "type" in module_ast._metadata: # we don't need to analyse again, skip out assert isinstance(module_ast._metadata["type"], ModuleT) @@ -92,8 +65,8 @@ def _analyze_module_r( # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - with namespace.enter_scope(), import_graph.enter_path(module_ast): - analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + with namespace.enter_scope(): + analyzer = ModuleAnalyzer(module_ast, namespace, is_interface) analyzer.analyze_module_body() _analyze_call_graph(module_ast) @@ -150,15 +123,15 @@ def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT path = path or [] path.append(fn_t) - root = path[0] for g in fn_t.called_functions: if g in fn_t.reachable_internal_functions: # already seen continue - if g == root: - message = " -> ".join([f.name for f in path]) + if g in path: + extended_path = path + [g] + message = " -> ".join([f.name for f in extended_path]) raise CallViolation(f"Contract contains cyclic function call: {message}") _compute_reachable_set(g, path=path) @@ -176,24 +149,14 @@ class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, - module_node: vy_ast.Module, - input_bundle: InputBundle, - namespace: Namespace, - import_graph: ImportGraph, - is_interface: bool = False, + self, module_node: vy_ast.Module, namespace: Namespace, is_interface: bool = False ) -> None: self.ast = module_node - self.input_bundle = input_bundle self.namespace = namespace - self._import_graph = import_graph self.is_interface = is_interface - # keep track of imported modules to prevent duplicate imports - self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} - # keep track of exported functions to prevent duplicate exports - self._exposed_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} + self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} self._events: list[EventT] = [] @@ -390,16 +353,6 @@ def validate_initialized_modules(self): err_list.raise_if_not_empty() - def _ast_from_file(self, file: FileInput) -> vy_ast.Module: - # cache ast if we have seen it before. - # this gives us the additional property of object equality on - # two ASTs produced from the same source - ast_of = self.input_bundle._cache._ast_of - if file.source_id not in ast_of: - ast_of[file.source_id] = _parse_ast(file) - - return ast_of[file.source_id] - def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) @@ -414,7 +367,7 @@ def visit_ImplementsDecl(self, node): raise StructureException(msg, node.annotation, hint=hint) # grab exposed functions - funcs = self._exposed_functions + funcs = {fn_t: node for fn_t, node in self._all_functions.items() if fn_t.is_external} type_.validate_implements(node, funcs) node._metadata["interface_type"] = type_ @@ -546,9 +499,19 @@ def visit_ExportsDecl(self, node): raise StructureException("not a public variable!", decl, item) funcs = [decl._expanded_getter._metadata["func_type"]] elif isinstance(info.typ, ContractFunctionT): + # e.g. lib1.__interface__(self._addr).foo + if not isinstance(get_expr_info(item.value).typ, (ModuleT, TYPE_T)): + raise StructureException( + "invalid export of a value", + item.value, + hint="exports should look like .", + ) + # regular function funcs = [info.typ] - elif isinstance(info.typ, InterfaceT): + elif is_type_t(info.typ, InterfaceT): + interface_t = info.typ.typedef + if not isinstance(item, vy_ast.Attribute): raise StructureException( "invalid export", @@ -559,7 +522,7 @@ def visit_ExportsDecl(self, node): if module_info is None: raise StructureException("not a valid module!", item.value) - if info.typ not in module_info.typ.implemented_interfaces: + if interface_t not in module_info.typ.implemented_interfaces: iface_str = item.node_source_code module_str = item.value.node_source_code msg = f"requested `{iface_str}` but `{module_str}`" @@ -570,9 +533,15 @@ def visit_ExportsDecl(self, node): # find the specific implementation of the function in the module funcs = [ module_exposed_fns[fn.name] - for fn in info.typ.functions.values() + for fn in interface_t.functions.values() if fn.is_external ] + + if len(funcs) == 0: + path = module_info.module_node.path + msg = f"{module_info.alias} (located at `{path}`) has no external functions!" + raise StructureException(msg, item) + else: raise StructureException( f"not a function or interface: `{info.typ}`", info.typ.decl_node, item @@ -608,10 +577,10 @@ def _self_t(self): def _add_exposed_function(self, func_t, node, relax=True): # call this before self._self_t.typ.add_member() for exception raising # priority - if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None: + if not relax and (prev_decl := self._all_functions.get(func_t)) is not None: raise StructureException("already exported!", node, prev_decl=prev_decl) - self._exposed_functions[func_t] = node + self._all_functions[func_t] = node def visit_VariableDecl(self, node): # postcondition of VariableDecl.validate @@ -740,32 +709,44 @@ def visit_FunctionDef(self, node): self._add_exposed_function(func_t, node) def visit_Import(self, node): - # import x.y[name] as y[alias] + self._add_import(node) - alias = node.alias + def visit_ImportFrom(self, node): + self._add_import(node) - if alias is None: - alias = node.name + def _add_import(self, node: vy_ast.VyperNode) -> None: + import_info = node._metadata["import_info"] + # similar structure to import analyzer + module_info = self._load_import(import_info) - # don't handle things like `import x.y` - if "." in alias: - msg = "import requires an accompanying `as` statement" - suggested_alias = node.name[node.name.rfind(".") :] - hint = f"try `import {node.name} as {suggested_alias}`" - raise StructureException(msg, node, hint=hint) + import_info._typ = module_info - self._add_import(node, 0, node.name, alias) + self.namespace[import_info.alias] = module_info - def visit_ImportFrom(self, node): - # from m.n[module] import x[name] as y[alias] - alias = node.alias or node.name + def _load_import(self, import_info: ImportInfo) -> Any: + path = import_info.compiler_input.path + if path.suffix == ".vy": + module_ast = import_info.parsed + with override_global_namespace(Namespace()): + module_t = _analyze_module_r(module_ast, is_interface=False) + return ModuleInfo(module_t, import_info.alias) - module = node.module or "" - if module: - module += "." + if path.suffix == ".vyi": + module_ast = import_info.parsed + with override_global_namespace(Namespace()): + module_t = _analyze_module_r(module_ast, is_interface=True) + + # NOTE: might be cleaner to return the whole module, so we + # have a ModuleInfo, that way we don't need to have different + # code paths for InterfaceT vs ModuleInfo + return module_t.interface - qualified_module_name = module + node.name - self._add_import(node, node.level, qualified_module_name, alias) + if path.suffix == ".json": + abi = import_info.parsed + path = import_info.compiler_input.path + return InterfaceT.from_json_abi(str(path), abi) + + raise CompilerPanic("unreachable") # pragma: nocover def visit_InterfaceDef(self, node): interface_t = InterfaceT.from_InterfaceDef(node) @@ -776,190 +757,3 @@ def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) node._metadata["struct_type"] = struct_t self.namespace[node.name] = struct_t - - def _add_import( - self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str - ) -> None: - compiler_input, module_info = self._load_import(node, level, qualified_module_name, alias) - node._metadata["import_info"] = ImportInfo( - module_info, alias, qualified_module_name, compiler_input, node - ) - self.namespace[alias] = module_info - - # load an InterfaceT or ModuleInfo from an import. - # raises FileNotFoundError - def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: - # the directory this (currently being analyzed) module is in - self_search_path = Path(self.ast.resolved_path).parent - - with self.input_bundle.poke_search_path(self_search_path): - return self._load_import_helper(node, level, module_str, alias) - - def _load_import_helper( - self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str - ) -> tuple[CompilerInput, Any]: - if _is_builtin(module_str): - return _load_builtin_import(level, module_str) - - path = _import_to_path(level, module_str) - - # this could conceivably be in the ImportGraph but no need at this point - if path in self._imported_modules: - previous_import_stmt = self._imported_modules[path] - raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) - - self._imported_modules[path] = node - - err = None - - try: - path_vy = path.with_suffix(".vy") - file = self.input_bundle.load_file(path_vy) - assert isinstance(file, FileInput) # mypy hint - - module_ast = self._ast_from_file(file) - - with override_global_namespace(Namespace()): - module_t = _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=False, - ) - - return file, ModuleInfo(module_t, alias) - - except FileNotFoundError as e: - # escape `e` from the block scope, it can make things - # easier to debug. - err = e - - try: - file = self.input_bundle.load_file(path.with_suffix(".vyi")) - assert isinstance(file, FileInput) # mypy hint - module_ast = self._ast_from_file(file) - - with override_global_namespace(Namespace()): - _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=True, - ) - module_t = module_ast._metadata["type"] - - return file, module_t.interface - - except FileNotFoundError: - pass - - try: - file = self.input_bundle.load_file(path.with_suffix(".json")) - assert isinstance(file, ABIInput) # mypy hint - return file, InterfaceT.from_json_abi(str(file.path), file.abi) - except FileNotFoundError: - pass - - hint = None - if module_str.startswith("vyper.interfaces"): - hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" - - # copy search_paths, makes debugging a bit easier - search_paths = self.input_bundle.search_paths.copy() # noqa: F841 - raise ModuleNotFound(module_str, hint=hint) from err - - -def _parse_ast(file: FileInput) -> vy_ast.Module: - module_path = file.resolved_path # for error messages - try: - # try to get a relative path, to simplify the error message - cwd = Path(".") - if module_path.is_absolute(): - cwd = cwd.resolve() - module_path = module_path.relative_to(cwd) - except ValueError: - # we couldn't get a relative path (cf. docs for Path.relative_to), - # use the resolved path given to us by the InputBundle - pass - - ret = vy_ast.parse_to_ast( - file.source_code, - source_id=file.source_id, - module_path=module_path.as_posix(), - resolved_path=file.resolved_path.as_posix(), - ) - return ret - - -# convert an import to a path (without suffix) -def _import_to_path(level: int, module_str: str) -> PurePath: - base_path = "" - if level > 1: - base_path = "../" * (level - 1) - elif level == 1: - base_path = "./" - return PurePath(f"{base_path}{module_str.replace('.','/')}/") - - -# can add more, e.g. "vyper.builtins.interfaces", etc. -BUILTIN_PREFIXES = ["ethereum.ercs"] - - -# TODO: could move this to analysis/common.py or something -def _is_builtin(module_str): - return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) - - -_builtins_cache: dict[PathLike, tuple[CompilerInput, ModuleT]] = {} - - -def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, InterfaceT]: - if not _is_builtin(module_str): # pragma: nocover - raise CompilerPanic("unreachable!") - - builtins_path = vyper.builtins.interfaces.__path__[0] - # hygiene: convert to relpath to avoid leaking user directory info - # (note Path.relative_to cannot handle absolute to relative path - # conversion, so we must use the `os` module). - builtins_path = os.path.relpath(builtins_path) - - search_path = Path(builtins_path).parent.parent.parent - # generate an input bundle just because it knows how to build paths. - input_bundle = FilesystemInputBundle([search_path]) - - # remap builtins directory -- - # ethereum/ercs => vyper/builtins/interfaces - remapped_module = module_str - if remapped_module.startswith("ethereum.ercs"): - remapped_module = remapped_module.removeprefix("ethereum.ercs") - remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - - path = _import_to_path(level, remapped_module).with_suffix(".vyi") - - # builtins are globally the same, so we can safely cache them - # (it is also *correct* to cache them, so that types defined in builtins - # compare correctly using pointer-equality.) - if path in _builtins_cache: - file, module_t = _builtins_cache[path] - return file, module_t.interface - - try: - file = input_bundle.load_file(path) - assert isinstance(file, FileInput) # mypy hint - except FileNotFoundError as e: - hint = None - components = module_str.split(".") - # common issue for upgrading codebases from v0.3.x to v0.4.x - - # hint: rename ERC20 to IERC20 - if components[-1].startswith("ERC"): - module_prefix = components[-1] - hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" - raise ModuleNotFound(module_str, hint=hint) from e - - interface_ast = _parse_ast(file) - - with override_global_namespace(Namespace()): - module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True) - - _builtins_cache[path] = file, module_t - return file, module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index be323b1d13..8727f3750d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -1,9 +1,11 @@ import itertools -from typing import Callable, Iterable, List +from typing import Any, Callable, Iterable, List from vyper import ast as vy_ast from vyper.exceptions import ( CompilerPanic, + InstantiationException, + InvalidAttribute, InvalidLiteral, InvalidOperation, InvalidReference, @@ -24,7 +26,7 @@ from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT -from vyper.utils import checksum_encode, int_to_fourbytes +from vyper.utils import OrderedSet, checksum_encode, int_to_fourbytes def _validate_op(node, types_list, validation_fn_name): @@ -39,7 +41,7 @@ def _validate_op(node, types_list, validation_fn_name): try: _validate_fn(node) ret.append(type_) - except InvalidOperation as e: + except (InvalidOperation, OverflowException) as e: err_list.append(e) if ret: @@ -197,7 +199,7 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, (VyperType, TYPE_T)): + if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] @@ -681,3 +683,56 @@ def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> info = get_expr_info(node) return info.modifiability <= modifiability + + +# TODO: move this into part of regular analysis in `local.py` +def get_expr_writes(node: vy_ast.VyperNode) -> OrderedSet[VarAccess]: + if "writes_r" in node._metadata: + return node._metadata["writes_r"] + ret: OrderedSet = OrderedSet() + if isinstance(node, vy_ast.ExprNode) and node._expr_info is not None: + ret = node._expr_info._writes + for c in node._children: + ret |= get_expr_writes(c) + node._metadata["writes_r"] = ret + return ret + + +def validate_kwargs(node: vy_ast.Call, members: dict[str, VyperType], typeclass: str): + # manually validate kwargs for better error messages instead of + # relying on `validate_call_args` + + seen: dict[str, vy_ast.keyword] = {} + membernames = list(members.keys()) + + # check duplicate kwargs + for i, kwarg in enumerate(node.keywords): + # x=5 => kwarg(arg="x", value=Int(5)) + argname = kwarg.arg + if argname in seen: + prev = seen[argname] + raise InvalidAttribute(f"Duplicate {typeclass} argument", prev, kwarg) + seen[argname] = kwarg + + hint: Any # mypy kludge + if argname not in members: + hint = get_levenshtein_error_suggestions(argname, members, 1.0) + raise UnknownAttribute(f"Unknown {typeclass} argument.", kwarg, hint=hint) + + expect_name = membernames[i] + if argname != expect_name: + # out of order key + msg = f"{typeclass} keys are required to be in order, but got" + msg += f" `{argname}` instead of `{expect_name}`." + hint = "as a reminder, the order of the keys in this" + hint += f" {typeclass} are {list(members)}" + raise InvalidAttribute(msg, kwarg, hint=hint) + + expected_type = members[argname] + validate_expected_type(kwarg.value, expected_type) + + missing = OrderedSet(members.keys()) - OrderedSet(seen.keys()) + if len(missing) > 0: + msg = f"{typeclass} instantiation missing fields:" + msg += f" {', '.join(list(missing))}" + raise InstantiationException(msg, node) diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 60477ff1c2..9175e518e1 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -36,7 +36,13 @@ class _Chain(_EnvType): class _Msg(_EnvType): _id = "msg" - _type_members = {"data": BytesT(), "gas": UINT256_T, "sender": AddressT(), "value": UINT256_T} + _type_members = { + "data": BytesT(), + "gas": UINT256_T, + "mana": UINT256_T, + "sender": AddressT(), + "value": UINT256_T, + } class _Tx(_EnvType): diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 59a20dd99f..b881f52b2b 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,8 +1,8 @@ from . import primitives, subscriptable, user from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void from .bytestrings import BytesT, StringT, _BytestringT -from .function import MemberFunctionT -from .module import InterfaceT +from .function import ContractFunctionT, MemberFunctionT +from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 128ede0d5b..aca37b33a3 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -114,8 +114,13 @@ def __eq__(self, other): ) def __lt__(self, other): + # CMC 2024-10-20 what is this for? return self.abi_type.selector_name() < other.abi_type.selector_name() + def __repr__(self): + # TODO: add `pretty()` to the VyperType API? + return self._id + # return a dict suitable for serializing in the AST def to_dict(self): ret = {"name": self._id} @@ -362,10 +367,7 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": raise StructureException(f"{self} instance does not have members", node) hint = get_levenshtein_error_suggestions(key, self.members, 0.3) - raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint) - - def __repr__(self): - return self._id + raise UnknownAttribute(f"{repr(self)} has no member '{key}'.", node, hint=hint) class KwargSettings: diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index cd330681cf..02e3bb213f 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -159,7 +159,7 @@ class BytesT(_BytestringT): typeclass = "bytes" _id = "Bytes" - _valid_literal = (vy_ast.Bytes,) + _valid_literal = (vy_ast.Bytes, vy_ast.HexBytes) @property def abi_type(self) -> ABIType: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index ae913b097c..ffeb5b7299 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -330,7 +330,23 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) if nonreentrant: - raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) + # TODO: refactor so parse_decorators returns the AST location + decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant") + raise FunctionDeclarationException( + "`@nonreentrant` not allowed in interfaces", decorator + ) + + # it's redundant to specify visibility in vyi - always should be external + if function_visibility is None: + function_visibility = FunctionVisibility.EXTERNAL + + if function_visibility != FunctionVisibility.EXTERNAL: + nonexternal = next( + d for d in funcdef.decorator_list if d.id in FunctionVisibility.values() + ) + raise FunctionDeclarationException( + "Interface functions can only be marked as `@external`", nonexternal + ) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -344,7 +360,11 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type = _parse_return_type(funcdef) - if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): + body = funcdef.body + + if len(body) != 1 or not ( + isinstance(body[0], vy_ast.Expr) and isinstance(body[0].value, vy_ast.Ellipsis) + ): raise FunctionDeclarationException( "function body in an interface can only be `...`!", funcdef ) @@ -377,6 +397,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": """ function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) + # it's redundant to specify internal visibility - it's implied by not being external + if function_visibility is None: + function_visibility = FunctionVisibility.INTERNAL + positional_args, keyword_args = _parse_args(funcdef) return_type = _parse_return_type(funcdef) @@ -415,6 +439,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": raise FunctionDeclarationException( "Constructor may not use default arguments", funcdef.args.defaults[0] ) + if nonreentrant: + decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant") + msg = "`@nonreentrant` decorator disallowed on `__init__`" + raise FunctionDeclarationException(msg, decorator) return cls( funcdef.name, @@ -491,6 +519,8 @@ def implements(self, other: "ContractFunctionT") -> bool: if not self.is_external: # pragma: nocover raise CompilerPanic("unreachable!") + assert self.visibility == other.visibility + arguments, return_type = self._iface_sig other_arguments, other_return_type = other._iface_sig @@ -696,7 +726,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, bool]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None nonreentrant_node = None @@ -715,10 +745,6 @@ def _parse_decorators( if nonreentrant_node is not None: raise StructureException("nonreentrant decorator is already set", nonreentrant_node) - if funcdef.name == "__init__": - msg = "`@nonreentrant` decorator disallowed on `__init__`" - raise FunctionDeclarationException(msg, decorator) - nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): @@ -729,6 +755,7 @@ def _parse_decorators( decorator, hint="only one visibility decorator is allowed per function", ) + function_visibility = FunctionVisibility(decorator.id) elif StateMutability.is_valid_value(decorator.id): @@ -751,9 +778,6 @@ def _parse_decorators( else: raise StructureException("Bad decorator syntax", decorator) - if function_visibility is None: - function_visibility = FunctionVisibility.INTERNAL - if state_mutability is None: # default to nonpayable state_mutability = StateMutability.NONPAYABLE @@ -850,7 +874,7 @@ def _id(self): return self.name def __repr__(self): - return f"{self.underlying_type._id} member function '{self.name}'" + return f"{self.underlying_type} member function '{self.name}'" def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: validate_call_args(node, len(self.arg_types)) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index e55c4d145f..498757b94e 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -19,10 +19,10 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.primitives import AddressT -from vyper.semantics.types.user import EventT, StructT, _UserType -from vyper.utils import OrderedSet, sha256sum +from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType +from vyper.utils import OrderedSet if TYPE_CHECKING: from vyper.semantics.analysis.base import ImportInfo, ModuleInfo @@ -45,27 +45,29 @@ def __init__( functions: dict, events: dict, structs: dict, + flags: dict, ) -> None: validate_unique_method_ids(list(functions.values())) - members = functions | events | structs + members = functions | events | structs | flags # sanity check: by construction, there should be no duplicates. - assert len(members) == len(functions) + len(events) + len(structs) + assert len(members) == len(functions) + len(events) + len(structs) + len(flags) super().__init__(functions) - self._helper = VyperType(events | structs) + self._helper = VyperType(events | structs | flags) self._id = _id self._helper._id = _id self.functions = functions self.events = events self.structs = structs + self.flags = flags self.decl_node = decl_node def get_type_member(self, attr, node): - # get an event or struct from this interface + # get an event, struct or flag from this interface return TYPE_T(self._helper.get_member(attr, node)) @property @@ -76,6 +78,9 @@ def getter_signature(self): def abi_type(self) -> ABIType: return ABI_Address() + def __str__(self): + return self._id + def __repr__(self): return f"interface {self._id}" @@ -107,6 +112,7 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia def validate_implements( self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode] ) -> None: + # only external functions can implement interfaces fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()} unimplemented = [] @@ -116,7 +122,9 @@ def _is_function_implemented(fn_name, fn_type): return False to_compare = fns_by_name[fn_name] + assert to_compare.is_external assert isinstance(to_compare, ContractFunctionT) + assert isinstance(fn_type, ContractFunctionT) return to_compare.implements(fn_type) @@ -153,12 +161,14 @@ def _from_lists( interface_name: str, decl_node: Optional[vy_ast.VyperNode], function_list: list[tuple[str, ContractFunctionT]], - event_list: list[tuple[str, EventT]], - struct_list: list[tuple[str, StructT]], + event_list: Optional[list[tuple[str, EventT]]] = None, + struct_list: Optional[list[tuple[str, StructT]]] = None, + flag_list: Optional[list[tuple[str, FlagT]]] = None, ) -> "InterfaceT": - functions = {} - events = {} - structs = {} + functions: dict[str, ContractFunctionT] = {} + events: dict[str, EventT] = {} + structs: dict[str, StructT] = {} + flags: dict[str, FlagT] = {} seen_items: dict = {} @@ -169,19 +179,20 @@ def _mark_seen(name, item): raise NamespaceCollision(msg, item.decl_node, prev_decl=prev_decl) seen_items[name] = item - for name, function in function_list: - _mark_seen(name, function) - functions[name] = function + def _process(dst_dict, items): + if items is None: + return - for name, event in event_list: - _mark_seen(name, event) - events[name] = event + for name, item in items: + _mark_seen(name, item) + dst_dict[name] = item - for name, struct in struct_list: - _mark_seen(name, struct) - structs[name] = struct + _process(functions, function_list) + _process(events, event_list) + _process(structs, struct_list) + _process(flags, flag_list) - return cls(interface_name, decl_node, functions, events, structs) + return cls(interface_name, decl_node, functions, events, structs, flags) @classmethod def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": @@ -208,8 +219,7 @@ def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": for item in [i for i in abi if i.get("type") == "event"]: events.append((item["name"], EventT.from_abi(item))) - structs: list = [] # no structs in json ABI (as of yet) - return cls._from_lists(name, None, functions, events, structs) + return cls._from_lists(name, None, functions, events) @classmethod def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": @@ -230,9 +240,6 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for fn_t in module_t.exposed_functions: funcs.append((fn_t.name, fn_t)) - if (fn_t := module_t.init_function) is not None: - funcs.append((fn_t.name, fn_t)) - event_set: OrderedSet[EventT] = OrderedSet() event_set.update([node._metadata["event_type"] for node in module_t.event_defs]) event_set.update(module_t.used_events) @@ -241,8 +248,9 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": # these are accessible via import, but they do not show up # in the ABI json structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] + flags = [(node.name, node._metadata["flag_type"]) for node in module_t.flag_defs] - return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs) + return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs, flags) @classmethod def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": @@ -259,11 +267,20 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": ) functions.append((func_ast.name, ContractFunctionT.from_InterfaceDef(func_ast))) - # no structs or events in InterfaceDefs - events: list = [] - structs: list = [] + return cls._from_lists(node.name, node, functions) - return cls._from_lists(node.name, node, functions, events, structs) + +def _module_at(module_t): + return MemberFunctionT( + # set underlying_type to a TYPE_T as a bit of a kludge, since it's + # kind of like a class method (but we don't have classmethod + # abstraction) + underlying_type=TYPE_T(module_t), + name="__at__", + arg_types=[AddressT()], + return_type=module_t.interface, + is_modifying=False, + ) # Datatype to store all module information. @@ -323,16 +340,28 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.import_stmts: import_info = i._metadata["import_info"] - self.add_member(import_info.alias, import_info.typ) if hasattr(import_info.typ, "module_t"): - self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + module_info = import_info.typ + # get_expr_info uses ModuleInfo + self.add_member(import_info.alias, module_info) + # type_from_annotation uses TYPE_T + self._helper.add_member(import_info.alias, TYPE_T(module_info.module_t)) + else: # interfaces + assert isinstance(import_info.typ, InterfaceT) + self.add_member(import_info.alias, TYPE_T(import_info.typ)) for name, interface_t in self.interfaces.items(): # can access interfaces in type position self._helper.add_member(name, TYPE_T(interface_t)) - self.add_member("__interface__", self.interface) + # module.__at__(addr) + self.add_member("__at__", _module_at(self)) + + # allow `module.__interface__` (in exports declarations) + self.add_member("__interface__", TYPE_T(self.interface)) + # allow `module.__interface__` (in type position) + self._helper.add_member("__interface__", TYPE_T(self.interface)) # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, @@ -431,21 +460,6 @@ def reachable_imports(self) -> list["ImportInfo"]: return ret - @cached_property - def integrity_sum(self) -> str: - acc = [sha256sum(self._module.full_source_code)] - for s in self.import_stmts: - info = s._metadata["import_info"] - - if isinstance(info.typ, InterfaceT): - # NOTE: this needs to be redone if interfaces can import other interfaces - acc.append(info.compiler_input.sha256sum) - else: - assert isinstance(info.typ.typ, ModuleT) - acc.append(info.typ.typ.integrity_sum) - - return sha256sum("".join(acc)) - def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: for s in self.imported_modules.values(): if s.module_t == needle: diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index eea58c6c68..dcc4fe8c8e 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -173,11 +173,11 @@ def _get_lr(): if isinstance(left, vy_ast.Int): if left.value >= 2**value_bits: raise OverflowException( - "Base is too large, calculation will always overflow", left + f"Base is too large for {self}, calculation will always overflow", left ) elif left.value < -(2**value_bits): raise OverflowException( - "Base is too small, calculation will always underflow", left + f"Base is too small for {self}, calculation will always underflow", left ) elif isinstance(right, vy_ast.Int): if right.value < 0: @@ -211,9 +211,16 @@ def _add_div_hint(node, e): else: return e + def _get_source(node): + source = node.node_source_code + if isinstance(node, vy_ast.BinOp): + # parenthesize, to preserve precedence + return f"({source})" + return source + if isinstance(node, vy_ast.BinOp): - e._hint = f"did you mean `{node.left.node_source_code} " - e._hint += f"{suggested} {node.right.node_source_code}`?" + e._hint = f"did you mean `{_get_source(node.left)} " + e._hint += f"{suggested} {_get_source(node.right)}`?" elif isinstance(node, vy_ast.AugAssign): e._hint = f"did you mean `{node.target.node_source_code} " e._hint += f"{suggested}= {node.value.node_source_code}`?" diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ca8e99bc92..d01ab23299 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -7,21 +7,23 @@ from vyper.exceptions import ( EventDeclarationException, FlagDeclarationException, - InvalidAttribute, + InstantiationException, NamespaceCollision, StructureException, UnfoldableNode, - UnknownAttribute, VariableDeclarationException, ) from vyper.semantics.analysis.base import Modifiability -from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type +from vyper.semantics.analysis.utils import ( + check_modifiability, + validate_expected_type, + validate_kwargs, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT from vyper.semantics.types.utils import type_from_abi, type_from_annotation -from vyper.utils import keccak256 +from vyper.utils import keccak256, vyper_warn # user defined type @@ -75,6 +77,9 @@ def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": self._helper.get_member(key, node) return self + def __str__(self): + return f"{self.name}" + def __repr__(self): arg_types = ",".join(repr(a) for a in self._flag_members) return f"flag {self.name}({arg_types})" @@ -281,6 +286,25 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": return cls(base_node.name, members, indexed, base_node) def _ctor_call_return(self, node: vy_ast.Call) -> None: + # validate keyword arguments if provided + if len(node.keywords) > 0: + if len(node.args) > 0: + raise InstantiationException( + "Event instantiation requires either all keyword arguments " + "or all positional arguments", + node, + ) + + return validate_kwargs(node, self.arguments, self.typeclass) + + # warn about positional argument depreciation + msg = "Instantiating events with positional arguments is " + msg += "deprecated as of v0.4.1 and will be disallowed " + msg += "in a future release. Use kwargs instead eg. " + msg += "Foo(a=1, b=2)" + + vyper_warn(msg, node) + validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): validate_expected_type(arg, expected) @@ -415,31 +439,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": "Struct contains a mapping and so cannot be declared as a literal", node ) - # manually validate kwargs for better error messages instead of - # relying on `validate_call_args` - members = self.member_types.copy() - keys = list(self.member_types.keys()) - for i, kwarg in enumerate(node.keywords): - # x=5 => kwarg(arg="x", value=Int(5)) - argname = kwarg.arg - if argname not in members: - hint = get_levenshtein_error_suggestions(argname, members, 1.0) - raise UnknownAttribute("Unknown or duplicate struct member.", kwarg, hint=hint) - expected = keys[i] - if argname != expected: - raise InvalidAttribute( - "Struct keys are required to be in order, but got " - f"`{argname}` instead of `{expected}`. (Reminder: the " - f"keys in this struct are {list(self.member_types.items())})", - kwarg, - ) - expected_type = members.pop(argname) - validate_expected_type(kwarg.value, expected_type) - - if members: - raise VariableDeclarationException( - f"Struct declaration does not define all fields: {', '.join(list(members))}", node - ) + validate_kwargs(node, self.member_types, self.typeclass) return self diff --git a/vyper/typing.py b/vyper/typing.py index ad3964dff9..108c0605bb 100644 --- a/vyper/typing.py +++ b/vyper/typing.py @@ -1,7 +1,6 @@ from typing import Dict, Optional, Sequence, Tuple, Union # Parser -ModificationOffsets = Dict[Tuple[int, int], str] ParserPosition = Tuple[int, int] # Compiler diff --git a/vyper/utils.py b/vyper/utils.py index 2b95485f4e..39d3093478 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -4,13 +4,14 @@ import enum import functools import hashlib +import os import sys import time import traceback import warnings -from typing import Generic, List, TypeVar, Union +from typing import Generic, Iterable, Iterator, List, Set, TypeVar, Union -from vyper.exceptions import CompilerPanic, DecimalOverrideException, InvalidLiteral, VyperException +from vyper.exceptions import CompilerPanic, DecimalOverrideException, VyperException _T = TypeVar("_T") @@ -25,9 +26,10 @@ class OrderedSet(Generic[_T]): """ def __init__(self, iterable=None): - self._data = dict() - if iterable is not None: - self.update(iterable) + if iterable is None: + self._data = dict() + else: + self._data = dict.fromkeys(iterable) def __repr__(self): keys = ", ".join(repr(k) for k in self) @@ -36,6 +38,9 @@ def __repr__(self): def __iter__(self): return iter(self._data) + def __reversed__(self): + return reversed(self._data) + def __contains__(self, item): return self._data.__contains__(item) @@ -45,12 +50,16 @@ def __len__(self): def first(self): return next(iter(self)) + def last(self): + return next(reversed(self)) + def pop(self): return self._data.popitem()[0] def add(self, item: _T) -> None: self._data[item] = None + # NOTE to refactor: duplicate of self.update() def addmany(self, iterable): for item in iterable: self._data[item] = None @@ -80,6 +89,7 @@ def update(self, other): def union(self, other): return self | other + # set dunders def __ior__(self, other): self.update(other) return self @@ -92,6 +102,15 @@ def __or__(self, other): def __eq__(self, other): return self._data == other._data + def __isub__(self, other): + self.dropmany(other) + return self + + def __sub__(self, other): + ret = self.copy() + ret.dropmany(other) + return ret + def copy(self): cls = self.__class__ ret = cls.__new__(cls) @@ -103,11 +122,25 @@ def intersection(cls, *sets): if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - ret = sets[0].copy() - for e in sets[0]: - if any(e not in s for s in sets[1:]): - ret.remove(e) - return ret + tmp = sets[0]._data.keys() + for s in sets[1:]: + tmp &= s._data.keys() + + return cls(tmp) + + +def uniq(seq: Iterable[_T]) -> Iterator[_T]: + """ + Yield unique items in ``seq`` in original sequence order. + """ + seen: Set[_T] = set() + + for x in seq: + if x in seen: + continue + + seen.add(x) + yield x class StringEnum(enum.Enum): @@ -215,6 +248,13 @@ def int_to_fourbytes(n: int) -> bytes: return n.to_bytes(4, byteorder="big") +def wrap256(val: int, signed=False) -> int: + ret = val % (2**256) + if signed: + ret = unsigned_to_signed(ret, 256, strict=True) + return ret + + def signed_to_unsigned(int_, bits, strict=False): """ Reinterpret a signed integer with n bits as an unsigned integer. @@ -224,7 +264,7 @@ def signed_to_unsigned(int_, bits, strict=False): """ if strict: lo, hi = int_bounds(signed=True, bits=bits) - assert lo <= int_ <= hi + assert lo <= int_ <= hi, int_ if int_ < 0: return int_ + 2**bits return int_ @@ -239,7 +279,7 @@ def unsigned_to_signed(int_, bits, strict=False): """ if strict: lo, hi = int_bounds(signed=False, bits=bits) - assert lo <= int_ <= hi + assert lo <= int_ <= hi, int_ if int_ > (2 ** (bits - 1)) - 1: return int_ - (2**bits) return int_ @@ -291,17 +331,6 @@ def round_towards_zero(d: decimal.Decimal) -> int: return int(d.to_integral_exact(decimal.ROUND_DOWN)) -# Converts string to bytes -def string_to_bytes(str): - bytez = b"" - for c in str: - if ord(c) >= 256: - raise InvalidLiteral(f"Cannot insert special character {c} into byte array") - bytez += bytes([ord(c)]) - bytez_length = len(bytez) - return bytez, bytez_length - - # Converts a provided hex string to an integer def hex_to_int(inp): if inp[:2] == "0x": @@ -372,6 +401,11 @@ def evm_twos_complement(x: int) -> int: return ((2**256 - 1) ^ x) + 1 +def evm_not(val: int) -> int: + assert 0 <= val <= SizeLimits.MAX_UINT256, "Value out of bounds" + return SizeLimits.MAX_UINT256 ^ val + + # EVM div semantics as a python function def evm_div(x, y): if y == 0: @@ -500,20 +534,79 @@ def indent(text: str, indent_chars: Union[str, List[str]] = " ", level: int = 1) @contextlib.contextmanager -def timeit(msg): +def timeit(msg): # pragma: nocover start_time = time.perf_counter() yield end_time = time.perf_counter() total_time = end_time - start_time - print(f"{msg}: Took {total_time:.4f} seconds") + print(f"{msg}: Took {total_time:.6f} seconds", file=sys.stderr) + + +_CUMTIMES = None + + +def _dump_cumtime(): # pragma: nocover + global _CUMTIMES + for msg, total_time in _CUMTIMES.items(): + print(f"{msg}: Cumulative time {total_time:.3f} seconds", file=sys.stderr) @contextlib.contextmanager -def timer(msg): - t0 = time.time() +def cumtimeit(msg): # pragma: nocover + import atexit + from collections import defaultdict + + global _CUMTIMES + + if _CUMTIMES is None: + warnings.warn("timing code, disable me before pushing!", stacklevel=2) + _CUMTIMES = defaultdict(int) + atexit.register(_dump_cumtime) + + start_time = time.perf_counter() yield - t1 = time.time() - print(f"{msg} took {t1 - t0}s") + end_time = time.perf_counter() + total_time = end_time - start_time + _CUMTIMES[msg] += total_time + + +_PROF = None + + +def _dump_profile(): # pragma: nocover + global _PROF + + _PROF.disable() # don't profile dumping stats + _PROF.dump_stats("stats") + + from pstats import Stats + + stats = Stats("stats", stream=sys.stderr) + stats.sort_stats("time") + stats.print_stats() + + +@contextlib.contextmanager +def profileit(): # pragma: nocover + """ + Helper function for local dev use, is not intended to ever be run in + production build + """ + import atexit + from cProfile import Profile + + global _PROF + if _PROF is None: + warnings.warn("profiling code, disable me before pushing!", stacklevel=2) + _PROF = Profile() + _PROF.disable() + atexit.register(_dump_profile) + + try: + _PROF.enable() + yield + finally: + _PROF.disable() def annotate_source_code( @@ -591,3 +684,12 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) + + +def safe_relpath(path): + try: + return os.path.relpath(path) + except ValueError: + # on Windows, if path and curdir are on different drives, an exception + # can be thrown + return path diff --git a/vyper/venom/README.md b/vyper/venom/README.md index 5d98b22dd6..964f52b524 100644 --- a/vyper/venom/README.md +++ b/vyper/venom/README.md @@ -29,59 +29,43 @@ Venom employs two scopes: global and function level. ### Example code ```llvm -IRFunction: global - -global: - %1 = calldataload 0 - %2 = shr 224, %1 - jmp label %selector_bucket_0 - -selector_bucket_0: - %3 = xor %2, 1579456981 - %4 = iszero %3 - jnz label %1, label %2, %4 - -1: IN=[selector_bucket_0] OUT=[9] - jmp label %fallback - -2: - %5 = callvalue - %6 = calldatasize - %7 = lt %6, 164 - %8 = or %5, %7 - %9 = iszero %8 - assert %9 - stop - -fallback: - revert 0, 0 +function global { + global: + %1 = calldataload 0 + %2 = shr 224, %1 + jmp @selector_bucket_0 + + selector_bucket_0: + %3 = xor %2, 1579456981 + %4 = iszero %3 + jnz @1, @2, %4 + + 1: + jmp @fallback + + 2: + %5 = callvalue + %6 = calldatasize + %7 = lt %6, 164 + %8 = or %5, %7 + %9 = iszero %8 + assert %9 + stop + + fallback: + revert 0, 0 +} + +[data] ``` ### Grammar -Below is a (not-so-complete) grammar to describe the text format of Venom IR: +To see a definition of grammar see the [venom parser](./parser.py) -```llvm -program ::= function_declaration* - -function_declaration ::= "IRFunction:" identifier input_list? output_list? "=>" block - -input_list ::= "IN=" "[" (identifier ("," identifier)*)? "]" -output_list ::= "OUT=" "[" (identifier ("," identifier)*)? "]" - -block ::= label ":" input_list? output_list? "=>{" operation* "}" - -operation ::= "%" identifier "=" opcode operand ("," operand)* - | opcode operand ("," operand)* +### Compiling Venom -opcode ::= "calldataload" | "shr" | "shl" | "and" | "add" | "codecopy" | "mload" | "jmp" | "xor" | "iszero" | "jnz" | "label" | "lt" | "or" | "assert" | "callvalue" | "calldatasize" | "alloca" | "calldatacopy" | "invoke" | "gt" | ... - -operand ::= "%" identifier | label | integer | "label" "%" identifier -label ::= "%" identifier - -identifier ::= [a-zA-Z_][a-zA-Z0-9_]* -integer ::= [0-9]+ -``` +Vyper ships with a venom compiler which compiles venom code to bytecode directly. It can be run by running `venom`, which is installed as a standalone binary when `vyper` is installed via `pip`. ## Implementation @@ -160,3 +144,317 @@ A number of passes that are planned to be implemented, or are implemented for im ### Function inlining ### Load-store elimination + +--- + +## Structure of a venom program + +### IRContext +An `IRContext` consists of multiple `IRFunctions`, with one designated as the main entry point of the program. +Additionally, the `IRContext` maintains its own representation of the data segment. + +### IRFunction +An `IRFunction` is composed of a name and multiple `IRBasicBlocks`, with one marked as the entry point to the function. + +### IRBasicBlock +An `IRBasicBlock` contains a label and a sequence of `IRInstructions`. +Each `IRBasicBlock` has a single entry point and exit point. +The exit point must be one of the following terminator instructions: +- `jmp` +- `djmp` +- `jnz` +- `ret` +- `return` +- `stop` +- `exit` + +Normalized basic blocks cannot have multiple predecessors and successors. It has either one (or zero) predecessors and potentially multiple successors or vice versa. + +### IRInstruction +An `IRInstruction` consists of an opcode, a list of operands, and an optional return value. +An operand can be a label, a variable, or a literal. + +By convention, variables have a `%-` prefix, e.g. `%1` is a valid variable. However, the prefix is not required. + +## Instructions +To enable Venom IR in Vyper, use the `--experimental-codegen` CLI flag or its alias `--venom`, or the corresponding pragma statements (e.g. `#pragma experimental-codegen`). To view the Venom IR output, use `-f bb_runtime` for the runtime code, or `-f bb` to see the deploy code. To get a dot file (for use e.g. with `xdot -`), use `-f cfg` or `-f cfg_runtime`. + +Assembly can be inspected with `-f asm`, whereas an opcode view of the final bytecode can be seen with `-f opcodes` or `-f opcodes_runtime`, respectively. + +### Special instructions + +- `invoke` + - ``` + invoke offset, label + ``` + - Causes control flow to jump to a function denoted by the `label`. + - Return values are passed in the return buffer at the `offset` address. + - Used for internal functions. + - Effectively translates to `JUMP`, and marks the call site as a valid return destination (for callee to jump back to) by `JUMPDEST`. +- `alloca` + - ``` + out = alloca size, offset, id + ``` + - Allocates memory of a given `size` at a given `offset` in memory. + - The `id` argument is there to help debugging translation into venom + - The output is the offset value itself. + - Because the SSA form does not allow changing values of registers, handling mutable variables can be tricky. The `alloca` instruction is meant to simplify that. + +- `palloca` + - ``` + out = palloca size, offset, id + ``` + - Like the `alloca` instruction but only used for parameters of internal functions which are passed by memory. +- `iload` + - ``` + out = iload offset + ``` + - Loads value at an immutable section of memory denoted by `offset` into `out` variable. + - The operand can be either a literal, which is a statically computed offset, or a variable. + - Essentially translates to `MLOAD` on an immutable section of memory. So, for example + ``` + %op = 12 + %out = iload %op + ``` + could compile into `PUSH1 12 _mem_deploy_end ADD MLOAD`. + - When `offset` is a literal the location is computed statically during compilation from assembly to bytecode. +- `istore` + - ``` + istore offset value + ``` + - Represents a store into immutable section of memory. + - Like in `iload`, the offset operand can be a literal. + - Essentially translates to `MSTORE` on an immutable section of memory. For example, + ``` + %op = 12 + istore 24 %op + ``` + could compile to + `PUSH1 12 PUSH1 24 _mem_deploy_end ADD MSTORE`. +- `phi` + - ``` + out = phi %var_a, label_a, %var_b, label_b + ``` + - Because in SSA form each variable is assigned just once, it is tricky to handle that variables may be assigned to something different based on which program path was taken. + - Therefore, we use `phi` instructions. They are are magic instructions, used in basic blocks where the control flow path merges. + - In this example, essentially the `out` variable is set to `%var_a` if the program entered the current block from `label_a` or to `%var_b` when it went through `label_b`. +- `offset` + - ``` + ret = offset label, op + ``` + - Statically compute offset before compiling into bytecode. Useful for `mstore`, `mload` and such. + - Basically `label` + `op`. + - The `asm` output could show something like `_OFST _sym_ label`. +- `param` + - ``` + out = param + ``` + - The `param` instruction is used to represent function arguments passed by the stack. + - We assume the argument is on the stack and the `param` instruction is used to ensure we represent the argument by the `out` variable. +- `store` + - ``` + out = op + ``` + - Store variable value or literal into `out` variable. +- `dbname` + - ``` + dbname label + ``` + - Mark memory with a `label` in the data segment so it can be referenced. +- `db` + - ``` + db data + ``` + - Store `data` into data segment. +- `dloadbytes` + - Alias for `codecopy` for legacy reasons. May be removed in future versions. + - Translates to `CODECOPY`. +- `ret` + - ``` + ret op + ``` + - Represents return from an internal call. + - Jumps to a location given by `op`. + - If `op` is a label it can effectively translate into `op JUMP`. +- `exit` + - ``` + exit + ``` + - Similar to `stop`, but used for constructor exit. The assembler is expected to jump to a special initcode sequence which returns the runtime code. + - Might translate to something like `_sym__ctor_exit JUMP`. +- `sha3_64` + - ``` + out = sha3_64 x y + ``` + - Shortcut to access the `SHA3` EVM opcode where `out` is the result. + - Essentially translates to + ``` + PUSH y PUSH FREE_VAR_SPACE MSTORE + PUSH x PUSH FREE_VAR_SPACE2 MSTORE + PUSH 64 PUSH FREE_VAR_SPACE SHA3 + ``` + where `FREE_VAR_SPACE` and `FREE_VAR_SPACE2` are locations reserved by the compiler, set to 0 and 32 respectively. + +- `assert` + - ``` + assert op + ``` + - Assert that `op` is zero. If it is not, revert. + - Calls that terminate this way receive a gas refund. + - For example + ``` + %op = 13 + assert %op + ``` + could compile to + `PUSH1 13 ISZERO _sym___revert JUMPI`. +- `assert_unreachable` + - ``` + assert_unreachable op + ``` + - Check that `op` is zero. If it is not, terminate with `0xFE` ("INVALID" opcode). + - Calls that end this way do not receive a gas refund. + - Could translate to `op reachable JUMPI INVALID reachable JUMPDEST`. + - For example + ``` + %op = 13 + assert_unreachable %op + ``` + could compile to + ``` + PUSH1 13 _sym_reachable1 JUMPI + INVALID + _sym_reachable1 JUMPDEST + ``` +- `log` + - ``` + log offset, size, [topic] * topic_count , topic_count + ``` + - Corresponds to the `LOGX` instruction in EVM. + - Depending on the `topic_count` value (which can be only from 0 to 4) translates to `LOG0` ... `LOG4`. + - The rest of the operands correspond to the `LOGX` instructions. + - For example + ``` + log %53, 32, 64, %56, 2 + ``` + could translate to: + ``` + %56, 64, 32, %53 LOG2 + ``` +- `nop` + - ``` + nop + ``` + - No operation, does nothing. +- `offset` + - ``` + %2 = offset %1 label1 + - Similar to `add`, but takes a label as the second argument. If the first argument is a literal, the addition will get optimized at assembly time. + +### Jump instructions + +- `jmp` + - ``` + jmp label + ``` + - Unconditional jump to code denoted by given `label`. + - Translates to `label JUMP`. +- `jnz` + - ``` + jnz label1, label2, op + ``` + - A conditional jump depending on the value of `op`. + - Jumps to `label2` when `op` is not zero, otherwise jumps to `label1`. + - For example + ``` + %op = 15 + jnz label1, label2, %op + ``` + could translate to: `PUSH1 15 label2 JUMPI label1 JUMP`. +- `djmp` + - ``` + djmp %var, label1, label2, label3, ... + ``` + - Dynamic jump to an address specified by the variable operand, constrained to the provided labels. + - Accepts a variable number of labels. + - The target is not a fixed label but rather a value stored in a variable, making the jump dynamic. + - The jump target can be any of the provided labels. + - Translates to `JUMP`. + +### EVM instructions + +The following instructions map one-to-one with [EVM instructions](https://www.evm.codes/). +Operands correspond to stack inputs in the same order. Stack outputs are the instruction's output. +Instructions have the same effects. +- `return` +- `revert` +- `coinbase` +- `calldatasize` +- `calldatacopy` +- `mcopy` +- `calldataload` +- `gas` +- `gasprice` +- `gaslimit` +- `chainid` +- `address` +- `origin` +- `number` +- `extcodesize` +- `extcodehash` +- `extcodecopy` +- `returndatasize` +- `returndatacopy` +- `callvalue` +- `selfbalance` +- `sload` +- `sstore` +- `mload` +- `mstore` +- `tload` +- `tstore` +- `timestamp` +- `caller` +- `blockhash` +- `selfdestruct` +- `signextend` +- `stop` +- `shr` +- `shl` +- `sar` +- `and` +- `xor` +- `or` +- `add` +- `sub` +- `mul` +- `div` +- `smul` +- `sdiv` +- `mod` +- `smod` +- `exp` +- `addmod` +- `mulmod` +- `eq` +- `iszero` +- `not` +- `lt` +- `gt` +- `slt` +- `sgt` +- `create` +- `create2` +- `msize` +- `balance` +- `call` +- `staticcall` +- `delegatecall` +- `codesize` +- `basefee` +- `blobhash` +- `blobbasefee` +- `prevrandao` +- `difficulty` +- `invalid` +- `sha3` diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index afd79fc44f..bb3fe58a8d 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -9,16 +9,23 @@ from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom -from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass -from vyper.venom.passes.branch_optimization import BranchOptimizationPass -from vyper.venom.passes.dft import DFTPass -from vyper.venom.passes.extract_literals import ExtractLiteralsPass -from vyper.venom.passes.make_ssa import MakeSSA -from vyper.venom.passes.mem2var import Mem2Var -from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass -from vyper.venom.passes.sccp import SCCP -from vyper.venom.passes.simplify_cfg import SimplifyCFGPass -from vyper.venom.passes.store_elimination import StoreElimination +from vyper.venom.passes import ( + SCCP, + AlgebraicOptimizationPass, + BranchOptimizationPass, + DFTPass, + FloatAllocas, + LoadElimination, + LowerDloadPass, + MakeSSA, + Mem2Var, + MemMergePass, + ReduceLiteralsCodesize, + RemoveUnusedVariablesPass, + SimplifyCFGPass, + StoreElimination, + StoreExpansionPass, +) from vyper.venom.venom_to_assembly import VenomCompiler DEFAULT_OPT_LEVEL = OptimizationLevel.default() @@ -45,24 +52,56 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: ac = IRAnalysesCache(fn) + FloatAllocas(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() + # run algebraic opts before mem2var to reduce some pointer arithmetic + AlgebraicOptimizationPass(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() Mem2Var(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() SCCP(ac, fn).run_pass() - StoreElimination(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() AlgebraicOptimizationPass(ac, fn).run_pass() + LoadElimination(ac, fn).run_pass() + SCCP(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() + + SimplifyCFGPass(ac, fn).run_pass() + MemMergePass(ac, fn).run_pass() + + LowerDloadPass(ac, fn).run_pass() + # NOTE: MakeSSA is after algebraic optimization it currently produces + # smaller code by adding some redundant phi nodes. This is not a + # problem for us, but we need to be aware of it, and should be + # removed when the dft pass is fixed to produce the smallest code + # without making the code generation more expensive by running + # MakeSSA again. + MakeSSA(ac, fn).run_pass() BranchOptimizationPass(ac, fn).run_pass() - ExtractLiteralsPass(ac, fn).run_pass() + + AlgebraicOptimizationPass(ac, fn).run_pass() RemoveUnusedVariablesPass(ac, fn).run_pass() + + StoreExpansionPass(ac, fn).run_pass() + + if optimize == OptimizationLevel.CODESIZE: + ReduceLiteralsCodesize(ac, fn).run_pass() + DFTPass(ac, fn).run_pass() +def run_passes_on(ctx: IRContext, optimize: OptimizationLevel): + for fn in ctx.functions.values(): + _run_passes(fn, optimize) + + def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> IRContext: # Convert "old" IR to "new" IR ctx = ir_node_to_venom(ir) - for fn in ctx.functions.values(): - _run_passes(fn, optimize) + run_passes_on(ctx, optimize) return ctx diff --git a/vyper/venom/analysis/__init__.py b/vyper/venom/analysis/__init__.py index e69de29bb2..4870de3fb7 100644 --- a/vyper/venom/analysis/__init__.py +++ b/vyper/venom/analysis/__init__.py @@ -0,0 +1,6 @@ +from .analysis import IRAnalysesCache, IRAnalysis +from .cfg import CFGAnalysis +from .dfg import DFGAnalysis +from .dominators import DominatorTreeAnalysis +from .equivalent_vars import VarEquivalenceAnalysis +from .liveness import LivenessAnalysis diff --git a/vyper/venom/analysis/analysis.py b/vyper/venom/analysis/analysis.py index f154993925..7bff6ba555 100644 --- a/vyper/venom/analysis/analysis.py +++ b/vyper/venom/analysis/analysis.py @@ -50,9 +50,9 @@ def request_analysis(self, analysis_cls: Type[IRAnalysis], *args, **kwargs): if analysis_cls in self.analyses_cache: return self.analyses_cache[analysis_cls] analysis = analysis_cls(self, self.function) + self.analyses_cache[analysis_cls] = analysis analysis.analyze(*args, **kwargs) - self.analyses_cache[analysis_cls] = analysis return analysis def invalidate_analysis(self, analysis_cls: Type[IRAnalysis]): diff --git a/vyper/venom/analysis/cfg.py b/vyper/venom/analysis/cfg.py index bd2ae34b68..2f90410cd5 100644 --- a/vyper/venom/analysis/cfg.py +++ b/vyper/venom/analysis/cfg.py @@ -1,6 +1,8 @@ +from typing import Iterator + from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysis -from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS +from vyper.venom.analysis import IRAnalysis +from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS, IRBasicBlock class CFGAnalysis(IRAnalysis): @@ -8,32 +10,53 @@ class CFGAnalysis(IRAnalysis): Compute control flow graph information for each basic block in the function. """ + _dfs: OrderedSet[IRBasicBlock] + def analyze(self) -> None: fn = self.function + self._dfs = OrderedSet() + for bb in fn.get_basic_blocks(): bb.cfg_in = OrderedSet() bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() + bb.is_reachable = False for bb in fn.get_basic_blocks(): - assert len(bb.instructions) > 0, "Basic block should not be empty" - last_inst = bb.instructions[-1] - assert last_inst.is_bb_terminator, f"Last instruction should be a terminator {bb}" + assert bb.is_terminated, f"not terminating:\n{bb}" - for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_INSTRUCTIONS: - ops = inst.get_label_operands() - for op in ops: - fn.get_basic_block(op.value).add_cfg_in(bb) + term = bb.instructions[-1] + if term.opcode in CFG_ALTERING_INSTRUCTIONS: + ops = term.get_label_operands() + # order of cfg_out matters to performance! + for op in reversed(list(ops)): + next_bb = fn.get_basic_block(op.value) + bb.add_cfg_out(next_bb) + next_bb.add_cfg_in(bb) - # Fill in the "out" set for each basic block - for bb in fn.get_basic_blocks(): - for in_bb in bb.cfg_in: - in_bb.add_cfg_out(bb) + self._compute_dfs_r(self.function.entry) + + def _compute_dfs_r(self, bb): + if bb.is_reachable: + return + bb.is_reachable = True + + for out_bb in bb.cfg_out: + self._compute_dfs_r(out_bb) + + self._dfs.add(bb) + + @property + def dfs_walk(self) -> Iterator[IRBasicBlock]: + return iter(self._dfs) def invalidate(self): - from vyper.venom.analysis.dominators import DominatorTreeAnalysis - from vyper.venom.analysis.liveness import LivenessAnalysis + from vyper.venom.analysis import DFGAnalysis, DominatorTreeAnalysis, LivenessAnalysis self.analyses_cache.invalidate_analysis(DominatorTreeAnalysis) self.analyses_cache.invalidate_analysis(LivenessAnalysis) + + self._dfs = None + + # be conservative - assume cfg invalidation invalidates dfg + self.analyses_cache.invalidate_analysis(DFGAnalysis) diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 328ed47c72..e528284422 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -1,13 +1,14 @@ from typing import Optional +from vyper.utils import OrderedSet from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis from vyper.venom.analysis.liveness import LivenessAnalysis -from vyper.venom.basicblock import IRInstruction, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction class DFGAnalysis(IRAnalysis): - _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_inputs: dict[IRVariable, OrderedSet[IRInstruction]] _dfg_outputs: dict[IRVariable, IRInstruction] def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): @@ -16,19 +17,28 @@ def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): self._dfg_outputs = dict() # return uses of a given variable - def get_uses(self, op: IRVariable) -> list[IRInstruction]: - return self._dfg_inputs.get(op, []) + def get_uses(self, op: IRVariable) -> OrderedSet[IRInstruction]: + return self._dfg_inputs.get(op, OrderedSet()) + + def get_uses_in_bb(self, op: IRVariable, bb: IRBasicBlock): + """ + Get uses of a given variable in a specific basic block. + """ + return [inst for inst in self.get_uses(op) if inst.parent == bb] # the instruction which produces this variable. def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: return self._dfg_outputs.get(op) + def set_producing_instruction(self, op: IRVariable, inst: IRInstruction): + self._dfg_outputs[op] = inst + def add_use(self, op: IRVariable, inst: IRInstruction): - uses = self._dfg_inputs.setdefault(op, []) - uses.append(inst) + uses = self._dfg_inputs.setdefault(op, OrderedSet()) + uses.add(inst) def remove_use(self, op: IRVariable, inst: IRInstruction): - uses = self._dfg_inputs.get(op, []) + uses: OrderedSet = self._dfg_inputs.get(op, OrderedSet()) uses.remove(inst) @property @@ -48,10 +58,11 @@ def analyze(self): res = inst.get_outputs() for op in operands: - inputs = self._dfg_inputs.setdefault(op, []) - inputs.append(inst) + inputs = self._dfg_inputs.setdefault(op, OrderedSet()) + inputs.add(inst) for op in res: # type: ignore + assert isinstance(op, IRVariable) self._dfg_outputs[op] = inst def as_graph(self) -> str: diff --git a/vyper/venom/analysis/dominators.py b/vyper/venom/analysis/dominators.py index 129d1d0f22..b60f9bdab9 100644 --- a/vyper/venom/analysis/dominators.py +++ b/vyper/venom/analysis/dominators.py @@ -1,7 +1,8 @@ +from functools import cached_property + from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysis -from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis import CFGAnalysis, IRAnalysis from vyper.venom.basicblock import IRBasicBlock from vyper.venom.function import IRFunction @@ -15,8 +16,6 @@ class DominatorTreeAnalysis(IRAnalysis): fn: IRFunction entry_block: IRBasicBlock - dfs_order: dict[IRBasicBlock, int] - dfs_walk: list[IRBasicBlock] dominators: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] immediate_dominators: dict[IRBasicBlock, IRBasicBlock] dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] @@ -28,16 +27,13 @@ def analyze(self): """ self.fn = self.function self.entry_block = self.fn.entry - self.dfs_order = {} - self.dfs_walk = [] self.dominators = {} self.immediate_dominators = {} self.dominated = {} self.dominator_frontiers = {} - self.analyses_cache.request_analysis(CFGAnalysis) + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) - self._compute_dfs(self.entry_block, OrderedSet()) self._compute_dominators() self._compute_idoms() self._compute_df() @@ -132,21 +128,13 @@ def _intersect(self, bb1, bb2): bb2 = self.immediate_dominators[bb2] return bb1 - def _compute_dfs(self, entry: IRBasicBlock, visited): - """ - Depth-first search to compute the DFS order of the basic blocks. This - is used to compute the dominator tree. The sequence of basic blocks in - the DFS order is stored in `self.dfs_walk`. The DFS order of each basic - block is stored in `self.dfs_order`. - """ - visited.add(entry) - - for bb in entry.cfg_out: - if bb not in visited: - self._compute_dfs(bb, visited) + @cached_property + def dfs_walk(self) -> list[IRBasicBlock]: + return list(self.cfg.dfs_walk) - self.dfs_walk.append(entry) - self.dfs_order[entry] = len(self.dfs_walk) + @cached_property + def dfs_order(self) -> dict[IRBasicBlock, int]: + return {bb: idx for idx, bb in enumerate(self.dfs_walk)} def as_graph(self) -> str: """ diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py deleted file mode 100644 index 7afb315035..0000000000 --- a/vyper/venom/analysis/dup_requirements.py +++ /dev/null @@ -1,15 +0,0 @@ -from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysis - - -class DupRequirementsAnalysis(IRAnalysis): - def analyze(self): - for bb in self.function.get_basic_blocks(): - last_liveness = bb.out_vars - for inst in reversed(bb.instructions): - inst.dup_requirements = OrderedSet() - ops = inst.get_input_variables() - for op in ops: - if op in last_liveness: - inst.dup_requirements.add(op) - last_liveness = inst.liveness diff --git a/vyper/venom/analysis/equivalent_vars.py b/vyper/venom/analysis/equivalent_vars.py new file mode 100644 index 0000000000..895895651a --- /dev/null +++ b/vyper/venom/analysis/equivalent_vars.py @@ -0,0 +1,40 @@ +from vyper.venom.analysis import DFGAnalysis, IRAnalysis +from vyper.venom.basicblock import IRVariable + + +class VarEquivalenceAnalysis(IRAnalysis): + """ + Generate equivalence sets of variables. This is used to avoid swapping + variables which are the same during venom_to_assembly. Theoretically, + the DFTPass should order variable declarations optimally, but, it is + not aware of the "pickaxe" heuristic in venom_to_assembly, so they can + interfere. + """ + + def analyze(self): + dfg = self.analyses_cache.request_analysis(DFGAnalysis) + + equivalence_set: dict[IRVariable, int] = {} + + for bag, (var, inst) in enumerate(dfg._dfg_outputs.items()): + if inst.opcode != "store": + continue + + source = inst.operands[0] + + assert var not in equivalence_set # invariant + if source in equivalence_set: + equivalence_set[var] = equivalence_set[source] + continue + else: + equivalence_set[var] = bag + equivalence_set[source] = bag + + self._equivalence_set = equivalence_set + + def equivalent(self, var1, var2): + if var1 not in self._equivalence_set: + return False + if var2 not in self._equivalence_set: + return False + return self._equivalence_set[var1] == self._equivalence_set[var2] diff --git a/vyper/venom/analysis/liveness.py b/vyper/venom/analysis/liveness.py index 5d1ac488f1..0ccda3de2c 100644 --- a/vyper/venom/analysis/liveness.py +++ b/vyper/venom/analysis/liveness.py @@ -1,7 +1,8 @@ +from collections import deque + from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysis -from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis import CFGAnalysis, IRAnalysis from vyper.venom.basicblock import IRBasicBlock, IRVariable @@ -11,16 +12,21 @@ class LivenessAnalysis(IRAnalysis): """ def analyze(self): - self.analyses_cache.request_analysis(CFGAnalysis) + cfg = self.analyses_cache.request_analysis(CFGAnalysis) self._reset_liveness() - while True: + + worklist = deque(cfg.dfs_walk) + + while len(worklist) > 0: changed = False - for bb in self.function.get_basic_blocks(): - changed |= self._calculate_out_vars(bb) - changed |= self._calculate_liveness(bb) - if not changed: - break + bb = worklist.popleft() + changed |= self._calculate_out_vars(bb) + changed |= self._calculate_liveness(bb) + # recompute liveness for basic blocks pointing into + # this basic block + if changed: + worklist.extend(bb.cfg_in) def _reset_liveness(self) -> None: for bb in self.function.get_basic_blocks(): @@ -54,11 +60,11 @@ def _calculate_out_vars(self, bb: IRBasicBlock) -> bool: Compute out_vars of basic block. Returns True if out_vars changed """ - out_vars = bb.out_vars + out_vars = bb.out_vars.copy() bb.out_vars = OrderedSet() for out_bb in bb.cfg_out: target_vars = self.input_vars_from(bb, out_bb) - bb.out_vars = bb.out_vars.union(target_vars) + bb.out_vars.update(target_vars) return out_vars != bb.out_vars # calculate the input variables into self from source diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index d6fb9560cd..8d86da73e7 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,6 +1,10 @@ +import json +import re from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +import vyper.venom.effects as effects from vyper.codegen.ir_node import IRnode +from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet # instructions which can terminate a basic block @@ -21,8 +25,6 @@ "istore", "tload", "tstore", - "assert", - "assert_unreachable", "mstore", "mload", "calldatacopy", @@ -83,10 +85,23 @@ CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "djmp", "jnz"]) +COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"]) + +COMPARATOR_INSTRUCTIONS = ("gt", "lt", "sgt", "slt") + if TYPE_CHECKING: from vyper.venom.function import IRFunction +def flip_comparison_opcode(opcode): + if opcode in ("gt", "sgt"): + return opcode.replace("g", "l") + elif opcode in ("lt", "slt"): + return opcode.replace("l", "g") + + raise CompilerPanic(f"unreachable {opcode}") # pragma: nocover + + class IRDebugInfo: """ IRDebugInfo represents debug information in IR, used to annotate IR @@ -102,7 +117,7 @@ def __init__(self, line_no: int, src: str) -> None: def __repr__(self) -> str: src = self.src if self.src else "" - return f"\t# line {self.line_no}: {src}".expandtabs(20) + return f"\t; line {self.line_no}: {src}".expandtabs(20) class IROperand: @@ -112,13 +127,20 @@ class IROperand: """ value: Any + _hash: Optional[int] = None + + def __init__(self, value: Any) -> None: + self.value = value + self._hash = None @property def name(self) -> str: return self.value def __hash__(self) -> int: - return hash(self.value) + if self._hash is None: + self._hash = hash(self.value) + return self._hash def __eq__(self, other) -> bool: if not isinstance(other, type(self)): @@ -138,7 +160,7 @@ class IRLiteral(IROperand): def __init__(self, value: int) -> None: assert isinstance(value, int), "value must be an int" - self.value = value + super().__init__(value) class IRVariable(IROperand): @@ -146,27 +168,25 @@ class IRVariable(IROperand): IRVariable represents a variable in IR. A variable is a string that starts with a %. """ - value: str - - def __init__(self, value: str, version: Optional[str | int] = None) -> None: - assert isinstance(value, str) - assert ":" not in value, "Variable name cannot contain ':'" - if version: - assert isinstance(value, str) or isinstance(value, int), "value must be an str or int" - value = f"{value}:{version}" - if value[0] != "%": - value = f"%{value}" - self.value = value + _name: str + version: Optional[int] + + def __init__(self, name: str, version: int = 0) -> None: + assert isinstance(name, str) + # TODO: allow version to be None + assert isinstance(version, int) + if not name.startswith("%"): + name = f"%{name}" + self._name = name + self.version = version + value = name + if version > 0: + value = f"{name}:{version}" + super().__init__(value) @property def name(self) -> str: - return self.value.split(":")[0] - - @property - def version(self) -> int: - if ":" not in self.value: - return 0 - return int(self.value.split(":")[1]) + return self._name class IRLabel(IROperand): @@ -180,18 +200,18 @@ class IRLabel(IROperand): value: str def __init__(self, value: str, is_symbol: bool = False) -> None: - assert isinstance(value, str), "value must be an str" - self.value = value + assert isinstance(value, str), f"not a str: {value} ({type(value)})" + assert len(value) > 0 self.is_symbol = is_symbol + super().__init__(value) - def __eq__(self, other): - # no need for is_symbol to participate in equality - return super().__eq__(other) + _IS_IDENTIFIER = re.compile("[0-9a-zA-Z_]*") - def __hash__(self): - # __hash__ is required when __eq__ is overridden -- - # https://docs.python.org/3/reference/datamodel.html#object.__hash__ - return super().__hash__() + def __repr__(self): + if self.__class__._IS_IDENTIFIER.fullmatch(self.value): + return self.value + + return json.dumps(self.value) # escape it class IRInstruction: @@ -206,12 +226,10 @@ class IRInstruction: opcode: str operands: list[IROperand] - output: Optional[IROperand] + output: Optional[IRVariable] # set of live variables at this instruction liveness: OrderedSet[IRVariable] - dup_requirements: OrderedSet[IRVariable] parent: "IRBasicBlock" - fence_id: int annotation: Optional[str] ast_source: Optional[IRnode] error_msg: Optional[str] @@ -220,7 +238,7 @@ def __init__( self, opcode: str, operands: list[IROperand] | Iterator[IROperand], - output: Optional[IROperand] = None, + output: Optional[IRVariable] = None, ): assert isinstance(opcode, str), "opcode must be an str" assert isinstance(operands, list | Iterator), "operands must be a list" @@ -228,8 +246,6 @@ def __init__( self.operands = list(operands) # in case we get an iterator self.output = output self.liveness = OrderedSet() - self.dup_requirements = OrderedSet() - self.fence_id = -1 self.annotation = None self.ast_source = None self.error_msg = None @@ -238,10 +254,44 @@ def __init__( def is_volatile(self) -> bool: return self.opcode in VOLATILE_INSTRUCTIONS + @property + def is_commutative(self) -> bool: + return self.opcode in COMMUTATIVE_INSTRUCTIONS + + @property + def is_comparator(self) -> bool: + return self.opcode in COMPARATOR_INSTRUCTIONS + + @property + def flippable(self) -> bool: + return self.is_commutative or self.is_comparator + @property def is_bb_terminator(self) -> bool: return self.opcode in BB_TERMINATORS + @property + def is_phi(self) -> bool: + return self.opcode == "phi" + + @property + def is_param(self) -> bool: + return self.opcode == "param" + + @property + def is_pseudo(self) -> bool: + """ + Check if instruction is pseudo, i.e. not an actual instruction but + a construct for intermediate representation like phi and param. + """ + return self.is_phi or self.is_param + + def get_read_effects(self): + return effects.reads.get(self.opcode, effects.EMPTY) + + def get_write_effects(self): + return effects.writes.get(self.opcode, effects.EMPTY) + def get_label_operands(self) -> Iterator[IRLabel]: """ Get all labels in instruction. @@ -268,6 +318,19 @@ def get_outputs(self) -> list[IROperand]: """ return [self.output] if self.output else [] + def flip(self): + """ + Flip operands for commutative or comparator opcodes + """ + assert self.flippable + self.operands.reverse() + + if self.is_commutative: + return + + assert self.opcode in COMPARATOR_INSTRUCTIONS # sanity + self.opcode = flip_comparison_opcode(self.opcode) + def replace_operands(self, replacements: dict) -> None: """ Update operands with replacements. @@ -328,19 +391,17 @@ def __repr__(self) -> str: opcode = f"{self.opcode} " if self.opcode != "store" else "" s += opcode operands = self.operands - if opcode not in ["jmp", "jnz", "invoke"]: + if self.opcode == "invoke": + operands = [operands[0]] + list(reversed(operands[1:])) + elif self.opcode not in ("jmp", "jnz", "phi"): operands = reversed(operands) # type: ignore - s += ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] - ) - if self.annotation: - s += f" <{self.annotation}>" + s += ", ".join([(f"@{op}" if isinstance(op, IRLabel) else str(op)) for op in operands]) - if self.liveness: - return f"{s: <30} # {self.liveness}" + if self.annotation: + s += f" ; {self.annotation}" - return s + return f"{s: <30}" def _ir_operand_from_value(val: Any) -> IROperand: @@ -384,7 +445,6 @@ class IRBasicBlock: # stack items which this basic block produces out_vars: OrderedSet[IRVariable] - reachable: OrderedSet["IRBasicBlock"] is_reachable: bool = False def __init__(self, label: IRLabel, parent: "IRFunction") -> None: @@ -395,9 +455,10 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None: self.cfg_in = OrderedSet() self.cfg_out = OrderedSet() self.out_vars = OrderedSet() - self.reachable = OrderedSet() self.is_reachable = False + self._garbage_instructions: set[IRInstruction] = set() + def add_cfg_in(self, bb: "IRBasicBlock") -> None: self.cfg_in.add(bb) @@ -416,7 +477,7 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: self.cfg_out.remove(bb) def append_instruction( - self, opcode: str, *args: Union[IROperand, int], ret: IRVariable = None + self, opcode: str, *args: Union[IROperand, int], ret: Optional[IRVariable] = None ) -> Optional[IRVariable]: """ Append an instruction to the basic block @@ -465,19 +526,54 @@ def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" if index is None: - assert not self.is_terminated, self + assert not self.is_terminated, (self, instruction) index = len(self.instructions) instruction.parent = self instruction.ast_source = self.parent.ast_source instruction.error_msg = self.parent.error_msg self.instructions.insert(index, instruction) + def mark_for_removal(self, instruction: IRInstruction) -> None: + self._garbage_instructions.add(instruction) + + def clear_dead_instructions(self) -> None: + if len(self._garbage_instructions) > 0: + self.instructions = [ + inst for inst in self.instructions if inst not in self._garbage_instructions + ] + self._garbage_instructions.clear() + def remove_instruction(self, instruction: IRInstruction) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" self.instructions.remove(instruction) - def clear_instructions(self) -> None: - self.instructions = [] + @property + def phi_instructions(self) -> Iterator[IRInstruction]: + for inst in self.instructions: + if inst.opcode == "phi": + yield inst + else: + return + + @property + def non_phi_instructions(self) -> Iterator[IRInstruction]: + return (inst for inst in self.instructions if inst.opcode != "phi") + + @property + def param_instructions(self) -> Iterator[IRInstruction]: + for inst in self.instructions: + if inst.opcode == "param": + yield inst + else: + return + + @property + def pseudo_instructions(self) -> Iterator[IRInstruction]: + return (inst for inst in self.instructions if inst.is_pseudo) + + @property + def body_instructions(self) -> Iterator[IRInstruction]: + return (inst for inst in self.instructions[:-1] if not inst.is_pseudo) def replace_operands(self, replacements: dict) -> None: """ @@ -486,6 +582,32 @@ def replace_operands(self, replacements: dict) -> None: for instruction in self.instructions: instruction.replace_operands(replacements) + def fix_phi_instructions(self): + cfg_in_labels = tuple(bb.label for bb in self.cfg_in) + + needs_sort = False + for inst in self.instructions: + if inst.opcode != "phi": + continue + + labels = inst.get_label_operands() + for label in labels: + if label not in cfg_in_labels: + needs_sort = True + inst.remove_phi_operand(label) + + op_len = len(inst.operands) + if op_len == 2: + inst.opcode = "store" + inst.operands = [inst.operands[1]] + elif op_len == 0: + inst.opcode = "nop" + inst.output = None + inst.operands = [] + + if needs_sort: + self.instructions.sort(key=lambda inst: inst.opcode != "phi") + def get_assignments(self): """ Get all assignments in basic block. @@ -542,10 +664,12 @@ def copy(self): return bb def __repr__(self) -> str: - s = ( - f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" - f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars}\n" - ) + s = f"{self.label}: ; IN={[bb.label for bb in self.cfg_in]}" + s += f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars}\n" for instruction in self.instructions: - s += f" {str(instruction).strip()}\n" + s += f" {str(instruction).strip()}\n" + if len(self.instructions) > 30: + s += f" ; {self.label}\n" + if len(self.instructions) > 30 or self.parent.num_basic_blocks > 5: + s += f" ; ({self.parent.name})\n\n" return s diff --git a/vyper/venom/context.py b/vyper/venom/context.py index 0b0252d976..0c5cbc379c 100644 --- a/vyper/venom/context.py +++ b/vyper/venom/context.py @@ -1,14 +1,40 @@ +import textwrap +from dataclasses import dataclass, field from typing import Optional -from vyper.venom.basicblock import IRInstruction, IRLabel, IROperand +from vyper.venom.basicblock import IRLabel from vyper.venom.function import IRFunction +@dataclass +class DataItem: + data: IRLabel | bytes # can be raw data or bytes + + def __str__(self): + if isinstance(self.data, IRLabel): + return f"@{self.data}" + else: + assert isinstance(self.data, bytes) + return f'x"{self.data.hex()}"' + + +@dataclass +class DataSection: + label: IRLabel + data_items: list[DataItem] = field(default_factory=list) + + def __str__(self): + ret = [f"dbsection {self.label.value}:"] + for item in self.data_items: + ret.append(f" db {item}") + return "\n".join(ret) + + class IRContext: functions: dict[IRLabel, IRFunction] ctor_mem_size: Optional[int] immutables_len: Optional[int] - data_segment: list[IRInstruction] + data_segment: list[DataSection] last_label: int def __init__(self) -> None: @@ -47,11 +73,16 @@ def chain_basic_blocks(self) -> None: for fn in self.functions.values(): fn.chain_basic_blocks() - def append_data(self, opcode: str, args: list[IROperand]) -> None: + def append_data_section(self, name: IRLabel) -> None: + self.data_segment.append(DataSection(name)) + + def append_data_item(self, data: IRLabel | bytes) -> None: """ - Append data + Append data to current data section """ - self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + assert len(self.data_segment) > 0 + data_section = self.data_segment[-1] + data_section.data_items.append(DataItem(data)) def as_graph(self) -> str: s = ["digraph G {"] @@ -62,14 +93,15 @@ def as_graph(self) -> str: return "\n".join(s) def __repr__(self) -> str: - s = ["IRContext:"] + s = [] for fn in self.functions.values(): - s.append(fn.__repr__()) + s.append(IRFunction.__repr__(fn)) s.append("\n") if len(self.data_segment) > 0: - s.append("\nData segment:") - for inst in self.data_segment: - s.append(f"{inst}") + s.append("data readonly {") + for data_section in self.data_segment: + s.append(textwrap.indent(DataSection.__str__(data_section), " ")) + s.append("}") return "\n".join(s) diff --git a/vyper/venom/effects.py b/vyper/venom/effects.py new file mode 100644 index 0000000000..bbda481e14 --- /dev/null +++ b/vyper/venom/effects.py @@ -0,0 +1,94 @@ +from enum import Flag, auto + + +class Effects(Flag): + STORAGE = auto() + TRANSIENT = auto() + MEMORY = auto() + MSIZE = auto() + IMMUTABLES = auto() + RETURNDATA = auto() + LOG = auto() + BALANCE = auto() + EXTCODE = auto() + + def __iter__(self): + # python3.10 doesn't have an iter implementation. we can + # remove this once we drop python3.10 support. + return (m for m in self.__class__.__members__.values() if m in self) + + +EMPTY = Effects(0) +ALL = ~EMPTY +STORAGE = Effects.STORAGE +TRANSIENT = Effects.TRANSIENT +MEMORY = Effects.MEMORY +MSIZE = Effects.MSIZE +IMMUTABLES = Effects.IMMUTABLES +RETURNDATA = Effects.RETURNDATA +LOG = Effects.LOG +BALANCE = Effects.BALANCE +EXTCODE = Effects.EXTCODE + + +_writes = { + "sstore": STORAGE, + "tstore": TRANSIENT, + "mstore": MEMORY, + "istore": IMMUTABLES, + "call": ALL ^ IMMUTABLES, + "delegatecall": ALL ^ IMMUTABLES, + "staticcall": MEMORY | RETURNDATA, + "create": ALL ^ (MEMORY | IMMUTABLES), + "create2": ALL ^ (MEMORY | IMMUTABLES), + "invoke": ALL, # could be smarter, look up the effects of the invoked function + "log": LOG, + "dloadbytes": MEMORY, + "dload": MEMORY, + "returndatacopy": MEMORY, + "calldatacopy": MEMORY, + "codecopy": MEMORY, + "extcodecopy": MEMORY, + "mcopy": MEMORY, +} + +_reads = { + "sload": STORAGE, + "tload": TRANSIENT, + "iload": IMMUTABLES, + "mload": MEMORY, + "mcopy": MEMORY, + "call": ALL, + "delegatecall": ALL, + "staticcall": ALL, + "create": ALL, + "create2": ALL, + "invoke": ALL, + "returndatasize": RETURNDATA, + "returndatacopy": RETURNDATA, + "balance": BALANCE, + "selfbalance": BALANCE, + "extcodecopy": EXTCODE, + "extcodesize": EXTCODE, + "extcodehash": EXTCODE, + "selfdestruct": BALANCE, # may modify code, but after the transaction + "log": MEMORY, + "revert": MEMORY, + "return": MEMORY, + "sha3": MEMORY, + "sha3_64": MEMORY, + "msize": MSIZE, +} + +reads = _reads.copy() +writes = _writes.copy() + +for k, v in reads.items(): + if MEMORY in v: + if k not in writes: + writes[k] = EMPTY + writes[k] |= MSIZE + +for k, v in writes.items(): + if MEMORY in v: + writes[k] |= MSIZE diff --git a/vyper/venom/function.py b/vyper/venom/function.py index fb0dabc99a..f02da77fe3 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -1,8 +1,8 @@ +import textwrap from typing import Iterator, Optional from vyper.codegen.ir_node import IRnode -from vyper.utils import OrderedSet -from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRLabel, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable class IRFunction: @@ -13,7 +13,6 @@ class IRFunction: name: IRLabel # symbol name ctx: "IRContext" # type: ignore # noqa: F821 args: list - last_label: int last_variable: int _basic_block_dict: dict[str, IRBasicBlock] @@ -43,7 +42,7 @@ def append_basic_block(self, bb: IRBasicBlock): Append basic block to function. """ assert isinstance(bb, IRBasicBlock), bb - assert bb.label.name not in self._basic_block_dict + assert bb.label.name not in self._basic_block_dict, bb.label self._basic_block_dict[bb.label.name] = bb def remove_basic_block(self, bb: IRBasicBlock): @@ -89,60 +88,31 @@ def get_last_variable(self) -> str: return f"%{self.last_variable}" def remove_unreachable_blocks(self) -> int: - self._compute_reachability() + # Remove unreachable basic blocks + # pre: requires CFG analysis! + # NOTE: should this be a pass? - removed = [] + removed = set() - # Remove unreachable basic blocks for bb in self.get_basic_blocks(): if not bb.is_reachable: - removed.append(bb) + removed.add(bb) for bb in removed: self.remove_basic_block(bb) # Remove phi instructions that reference removed basic blocks - for bb in removed: - for out_bb in bb.cfg_out: - out_bb.remove_cfg_in(bb) - for inst in out_bb.instructions: - if inst.opcode != "phi": - continue - in_labels = inst.get_label_operands() - if bb.label in in_labels: - inst.remove_phi_operand(bb.label) - op_len = len(inst.operands) - if op_len == 2: - inst.opcode = "store" - inst.operands = [inst.operands[1]] - elif op_len == 0: - out_bb.remove_instruction(inst) - - return len(removed) - - def _compute_reachability(self) -> None: - """ - Compute reachability of basic blocks. - """ for bb in self.get_basic_blocks(): - bb.reachable = OrderedSet() - bb.is_reachable = False + for in_bb in list(bb.cfg_in): + if in_bb not in removed: + continue - self._compute_reachability_from(self.entry) + bb.remove_cfg_in(in_bb) - def _compute_reachability_from(self, bb: IRBasicBlock) -> None: - """ - Compute reachability of basic blocks from bb. - """ - if bb.is_reachable: - return - bb.is_reachable = True - for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_INSTRUCTIONS: - for op in inst.get_label_operands(): - out_bb = self.get_basic_block(op.value) - bb.reachable.add(out_bb) - self._compute_reachability_from(out_bb) + # TODO: only run this if cfg_in changed + bb.fix_phi_instructions() + + return len(removed) @property def normalized(self) -> bool: @@ -195,22 +165,23 @@ def chain_basic_blocks(self) -> None: """ bbs = list(self.get_basic_blocks()) for i, bb in enumerate(bbs): - if not bb.is_terminated: - if i < len(bbs) - 1: - # TODO: revisit this. When contructor calls internal functions they - # are linked to the last ctor block. Should separate them before this - # so we don't have to handle this here - if bbs[i + 1].label.value.startswith("internal"): - bb.append_instruction("stop") - else: - bb.append_instruction("jmp", bbs[i + 1].label) + if bb.is_terminated: + continue + + if i < len(bbs) - 1: + # TODO: revisit this. When contructor calls internal functions + # they are linked to the last ctor block. Should separate them + # before this so we don't have to handle this here + if bbs[i + 1].label.value.startswith("internal"): + bb.append_instruction("stop") else: - bb.append_instruction("exit") + bb.append_instruction("jmp", bbs[i + 1].label) + else: + bb.append_instruction("stop") def copy(self): new = IRFunction(self.name) new._basic_block_dict = self._basic_block_dict.copy() - new.last_label = self.last_label new.last_variable = self.last_variable return new @@ -252,7 +223,10 @@ def _make_label(bb): return "\n".join(ret) def __repr__(self) -> str: - str = f"IRFunction: {self.name}\n" + ret = f"function {self.name} {{\n" for bb in self.get_basic_blocks(): - str += f"{bb}\n" - return str.strip() + bb_str = textwrap.indent(str(bb), " ") + ret += f"{bb_str}\n" + ret = ret.strip() + "\n}" + ret += f" ; close function {self.name}" + return ret diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 85172c70e1..f46457b77f 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -4,7 +4,6 @@ from vyper.codegen.ir_node import IRnode from vyper.evm.opcodes import get_opcodes -from vyper.utils import MemoryPositions from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -67,6 +66,8 @@ "mload", "iload", "istore", + "dload", + "dloadbytes", "sload", "sstore", "tload", @@ -107,18 +108,16 @@ NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) SymbolTable = dict[str, Optional[IROperand]] -_global_symbols: SymbolTable = None # type: ignore +_alloca_table: SymbolTable = None # type: ignore MAIN_ENTRY_LABEL_NAME = "__main_entry" -_external_functions: dict[int, SymbolTable] = None # type: ignore # convert IRnode directly to venom def ir_node_to_venom(ir: IRnode) -> IRContext: _ = ir.unique_symbols # run unique symbols check - global _global_symbols, _external_functions - _global_symbols = {} - _external_functions = {} + global _alloca_table + _alloca_table = {} ctx = IRContext() fn = ctx.create_function(MAIN_ENTRY_LABEL_NAME) @@ -233,7 +232,7 @@ def pop_source(*args, **kwargs): def _convert_ir_bb(fn, ir, symbols): assert isinstance(ir, IRnode), ir # TODO: refactor these to not be globals - global _break_target, _continue_target, _global_symbols, _external_functions + global _break_target, _continue_target, _alloca_table # keep a map from external functions to all possible entry points @@ -255,6 +254,7 @@ def _convert_ir_bb(fn, ir, symbols): elif ir.value == "deploy": ctx.ctor_mem_size = ir.args[0].value ctx.immutables_len = ir.args[2].value + fn.get_basic_block().append_instruction("exit") return None elif ir.value == "seq": if len(ir.args) == 0: @@ -268,8 +268,8 @@ def _convert_ir_bb(fn, ir, symbols): if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", current_func)) > 0: # Internal definition var_list = ir.args[0].args[1] + assert var_list.value == "var_list" does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args - _global_symbols = {} symbols = {} new_fn = _handle_internal_func(fn, ir, does_return_data, symbols) for ir_node in ir.args[1:]: @@ -297,8 +297,6 @@ def _convert_ir_bb(fn, ir, symbols): cont_ret = _convert_ir_bb(fn, cond, symbols) cond_block = fn.get_basic_block() - saved_global_symbols = _global_symbols.copy() - then_block = IRBasicBlock(ctx.get_next_label("then"), fn) else_block = IRBasicBlock(ctx.get_next_label("else"), fn) @@ -313,7 +311,6 @@ def _convert_ir_bb(fn, ir, symbols): # convert "else" cond_symbols = symbols.copy() - _global_symbols = saved_global_symbols.copy() fn.append_basic_block(else_block) else_ret_val = None if len(ir.args) == 3: @@ -342,8 +339,6 @@ def _convert_ir_bb(fn, ir, symbols): if not then_block_finish.is_terminated: then_block_finish.append_instruction("jmp", exit_bb.label) - _global_symbols = saved_global_symbols - return if_ret elif ir.value == "with": @@ -372,25 +367,16 @@ def _convert_ir_bb(fn, ir, symbols): elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": - label = IRLabel(ir.args[0].value) - ctx.append_data("dbname", [label]) + label = IRLabel(ir.args[0].value, True) + ctx.append_data_section(label) for c in ir.args[1:]: - if isinstance(c, int): - assert 0 <= c <= 255, "data with invalid size" - ctx.append_data("db", [c]) # type: ignore - elif isinstance(c.value, bytes): - ctx.append_data("db", [c.value]) # type: ignore + if isinstance(c.value, bytes): + ctx.append_data_item(c.value) elif isinstance(c, IRnode): data = _convert_ir_bb(fn, c, symbols) - ctx.append_data("db", [data]) # type: ignore + assert isinstance(data, IRLabel) # help mypy + ctx.append_data_item(data) elif ir.value == "label": - function_id_pattern = r"external (\d+)" - function_name = ir.args[0].value - m = re.match(function_id_pattern, function_name) - if m is not None: - function_id = m.group(1) - _global_symbols = _external_functions.setdefault(function_id, {}) - label = IRLabel(ir.args[0].value, True) bb = fn.get_basic_block() if not bb.is_terminated: @@ -398,10 +384,7 @@ def _convert_ir_bb(fn, ir, symbols): bb = IRBasicBlock(label, fn) fn.append_basic_block(bb) code = ir.args[2] - if code.value == "pass": - bb.append_instruction("exit") - else: - _convert_ir_bb(fn, code, symbols) + _convert_ir_bb(fn, code, symbols) elif ir.value == "exit_to": args = _convert_ir_bb_list(fn, ir.args[1:], symbols) var_list = args @@ -419,22 +402,6 @@ def _convert_ir_bb(fn, ir, symbols): else: bb.append_instruction("jmp", label) - elif ir.value == "dload": - arg_0 = _convert_ir_bb(fn, ir.args[0], symbols) - bb = fn.get_basic_block() - src = bb.append_instruction("add", arg_0, IRLabel("code_end")) - - bb.append_instruction("dloadbytes", 32, src, MemoryPositions.FREE_VAR_SPACE) - return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) - - elif ir.value == "dloadbytes": - dst, src_offset, len_ = _convert_ir_bb_list(fn, ir.args, symbols) - - bb = fn.get_basic_block() - src = bb.append_instruction("add", src_offset, IRLabel("code_end")) - bb.append_instruction("dloadbytes", len_, src, dst) - return None - elif ir.value == "mstore": # some upstream code depends on reversed order of evaluation -- # to fix upstream. @@ -465,13 +432,11 @@ def _convert_ir_bb(fn, ir, symbols): elif ir.value == "repeat": def emit_body_blocks(): - global _break_target, _continue_target, _global_symbols + global _break_target, _continue_target old_targets = _break_target, _continue_target _break_target, _continue_target = exit_block, incr_block - saved_global_symbols = _global_symbols.copy() _convert_ir_bb(fn, body, symbols.copy()) _break_target, _continue_target = old_targets - _global_symbols = saved_global_symbols sym = ir.args[0] start, end, _ = _convert_ir_bb_list(fn, ir.args[1:4], symbols) @@ -542,16 +507,25 @@ def emit_body_blocks(): elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): _convert_ir_opcode(fn, ir, symbols) elif isinstance(ir.value, str): - if ir.value.startswith("$alloca") and ir.value not in _global_symbols: + if ir.value.startswith("$alloca"): alloca = ir.passthrough_metadata["alloca"] - ptr = fn.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size) - _global_symbols[ir.value] = ptr - elif ir.value.startswith("$palloca") and ir.value not in _global_symbols: + if alloca._id not in _alloca_table: + ptr = fn.get_basic_block().append_instruction( + "alloca", alloca.offset, alloca.size, alloca._id + ) + _alloca_table[alloca._id] = ptr + return _alloca_table[alloca._id] + + elif ir.value.startswith("$palloca"): alloca = ir.passthrough_metadata["alloca"] - ptr = fn.get_basic_block().append_instruction("store", alloca.offset) - _global_symbols[ir.value] = ptr - - return _global_symbols.get(ir.value) or symbols.get(ir.value) + if alloca._id not in _alloca_table: + ptr = fn.get_basic_block().append_instruction( + "palloca", alloca.offset, alloca.size, alloca._id + ) + _alloca_table[alloca._id] = ptr + return _alloca_table[alloca._id] + + return symbols.get(ir.value) elif ir.is_literal: return IRLiteral(ir.value) else: diff --git a/vyper/venom/parser.py b/vyper/venom/parser.py new file mode 100644 index 0000000000..5ccc29b7a4 --- /dev/null +++ b/vyper/venom/parser.py @@ -0,0 +1,238 @@ +import json + +from lark import Lark, Transformer + +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, +) +from vyper.venom.context import DataItem, DataSection, IRContext +from vyper.venom.function import IRFunction + +VENOM_GRAMMAR = """ + %import common.CNAME + %import common.DIGIT + %import common.HEXDIGIT + %import common.LETTER + %import common.WS + %import common.INT + %import common.SIGNED_INT + %import common.ESCAPED_STRING + + # Allow multiple comment styles + COMMENT: ";" /[^\\n]*/ | "//" /[^\\n]*/ | "#" /[^\\n]*/ + + start: function* data_segment? + + # TODO: consider making entry block implicit, e.g. + # `"{" instruction+ block* "}"` + function: "function" LABEL_IDENT "{" block* "}" + + data_segment: "data" "readonly" "{" data_section* "}" + data_section: "dbsection" LABEL_IDENT ":" data_item+ + data_item: "db" (HEXSTR | LABEL) + + block: LABEL_IDENT ":" "\\n" statement* + + statement: (instruction | assignment) "\\n" + assignment: VAR_IDENT "=" expr + expr: instruction | operand + instruction: OPCODE operands_list? + + operands_list: operand ("," operand)* + + operand: VAR_IDENT | CONST | LABEL + + CONST: SIGNED_INT + OPCODE: CNAME + VAR_IDENT: "%" (DIGIT|LETTER|"_"|":")+ + + # handy for identifier to be an escaped string sometimes + # (especially for machine-generated labels) + LABEL_IDENT: (NAME | ESCAPED_STRING) + LABEL: "@" LABEL_IDENT + + DOUBLE_QUOTE: "\\"" + NAME: (DIGIT|LETTER|"_")+ + HEXSTR: "x" DOUBLE_QUOTE (HEXDIGIT|"_")+ DOUBLE_QUOTE + + %ignore WS + %ignore COMMENT + """ + +VENOM_PARSER = Lark(VENOM_GRAMMAR) + + +def _set_last_var(fn: IRFunction): + for bb in fn.get_basic_blocks(): + for inst in bb.instructions: + if inst.output is None: + continue + value = inst.output.value + assert value.startswith("%") + varname = value[1:] + if varname.isdigit(): + fn.last_variable = max(fn.last_variable, int(varname)) + + +def _set_last_label(ctx: IRContext): + for fn in ctx.functions.values(): + for bb in fn.get_basic_blocks(): + label = bb.label.value + label_head, *_ = label.split("_", maxsplit=1) + if label_head.isdigit(): + ctx.last_label = max(int(label_head), ctx.last_label) + + +def _ensure_terminated(bb): + # Since "revert" is not considered terminal explicitly check for it to ensure basic + # blocks are terminating + if not bb.is_terminated: + if any(inst.opcode == "revert" for inst in bb.instructions): + bb.append_instruction("stop") + # TODO: raise error if still not terminated. + + +def _unescape(s: str): + """ + Unescape the escaped string. This is the inverse of `IRLabel.__repr__()`. + """ + if s.startswith('"'): + return json.loads(s) + return s + + +class _TypedItem: + def __init__(self, children): + self.children = children + + +class _DataSegment(_TypedItem): + pass + + +class VenomTransformer(Transformer): + def start(self, children) -> IRContext: + ctx = IRContext() + if len(children) > 0 and isinstance(children[-1], _DataSegment): + ctx.data_segment = children.pop().children + + funcs = children + for fn_name, blocks in funcs: + fn = ctx.create_function(fn_name) + fn._basic_block_dict.clear() + + for block_name, instructions in blocks: + bb = IRBasicBlock(IRLabel(block_name, True), fn) + fn.append_basic_block(bb) + + for instruction in instructions: + assert isinstance(instruction, IRInstruction) # help mypy + bb.insert_instruction(instruction) + + _ensure_terminated(bb) + + _set_last_var(fn) + _set_last_label(ctx) + + return ctx + + def function(self, children) -> tuple[str, list[tuple[str, list[IRInstruction]]]]: + name, *blocks = children + return name, blocks + + def statement(self, children): + return children[0] + + def data_segment(self, children): + return _DataSegment(children) + + def data_section(self, children): + label = IRLabel(children[0], True) + data_items = children[1:] + assert all(isinstance(item, DataItem) for item in data_items) + return DataSection(label, data_items) + + def data_item(self, children): + item = children[0] + if isinstance(item, IRLabel): + return DataItem(item) + assert item.startswith('x"') + assert item.endswith('"') + item = item.removeprefix('x"').removesuffix('"') + item = item.replace("_", "") + return DataItem(bytes.fromhex(item)) + + def block(self, children) -> tuple[str, list[IRInstruction]]: + label, *instructions = children + return label, instructions + + def assignment(self, children) -> IRInstruction: + to, value = children + if isinstance(value, IRInstruction): + value.output = to + return value + if isinstance(value, (IRLiteral, IRVariable)): + return IRInstruction("store", [value], output=to) + raise TypeError(f"Unexpected value {value} of type {type(value)}") + + def expr(self, children): + return children[0] + + def instruction(self, children) -> IRInstruction: + if len(children) == 1: + opcode = children[0] + operands = [] + else: + assert len(children) == 2 + opcode, operands = children + + # reverse operands, venom internally represents top of stack + # as rightmost operand + if opcode == "invoke": + # reverse stack arguments but not label arg + # invoke + operands = [operands[0]] + list(reversed(operands[1:])) + # special cases: operands with labels look better un-reversed + elif opcode not in ("jmp", "jnz", "phi"): + operands.reverse() + return IRInstruction(opcode, operands) + + def operands_list(self, children) -> list[IROperand]: + return children + + def operand(self, children) -> IROperand: + return children[0] + + def OPCODE(self, token): + return token.value + + def LABEL_IDENT(self, label) -> str: + return _unescape(label) + + def LABEL(self, label) -> IRLabel: + label = _unescape(label[1:]) + return IRLabel(label, True) + + def VAR_IDENT(self, var_ident) -> IRVariable: + return IRVariable(var_ident[1:]) + + def CONST(self, val) -> IRLiteral: + return IRLiteral(int(val)) + + def CNAME(self, val) -> str: + return val.value + + def NAME(self, val) -> str: + return val.value + + +def parse_venom(source: str) -> IRContext: + tree = VENOM_PARSER.parse(source) + ctx = VenomTransformer().transform(tree) + assert isinstance(ctx, IRContext) # help mypy + return ctx diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py new file mode 100644 index 0000000000..a3227dcf4b --- /dev/null +++ b/vyper/venom/passes/__init__.py @@ -0,0 +1,16 @@ +from .algebraic_optimization import AlgebraicOptimizationPass +from .branch_optimization import BranchOptimizationPass +from .dft import DFTPass +from .float_allocas import FloatAllocas +from .literals_codesize import ReduceLiteralsCodesize +from .load_elimination import LoadElimination +from .lower_dload import LowerDloadPass +from .make_ssa import MakeSSA +from .mem2var import Mem2Var +from .memmerging import MemMergePass +from .normalization import NormalizationPass +from .remove_unused_variables import RemoveUnusedVariablesPass +from .sccp import SCCP +from .simplify_cfg import SimplifyCFGPass +from .store_elimination import StoreElimination +from .store_expansion import StoreExpansionPass diff --git a/vyper/venom/passes/algebraic_optimization.py b/vyper/venom/passes/algebraic_optimization.py index 4094219a6d..b4f4104d5f 100644 --- a/vyper/venom/passes/algebraic_optimization.py +++ b/vyper/venom/passes/algebraic_optimization.py @@ -1,17 +1,105 @@ +from vyper.utils import SizeLimits, int_bounds, int_log2, is_power_of_two, wrap256 from vyper.venom.analysis.dfg import DFGAnalysis from vyper.venom.analysis.liveness import LivenessAnalysis -from vyper.venom.basicblock import IRInstruction, IROperand +from vyper.venom.basicblock import ( + COMPARATOR_INSTRUCTIONS, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + flip_comparison_opcode, +) from vyper.venom.passes.base_pass import IRPass +TRUTHY_INSTRUCTIONS = ("iszero", "jnz", "assert", "assert_unreachable") + + +def lit_eq(op: IROperand, val: int) -> bool: + return isinstance(op, IRLiteral) and wrap256(op.value) == wrap256(val) + + +class InstructionUpdater: + """ + A helper class for updating instructions which also updates the + basic block and dfg in place + """ + + def __init__(self, dfg: DFGAnalysis): + self.dfg = dfg + + def _update_operands(self, inst: IRInstruction, replace_dict: dict[IROperand, IROperand]): + old_operands = inst.operands + new_operands = [replace_dict[op] if op in replace_dict else op for op in old_operands] + self._update(inst, inst.opcode, new_operands) + + def _update(self, inst: IRInstruction, opcode: str, new_operands: list[IROperand]): + assert opcode != "phi" + # sanity + assert all(isinstance(op, IROperand) for op in new_operands) + + old_operands = inst.operands + + for op in old_operands: + if not isinstance(op, IRVariable): + continue + uses = self.dfg.get_uses(op) + if inst in uses: + uses.remove(inst) + + for op in new_operands: + if isinstance(op, IRVariable): + self.dfg.add_use(op, inst) + + inst.opcode = opcode + inst.operands = new_operands + + def _store(self, inst: IRInstruction, op: IROperand): + self._update(inst, "store", [op]) + + def _add_before(self, inst: IRInstruction, opcode: str, args: list[IROperand]) -> IRVariable: + """ + Insert another instruction before the given instruction + """ + assert opcode != "phi" + index = inst.parent.instructions.index(inst) + var = inst.parent.parent.get_next_variable() + operands = list(args) + new_inst = IRInstruction(opcode, operands, output=var) + inst.parent.insert_instruction(new_inst, index) + for op in new_inst.operands: + if isinstance(op, IRVariable): + self.dfg.add_use(op, new_inst) + self.dfg.add_use(var, inst) + self.dfg.set_producing_instruction(var, new_inst) + return var + class AlgebraicOptimizationPass(IRPass): """ This pass reduces algebraic evaluatable expressions. It currently optimizes: - * iszero chains + - iszero chains + - binops + - offset adds """ + dfg: DFGAnalysis + updater: InstructionUpdater + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore + self.updater = InstructionUpdater(self.dfg) + self._handle_offset() + + self._algebraic_opt() + self._optimize_iszero_chains() + self._algebraic_opt() + + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + def _optimize_iszero_chains(self) -> None: fn = self.function for bb in fn.get_basic_blocks(): @@ -24,7 +112,8 @@ def _optimize_iszero_chains(self) -> None: if iszero_count == 0: continue - for use_inst in self.dfg.get_uses(inst.output): + assert isinstance(inst.output, IRVariable) + for use_inst in self.dfg.get_uses(inst.output).copy(): opcode = use_inst.opcode if opcode == "iszero": @@ -43,12 +132,14 @@ def _optimize_iszero_chains(self) -> None: continue out_var = iszero_chain[keep_count].operands[0] - use_inst.replace_operands({inst.output: out_var}) + self.updater._update_operands(use_inst, {inst.output: out_var}) def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]: chain: list[IRInstruction] = [] while True: + if not isinstance(op, IRVariable): + break inst = self.dfg.get_producing_instruction(op) if inst is None or inst.opcode != "iszero": break @@ -58,10 +149,302 @@ def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]: chain.reverse() return chain - def run_pass(self): - self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + def _handle_offset(self): + for bb in self.function.get_basic_blocks(): + for inst in bb.instructions: + if ( + inst.opcode == "add" + and self._is_lit(inst.operands[0]) + and isinstance(inst.operands[1], IRLabel) + ): + inst.opcode = "offset" - self._optimize_iszero_chains() + def _is_lit(self, operand: IROperand) -> bool: + return isinstance(operand, IRLiteral) - self.analyses_cache.invalidate_analysis(DFGAnalysis) - self.analyses_cache.invalidate_analysis(LivenessAnalysis) + def _algebraic_opt(self): + self._algebraic_opt_pass() + + def _algebraic_opt_pass(self): + for bb in self.function.get_basic_blocks(): + for inst in bb.instructions: + self._handle_inst_peephole(inst) + self._flip_inst(inst) + + def _flip_inst(self, inst: IRInstruction): + ops = inst.operands + # improve code. this seems like it should be properly handled by + # better heuristics in DFT pass. + if inst.flippable and self._is_lit(ops[0]) and not self._is_lit(ops[1]): + inst.flip() + + # "peephole", weakening algebraic optimizations + def _handle_inst_peephole(self, inst: IRInstruction): + if inst.output is None: + return + if inst.is_volatile: + return + if inst.opcode == "store": + return + if inst.is_pseudo: + return + + # TODO nice to have rules: + # -1 * x => 0 - x + # x // -1 => 0 - x (?) + # x + (-1) => x - 1 # save codesize, maybe for all negative numbers) + # 1 // x => x == 1(?) + # 1 % x => x > 1(?) + # !!x => x > 0 # saves 1 gas as of shanghai + + operands = inst.operands + + # make logic easier for commutative instructions. + if inst.flippable and self._is_lit(operands[1]) and not self._is_lit(operands[0]): + inst.flip() + operands = inst.operands + + if inst.opcode in {"shl", "shr", "sar"}: + # (x >> 0) == (x << 0) == x + if lit_eq(operands[1], 0): + self.updater._store(inst, operands[0]) + return + # no more cases for these instructions + return + + if inst.opcode == "exp": + # x ** 0 -> 1 + if lit_eq(operands[0], 0): + self.updater._store(inst, IRLiteral(1)) + return + + # 1 ** x -> 1 + if lit_eq(operands[1], 1): + self.updater._store(inst, IRLiteral(1)) + return + + # 0 ** x -> iszero x + if lit_eq(operands[1], 0): + self.updater._update(inst, "iszero", [operands[0]]) + return + + # x ** 1 -> x + if lit_eq(operands[0], 1): + self.updater._store(inst, operands[1]) + return + + # no more cases for this instruction + return + + if inst.opcode in {"add", "sub", "xor"}: + # (x - x) == (x ^ x) == 0 + if inst.opcode in ("xor", "sub") and operands[0] == operands[1]: + self.updater._store(inst, IRLiteral(0)) + return + + # (x + 0) == (0 + x) -> x + # x - 0 -> x + # (x ^ 0) == (0 ^ x) -> x + if lit_eq(operands[0], 0): + self.updater._store(inst, operands[1]) + return + + # (-1) - x -> ~x + # from two's complement + if inst.opcode == "sub" and lit_eq(operands[1], -1): + self.updater._update(inst, "not", [operands[0]]) + return + + # x ^ 0xFFFF..FF -> ~x + if inst.opcode == "xor" and lit_eq(operands[0], -1): + self.updater._update(inst, "not", [operands[1]]) + return + + return + + # x & 0xFF..FF -> x + if inst.opcode == "and" and lit_eq(operands[0], -1): + self.updater._store(inst, operands[1]) + return + + if inst.opcode in ("mul", "and", "div", "sdiv", "mod", "smod"): + # (x * 0) == (x & 0) == (x // 0) == (x % 0) -> 0 + if any(lit_eq(op, 0) for op in operands): + self.updater._store(inst, IRLiteral(0)) + return + + if inst.opcode in {"mul", "div", "sdiv", "mod", "smod"}: + if inst.opcode in ("mod", "smod") and lit_eq(operands[0], 1): + # x % 1 -> 0 + self.updater._store(inst, IRLiteral(0)) + return + + # (x * 1) == (1 * x) == (x // 1) -> x + if inst.opcode in ("mul", "div", "sdiv") and lit_eq(operands[0], 1): + self.updater._store(inst, operands[1]) + return + + if self._is_lit(operands[0]) and is_power_of_two(operands[0].value): + val = operands[0].value + # x % (2^n) -> x & (2^n - 1) + if inst.opcode == "mod": + self.updater._update(inst, "and", [IRLiteral(val - 1), operands[1]]) + return + # x / (2^n) -> x >> n + if inst.opcode == "div": + self.updater._update(inst, "shr", [operands[1], IRLiteral(int_log2(val))]) + return + # x * (2^n) -> x << n + if inst.opcode == "mul": + self.updater._update(inst, "shl", [operands[1], IRLiteral(int_log2(val))]) + return + return + + assert inst.output is not None + uses = self.dfg.get_uses(inst.output) + + is_truthy = all(i.opcode in TRUTHY_INSTRUCTIONS for i in uses) + prefer_iszero = all(i.opcode in ("assert", "iszero") for i in uses) + + # TODO rules like: + # not x | not y => not (x & y) + # x | not y => not (not x & y) + + if inst.opcode == "or": + # x | 0xff..ff == 0xff..ff + if any(lit_eq(op, SizeLimits.MAX_UINT256) for op in operands): + self.updater._store(inst, IRLiteral(SizeLimits.MAX_UINT256)) + return + + # x | n -> 1 in truthy positions (if n is non zero) + if is_truthy and self._is_lit(operands[0]) and operands[0].value != 0: + self.updater._store(inst, IRLiteral(1)) + return + + # x | 0 -> x + if lit_eq(operands[0], 0): + self.updater._store(inst, operands[1]) + return + + if inst.opcode == "eq": + # x == x -> 1 + if operands[0] == operands[1]: + self.updater._store(inst, IRLiteral(1)) + return + + # x == 0 -> iszero x + if lit_eq(operands[0], 0): + self.updater._update(inst, "iszero", [operands[1]]) + return + + # eq x -1 -> iszero(~x) + # (saves codesize, not gas) + if lit_eq(operands[0], -1): + var = self.updater._add_before(inst, "not", [operands[1]]) + self.updater._update(inst, "iszero", [var]) + return + + if prefer_iszero: + # (eq x y) has the same truthyness as (iszero (xor x y)) + tmp = self.updater._add_before(inst, "xor", [operands[0], operands[1]]) + + self.updater._update(inst, "iszero", [tmp]) + return + + if inst.opcode in COMPARATOR_INSTRUCTIONS: + self._optimize_comparator_instruction(inst, prefer_iszero) + + def _optimize_comparator_instruction(self, inst, prefer_iszero): + opcode, operands = inst.opcode, inst.operands + assert opcode in COMPARATOR_INSTRUCTIONS # sanity + assert isinstance(inst.output, IRVariable) # help mypy + + # (x > x) == (x < x) -> 0 + if operands[0] == operands[1]: + self.updater._store(inst, IRLiteral(0)) + return + + is_gt = "g" in opcode + signed = "s" in opcode + + lo, hi = int_bounds(bits=256, signed=signed) + + if not isinstance(operands[0], IRLiteral): + return + + # for comparison operators, we have three special boundary cases: + # almost always, never and almost never. + # almost_always is always true for the non-strict ("ge" and co) + # comparators. for strict comparators ("gt" and co), almost_always + # is true except for one case. never is never true for the strict + # comparators. never is almost always false for the non-strict + # comparators, except for one case. and almost_never is almost + # never true (except one case) for the strict comparators. + if is_gt: + almost_always, never = lo, hi + almost_never = hi - 1 + else: + almost_always, never = hi, lo + almost_never = lo + 1 + + if lit_eq(operands[0], never): + self.updater._store(inst, IRLiteral(0)) + return + + if lit_eq(operands[0], almost_never): + # (lt x 1), (gt x (MAX_UINT256 - 1)), (slt x (MIN_INT256 + 1)) + self.updater._update(inst, "eq", [operands[1], IRLiteral(never)]) + return + + # rewrites. in positions where iszero is preferred, (gt x 5) => (ge x 6) + if prefer_iszero and lit_eq(operands[0], almost_always): + # e.g. gt x 0, slt x MAX_INT256 + tmp = self.updater._add_before(inst, "eq", operands) + self.updater._update(inst, "iszero", [tmp]) + return + + # since push0 was introduced in shanghai, it's potentially + # better to actually reverse this optimization -- i.e. + # replace iszero(iszero(x)) with (gt x 0) + if opcode == "gt" and lit_eq(operands[0], 0): + tmp = self.updater._add_before(inst, "iszero", [operands[1]]) + self.updater._update(inst, "iszero", [tmp]) + return + + # rewrite comparisons by removing an `iszero`, e.g. + # `x > N` -> `x >= (N + 1)` + assert inst.output is not None + uses = self.dfg.get_uses(inst.output) + if len(uses) != 1: + return + + after = uses.first() + if not after.opcode == "iszero": + return + + # peer down the iszero chain to see if it makes sense + # to remove the iszero. (can we simplify this?) + n_uses = self.dfg.get_uses(after.output) + # "assert" inserts an iszero in assembly + if len(n_uses) != 1 or n_uses.first().opcode == "assert": + return + + val = wrap256(operands[0].value, signed=signed) + assert val != never, "unreachable" # sanity + + if is_gt: + val += 1 + else: + # TODO: if resulting val is -1 (0xFF..FF), disable this + # when optimization level == codesize + val -= 1 + + # sanity -- implied by precondition that `val != never` + assert wrap256(val, signed=signed) == val + + new_opcode = flip_comparison_opcode(opcode) + + self.updater._update(inst, new_opcode, [IRLiteral(val), operands[1]]) + + assert len(after.operands) == 1 + self.updater._update(after, "store", after.operands) diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py index 4d1bfe9647..3951ac4455 100644 --- a/vyper/venom/passes/base_pass.py +++ b/vyper/venom/passes/base_pass.py @@ -1,4 +1,4 @@ -from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis import IRAnalysesCache from vyper.venom.function import IRFunction diff --git a/vyper/venom/passes/branch_optimization.py b/vyper/venom/passes/branch_optimization.py index 354aab7900..920dc5e431 100644 --- a/vyper/venom/passes/branch_optimization.py +++ b/vyper/venom/passes/branch_optimization.py @@ -1,4 +1,5 @@ -from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRInstruction from vyper.venom.passes.base_pass import IRPass @@ -14,17 +15,30 @@ def _optimize_branches(self) -> None: if term_inst.opcode != "jnz": continue - prev_inst = self.dfg.get_producing_instruction(term_inst.operands[0]) - if prev_inst.opcode == "iszero": + fst, snd = bb.cfg_out + + fst_liveness = fst.instructions[0].liveness + snd_liveness = snd.instructions[0].liveness + + cost_a, cost_b = len(fst_liveness), len(snd_liveness) + + cond = term_inst.operands[0] + prev_inst = self.dfg.get_producing_instruction(cond) + if cost_a >= cost_b and prev_inst.opcode == "iszero": new_cond = prev_inst.operands[0] term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] - - # Since the DFG update is simple we do in place to avoid invalidating the DFG - # and having to recompute it (which is expensive(er)) - self.dfg.remove_use(prev_inst.output, term_inst) - self.dfg.add_use(new_cond, term_inst) + elif cost_a > cost_b: + new_cond = fn.get_next_variable() + inst = IRInstruction("iszero", [term_inst.operands[0]], output=new_cond) + bb.insert_instruction(inst, index=-1) + term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] def run_pass(self): + self.liveness = self.analyses_cache.request_analysis(LivenessAnalysis) + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) self._optimize_branches() + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(CFGAnalysis) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index f45a60079c..a8d68ad676 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -1,81 +1,130 @@ +from collections import defaultdict + +import vyper.venom.effects as effects from vyper.utils import OrderedSet -from vyper.venom.analysis.dfg import DFGAnalysis -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass class DFTPass(IRPass): function: IRFunction - inst_order: dict[IRInstruction, int] - inst_order_num: int + data_offspring: dict[IRInstruction, OrderedSet[IRInstruction]] + visited_instructions: OrderedSet[IRInstruction] + # "data dependency analysis" + dda: dict[IRInstruction, OrderedSet[IRInstruction]] + # "effect dependency analysis" + eda: dict[IRInstruction, OrderedSet[IRInstruction]] + + def run_pass(self) -> None: + self.data_offspring = {} + self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() - def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: int = 0): - for op in inst.get_outputs(): - assert isinstance(op, IRVariable), f"expected variable, got {op}" - uses = self.dfg.get_uses(op) + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + + for bb in self.function.get_basic_blocks(): + self._process_basic_block(bb) - for uses_this in uses: - if uses_this.parent != inst.parent or uses_this.fence_id != inst.fence_id: - # don't reorder across basic block or fence boundaries - continue + self.analyses_cache.invalidate_analysis(LivenessAnalysis) - # if the instruction is a terminator, we need to place - # it at the end of the basic block - # along with all the instructions that "lead" to it - self._process_instruction_r(bb, uses_this, offset) + def _process_basic_block(self, bb: IRBasicBlock) -> None: + self._calculate_dependency_graphs(bb) + self.instructions = list(bb.pseudo_instructions) + non_phi_instructions = list(bb.non_phi_instructions) + self.visited_instructions = OrderedSet() + for inst in bb.instructions: + self._calculate_data_offspring(inst) + + # Compute entry points in the graph of instruction dependencies + entry_instructions: OrderedSet[IRInstruction] = OrderedSet(non_phi_instructions) + for inst in non_phi_instructions: + to_remove = self.dda.get(inst, OrderedSet()) | self.eda.get(inst, OrderedSet()) + entry_instructions.dropmany(to_remove) + + entry_instructions_list = list(entry_instructions) + + self.visited_instructions = OrderedSet() + for inst in entry_instructions_list: + self._process_instruction_r(self.instructions, inst) + + bb.instructions = self.instructions + assert bb.is_terminated, f"Basic block should be terminated {bb}" + + def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInstruction): if inst in self.visited_instructions: return self.visited_instructions.add(inst) - self.inst_order_num += 1 - - if inst.is_bb_terminator: - offset = len(bb.instructions) - if inst.opcode == "phi": - # phi instructions stay at the beginning of the basic block - # and no input processing is needed - # bb.instructions.append(inst) - self.inst_order[inst] = 0 + if inst.is_pseudo: return - for op in inst.get_input_variables(): - target = self.dfg.get_producing_instruction(op) - assert target is not None, f"no producing instruction for {op}" - if target.parent != inst.parent or target.fence_id != inst.fence_id: - # don't reorder across basic block or fence boundaries - continue - self._process_instruction_r(bb, target, offset) + children = list(self.dda[inst] | self.eda[inst]) - self.inst_order[inst] = self.inst_order_num + offset + def cost(x: IRInstruction) -> int | float: + if x in self.eda[inst] or inst.flippable: + ret = -1 * int(len(self.data_offspring[x]) > 0) + else: + assert x in self.dda[inst] # sanity check + assert x.output is not None # help mypy + ret = inst.operands.index(x.output) + return ret - def _process_basic_block(self, bb: IRBasicBlock) -> None: - self.function.append_basic_block(bb) + # heuristic: sort by size of child dependency graph + orig_children = children.copy() + children.sort(key=cost) - for inst in bb.instructions: - inst.fence_id = self.fence_id - if inst.is_volatile: - self.fence_id += 1 - - # We go throught the instructions and calculate the order in which they should be executed - # based on the data flow graph. This order is stored in the inst_order dictionary. - # We then sort the instructions based on this order. - self.inst_order = {} - self.inst_order_num = 0 - for inst in bb.instructions: - self._process_instruction_r(bb, inst) + if inst.flippable and (orig_children != children): + inst.flip() - bb.instructions.sort(key=lambda x: self.inst_order[x]) + for dep_inst in children: + self._process_instruction_r(instructions, dep_inst) - def run_pass(self) -> None: - self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + instructions.append(inst) - self.fence_id = 0 - self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() + def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None: + # ida: instruction dependency analysis + self.dda = defaultdict(OrderedSet) + self.eda = defaultdict(OrderedSet) - basic_blocks = list(self.function.get_basic_blocks()) + non_phis = list(bb.non_phi_instructions) - self.function.clear_basic_blocks() - for bb in basic_blocks: - self._process_basic_block(bb) + # + # Compute dependency graph + # + last_write_effects: dict[effects.Effects, IRInstruction] = {} + last_read_effects: dict[effects.Effects, IRInstruction] = {} + + for inst in non_phis: + for op in inst.operands: + dep = self.dfg.get_producing_instruction(op) + if dep is not None and dep.parent == bb: + self.dda[inst].add(dep) + + write_effects = inst.get_write_effects() + read_effects = inst.get_read_effects() + + for write_effect in write_effects: + if write_effect in last_read_effects: + self.eda[inst].add(last_read_effects[write_effect]) + last_write_effects[write_effect] = inst + + for read_effect in read_effects: + if read_effect in last_write_effects and last_write_effects[read_effect] != inst: + self.eda[inst].add(last_write_effects[read_effect]) + last_read_effects[read_effect] = inst + + def _calculate_data_offspring(self, inst: IRInstruction): + if inst in self.data_offspring: + return self.data_offspring[inst] + + self.data_offspring[inst] = self.dda[inst].copy() + + deps = self.dda[inst] + for dep_inst in deps: + assert inst.parent == dep_inst.parent + res = self._calculate_data_offspring(dep_inst) + self.data_offspring[inst] |= res + + return self.data_offspring[inst] diff --git a/vyper/venom/passes/float_allocas.py b/vyper/venom/passes/float_allocas.py new file mode 100644 index 0000000000..81fa115645 --- /dev/null +++ b/vyper/venom/passes/float_allocas.py @@ -0,0 +1,36 @@ +from vyper.venom.passes.base_pass import IRPass + + +class FloatAllocas(IRPass): + """ + This pass moves allocas to the entry basic block of a function + We could probably move them to the immediate dominator of the basic + block defining the alloca instead of the entry (which dominates all + basic blocks), but this is done for expedience. + Without this step, sccp fails, possibly because dominators are not + guaranteed to be traversed first. + """ + + def run_pass(self): + entry_bb = self.function.entry + assert entry_bb.is_terminated + tmp = entry_bb.instructions.pop() + + for bb in self.function.get_basic_blocks(): + if bb is entry_bb: + continue + + # Extract alloca instructions + non_alloca_instructions = [] + for inst in bb.instructions: + if inst.opcode in ("alloca", "palloca"): + # note: order of allocas impacts bytecode. + # TODO: investigate. + entry_bb.insert_instruction(inst) + else: + non_alloca_instructions.append(inst) + + # Replace original instructions with filtered list + bb.instructions = non_alloca_instructions + + entry_bb.instructions.append(tmp) diff --git a/vyper/venom/passes/literals_codesize.py b/vyper/venom/passes/literals_codesize.py new file mode 100644 index 0000000000..daf195dfd4 --- /dev/null +++ b/vyper/venom/passes/literals_codesize.py @@ -0,0 +1,58 @@ +from vyper.utils import evm_not +from vyper.venom.basicblock import IRLiteral +from vyper.venom.passes.base_pass import IRPass + +# not takes 1 byte1, so it makes sense to use it when we can save at least +# 1 byte +NOT_THRESHOLD = 1 + +# shl takes 3 bytes, so it makes sense to use it when we can save at least +# 3 bytes +SHL_THRESHOLD = 3 + + +class ReduceLiteralsCodesize(IRPass): + def run_pass(self): + for bb in self.function.get_basic_blocks(): + self._process_bb(bb) + + def _process_bb(self, bb): + for inst in bb.instructions: + if inst.opcode != "store": + continue + + (op,) = inst.operands + if not isinstance(op, IRLiteral): + continue + + val = op.value % (2**256) + + # calculate amount of bits saved by not optimization + not_benefit = ((len(hex(val)) // 2 - len(hex(evm_not(val))) // 2) - NOT_THRESHOLD) * 8 + + # calculate amount of bits saved by shl optimization + binz = bin(val)[2:] + ix = len(binz) - binz.rfind("1") + shl_benefit = ix - SHL_THRESHOLD * 8 + + if not_benefit <= 0 and shl_benefit <= 0: + # no optimization can be done here + continue + + if not_benefit >= shl_benefit: + assert not_benefit > 0 # implied by previous conditions + # transform things like 0xffff...01 to (not 0xfe) + inst.opcode = "not" + op.value = evm_not(val) + continue + else: + assert shl_benefit > 0 # implied by previous conditions + # transform things like 0x123400....000 to 0x1234 << ... + ix -= 1 + # sanity check + assert (val >> ix) << ix == val, val + assert (val >> ix) & 1 == 1, val + + inst.opcode = "shl" + inst.operands = [IRLiteral(val >> ix), IRLiteral(ix)] + continue diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py new file mode 100644 index 0000000000..6701b588fe --- /dev/null +++ b/vyper/venom/passes/load_elimination.py @@ -0,0 +1,50 @@ +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis +from vyper.venom.effects import Effects +from vyper.venom.passes.base_pass import IRPass + + +class LoadElimination(IRPass): + """ + Eliminate sloads, mloads and tloads + """ + + # should this be renamed to EffectsElimination? + + def run_pass(self): + self.equivalence = self.analyses_cache.request_analysis(VarEquivalenceAnalysis) + + for bb in self.function.get_basic_blocks(): + self._process_bb(bb, Effects.MEMORY, "mload", "mstore") + self._process_bb(bb, Effects.TRANSIENT, "tload", "tstore") + self._process_bb(bb, Effects.STORAGE, "sload", "sstore") + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def equivalent(self, op1, op2): + return op1 == op2 or self.equivalence.equivalent(op1, op2) + + def _process_bb(self, bb, eff, load_opcode, store_opcode): + # not really a lattice even though it is not really inter-basic block; + # we may generalize in the future + lattice = () + + for inst in bb.instructions: + if eff in inst.get_write_effects(): + lattice = () + + if inst.opcode == store_opcode: + # mstore [val, ptr] + val, ptr = inst.operands + lattice = (ptr, val) + + if inst.opcode == load_opcode: + prev_lattice = lattice + (ptr,) = inst.operands + lattice = (ptr, inst.output) + if not prev_lattice: + continue + if not self.equivalent(ptr, prev_lattice[0]): + continue + inst.opcode = "store" + inst.operands = [prev_lattice[1]] diff --git a/vyper/venom/passes/lower_dload.py b/vyper/venom/passes/lower_dload.py new file mode 100644 index 0000000000..c863a1b7c7 --- /dev/null +++ b/vyper/venom/passes/lower_dload.py @@ -0,0 +1,42 @@ +from vyper.utils import MemoryPositions +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral +from vyper.venom.passes.base_pass import IRPass + + +class LowerDloadPass(IRPass): + """ + Lower dload and dloadbytes instructions + """ + + def run_pass(self): + for bb in self.function.get_basic_blocks(): + self._handle_bb(bb) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def _handle_bb(self, bb: IRBasicBlock): + fn = bb.parent + for idx, inst in enumerate(bb.instructions): + if inst.opcode == "dload": + (ptr,) = inst.operands + var = fn.get_next_variable() + bb.insert_instruction( + IRInstruction("add", [ptr, IRLabel("code_end")], output=var), index=idx + ) + idx += 1 + dst = IRLiteral(MemoryPositions.FREE_VAR_SPACE) + bb.insert_instruction( + IRInstruction("codecopy", [IRLiteral(32), var, dst]), index=idx + ) + + inst.opcode = "mload" + inst.operands = [dst] + elif inst.opcode == "dloadbytes": + _, src, _ = inst.operands + code_ptr = fn.get_next_variable() + bb.insert_instruction( + IRInstruction("add", [src, IRLabel("code_end")], output=code_ptr), index=idx + ) + inst.opcode = "codecopy" + inst.operands[1] = code_ptr diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py index a803514d8b..ee013e0f1d 100644 --- a/vyper/venom/passes/make_ssa.py +++ b/vyper/venom/passes/make_ssa.py @@ -1,7 +1,5 @@ from vyper.utils import OrderedSet -from vyper.venom.analysis.cfg import CFGAnalysis -from vyper.venom.analysis.dominators import DominatorTreeAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.analysis import CFGAnalysis, DominatorTreeAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IROperand, IRVariable from vyper.venom.passes.base_pass import IRPass @@ -37,8 +35,8 @@ def _add_phi_nodes(self): Add phi nodes to the function. """ self._compute_defs() - work = {var: 0 for var in self.dom.dfs_walk} - has_already = {var: 0 for var in self.dom.dfs_walk} + work = {bb: 0 for bb in self.dom.dfs_walk} + has_already = {bb: 0 for bb in self.dom.dfs_walk} i = 0 # Iterate over all variables @@ -98,7 +96,6 @@ def _rename_vars(self, basic_block: IRBasicBlock): self.var_name_counters[v_name] = i + 1 inst.output = IRVariable(v_name, version=i) - # note - after previous line, inst.output.name != v_name outs.append(inst.output.name) for bb in basic_block.cfg_out: @@ -108,8 +105,9 @@ def _rename_vars(self, basic_block: IRBasicBlock): assert inst.output is not None, "Phi instruction without output" for i, op in enumerate(inst.operands): if op == basic_block.label: + var = inst.operands[i + 1] inst.operands[i + 1] = IRVariable( - inst.output.name, version=self.var_name_stacks[inst.output.name][-1] + var.name, version=self.var_name_stacks[var.name][-1] ) for bb in self.dom.dominated[basic_block]: diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py index f4a37f5abb..9f985e2b0b 100644 --- a/vyper/venom/passes/mem2var.py +++ b/vyper/venom/passes/mem2var.py @@ -1,8 +1,5 @@ -from vyper.utils import OrderedSet -from vyper.venom.analysis.cfg import CFGAnalysis -from vyper.venom.analysis.dfg import DFGAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRInstruction, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -14,7 +11,6 @@ class Mem2Var(IRPass): """ function: IRFunction - defs: dict[IRVariable, OrderedSet[IRBasicBlock]] def run_pass(self): self.analyses_cache.request_analysis(CFGAnalysis) @@ -22,37 +18,69 @@ def run_pass(self): self.var_name_count = 0 for var, inst in dfg.outputs.items(): - if inst.opcode != "alloca": - continue - self._process_alloca_var(dfg, var) + if inst.opcode == "alloca": + self._process_alloca_var(dfg, var) + elif inst.opcode == "palloca": + self._process_palloca_var(dfg, inst, var) self.analyses_cache.invalidate_analysis(DFGAnalysis) self.analyses_cache.invalidate_analysis(LivenessAnalysis) + def _mk_varname(self, varname: str): + varname = varname.removeprefix("%") + varname = f"var{varname}_{self.var_name_count}" + self.var_name_count += 1 + return varname + def _process_alloca_var(self, dfg: DFGAnalysis, var: IRVariable): """ - Process alloca allocated variable. If it is only used by mstore/mload/return - instructions, it is promoted to a stack variable. Otherwise, it is left as is. + Process alloca allocated variable. If it is only used by + mstore/mload/return instructions, it is promoted to a stack variable. + Otherwise, it is left as is. """ uses = dfg.get_uses(var) - if all([inst.opcode == "mload" for inst in uses]): + if not all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]): return - elif all([inst.opcode == "mstore" for inst in uses]): + + var_name = self._mk_varname(var.name) + var = IRVariable(var_name) + for inst in uses: + if inst.opcode == "mstore": + inst.opcode = "store" + inst.output = var + inst.operands = [inst.operands[0]] + elif inst.opcode == "mload": + inst.opcode = "store" + inst.operands = [var] + elif inst.opcode == "return": + bb = inst.parent + idx = len(bb.instructions) - 1 + assert inst == bb.instructions[idx] # sanity + new_inst = IRInstruction("mstore", [var, inst.operands[1]]) + bb.insert_instruction(new_inst, idx) + + def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, var: IRVariable): + """ + Process alloca allocated variable. If it is only used by mstore/mload + instructions, it is promoted to a stack variable. Otherwise, it is left as is. + """ + uses = dfg.get_uses(var) + if not all(inst.opcode in ["mstore", "mload"] for inst in uses): return - elif all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]): - var_name = f"addr{var.name}_{self.var_name_count}" - self.var_name_count += 1 - for inst in uses: - if inst.opcode == "mstore": - inst.opcode = "store" - inst.output = IRVariable(var_name) - inst.operands = [inst.operands[0]] - elif inst.opcode == "mload": - inst.opcode = "store" - inst.operands = [IRVariable(var_name)] - elif inst.opcode == "return": - bb = inst.parent - idx = bb.instructions.index(inst) - bb.insert_instruction( - IRInstruction("mstore", [IRVariable(var_name), inst.operands[1]]), idx - ) + + var_name = self._mk_varname(var.name) + var = IRVariable(var_name) + + # some value given to us by the calling convention + palloca_inst.opcode = "mload" + palloca_inst.operands = [palloca_inst.operands[0]] + palloca_inst.output = var + + for inst in uses: + if inst.opcode == "mstore": + inst.opcode = "store" + inst.output = var + inst.operands = [inst.operands[0]] + elif inst.opcode == "mload": + inst.opcode = "store" + inst.operands = [var] diff --git a/vyper/venom/passes/memmerging.py b/vyper/venom/passes/memmerging.py new file mode 100644 index 0000000000..2e5ee46b84 --- /dev/null +++ b/vyper/venom/passes/memmerging.py @@ -0,0 +1,358 @@ +from bisect import bisect_left +from dataclasses import dataclass + +from vyper.evm.opcodes import version_check +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLiteral, IRVariable +from vyper.venom.effects import Effects +from vyper.venom.passes.base_pass import IRPass + + +@dataclass +class _Interval: + start: int + length: int + + @property + def end(self): + return self.start + self.length + + +@dataclass +class _Copy: + # abstract "copy" operation which contains a list of copy instructions + # and can fuse them into a single copy operation. + dst: int + src: int + length: int + insts: list[IRInstruction] + + @classmethod + def memzero(cls, dst, length, insts): + # factory method to simplify creation of memory zeroing operations + # (which are similar to Copy operations but src is always + # `calldatasize`). choose src=dst, so that can_merge returns True + # for overlapping memzeros. + return cls(dst, dst, length, insts) + + @property + def src_end(self) -> int: + return self.src + self.length + + @property + def dst_end(self) -> int: + return self.dst + self.length + + def src_interval(self) -> _Interval: + return _Interval(self.src, self.length) + + def dst_interval(self) -> _Interval: + return _Interval(self.dst, self.length) + + def overwrites_self_src(self) -> bool: + # return true if dst overlaps src. this is important for blocking + # mcopy batching in certain cases. + return self.overwrites(self.src_interval()) + + def overwrites(self, interval: _Interval) -> bool: + # return true if dst of self overwrites the interval + a = max(self.dst, interval.start) + b = min(self.dst_end, interval.end) + return a < b + + def can_merge(self, other: "_Copy"): + # both source and destination have to be offset by same amount, + # otherwise they do not represent the same copy. e.g. + # Copy(0, 64, 16) + # Copy(11, 74, 16) + if self.src - other.src != self.dst - other.dst: + return False + + # the copies must at least touch each other + if other.dst > self.dst_end: + return False + + return True + + def merge(self, other: "_Copy"): + # merge other into self. e.g. + # Copy(0, 64, 16); Copy(16, 80, 8) => Copy(0, 64, 24) + + assert self.dst <= other.dst, "bad bisect_left" + assert self.can_merge(other) + + new_length = max(self.dst_end, other.dst_end) - self.dst + self.length = new_length + self.insts.extend(other.insts) + + def __repr__(self) -> str: + return f"({self.src}, {self.src_end}, {self.length}, {self.dst}, {self.dst_end})" + + +class MemMergePass(IRPass): + dfg: DFGAnalysis + _copies: list[_Copy] + _loads: dict[IRVariable, int] + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore + + for bb in self.function.get_basic_blocks(): + self._handle_bb_memzero(bb) + self._handle_bb(bb, "calldataload", "calldatacopy", allow_dst_overlaps_src=True) + self._handle_bb(bb, "dload", "dloadbytes", allow_dst_overlaps_src=True) + + if version_check(begin="cancun"): + # mcopy is available + self._handle_bb(bb, "mload", "mcopy") + + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + + def _optimize_copy(self, bb: IRBasicBlock, copy_opcode: str, load_opcode: str): + for copy in self._copies: + copy.insts.sort(key=bb.instructions.index) + + if copy_opcode == "mcopy": + assert not copy.overwrites_self_src() + + pin_inst = None + inst = copy.insts[-1] + if copy.length != 32 or load_opcode == "dload": + inst.output = None + inst.opcode = copy_opcode + inst.operands = [IRLiteral(copy.length), IRLiteral(copy.src), IRLiteral(copy.dst)] + elif inst.opcode == "mstore": + # we already have a load which is the val for this mstore; + # leave it in place. + var, _ = inst.operands + assert isinstance(var, IRVariable) # help mypy + pin_inst = self.dfg.get_producing_instruction(var) + assert pin_inst is not None # help mypy + + else: + # we are converting an mcopy into an mload+mstore (mload+mstore + # is 1 byte smaller than mcopy). + index = inst.parent.instructions.index(inst) + var = bb.parent.get_next_variable() + load = IRInstruction(load_opcode, [IRLiteral(copy.src)], output=var) + inst.parent.insert_instruction(load, index) + + inst.output = None + inst.opcode = "mstore" + inst.operands = [var, IRLiteral(copy.dst)] + + for inst in copy.insts[:-1]: + if inst.opcode == load_opcode: + if inst is pin_inst: + continue + + # if the load is used by any instructions besides the ones + # we are removing, we can't delete it. (in the future this + # may be handled by "remove unused effects" pass). + assert isinstance(inst.output, IRVariable) # help mypy + uses = self.dfg.get_uses(inst.output) + if not all(use in copy.insts for use in uses): + continue + + bb.mark_for_removal(inst) + + self._copies.clear() + self._loads.clear() + + def _write_after_write_hazard(self, new_copy: _Copy) -> bool: + for copy in self._copies: + # note, these are the same: + # - new_copy.overwrites(copy.dst_interval()) + # - copy.overwrites(new_copy.dst_interval()) + if new_copy.overwrites(copy.dst_interval()) and not ( + copy.can_merge(new_copy) or new_copy.can_merge(copy) + ): + return True + return False + + def _read_after_write_hazard(self, new_copy: _Copy) -> bool: + new_copies = self._copies + [new_copy] + + # new copy would overwrite memory that + # needs to be read to optimize copy + if any(new_copy.overwrites(copy.src_interval()) for copy in new_copies): + return True + + # existing copies would overwrite memory that the + # new copy would need + if self._overwrites(new_copy.src_interval()): + return True + + return False + + def _find_insertion_point(self, new_copy: _Copy): + return bisect_left(self._copies, new_copy.dst, key=lambda c: c.dst) + + def _add_copy(self, new_copy: _Copy): + index = self._find_insertion_point(new_copy) + self._copies.insert(index, new_copy) + + i = max(index - 1, 0) + while i < min(index + 1, len(self._copies) - 1): + if self._copies[i].can_merge(self._copies[i + 1]): + self._copies[i].merge(self._copies[i + 1]) + del self._copies[i + 1] + else: + i += 1 + + def _overwrites(self, read_interval: _Interval) -> bool: + # check if any of self._copies tramples the interval + + # could use bisect_left to optimize, but it's harder to reason about + return any(c.overwrites(read_interval) for c in self._copies) + + def _handle_bb( + self, + bb: IRBasicBlock, + load_opcode: str, + copy_opcode: str, + allow_dst_overlaps_src: bool = False, + ): + self._loads = {} + self._copies = [] + + def _barrier(): + self._optimize_copy(bb, copy_opcode, load_opcode) + + # copy in necessary because there is a possibility + # of insertion in optimizations + for inst in bb.instructions.copy(): + if inst.opcode == load_opcode: + src_op = inst.operands[0] + if not isinstance(src_op, IRLiteral): + _barrier() + continue + + read_interval = _Interval(src_op.value, 32) + + # we will read from this memory so we need to put barier + if not allow_dst_overlaps_src and self._overwrites(read_interval): + _barrier() + + assert inst.output is not None + self._loads[inst.output] = src_op.value + + elif inst.opcode == "mstore": + var, dst = inst.operands + + if not isinstance(var, IRVariable) or not isinstance(dst, IRLiteral): + _barrier() + continue + + if var not in self._loads: + _barrier() + continue + + src_ptr = self._loads[var] + load_inst = self.dfg.get_producing_instruction(var) + assert load_inst is not None # help mypy + n_copy = _Copy(dst.value, src_ptr, 32, [inst, load_inst]) + + if self._write_after_write_hazard(n_copy): + _barrier() + # no continue needed, we have not invalidated the loads dict + + # check if the new copy does not overwrites existing data + if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy): + _barrier() + # this continue is necessary because we have invalidated + # the _loads dict, so src_ptr is no longer valid. + continue + self._add_copy(n_copy) + + elif inst.opcode == copy_opcode: + if not all(isinstance(op, IRLiteral) for op in inst.operands): + _barrier() + continue + + length, src, dst = inst.operands + n_copy = _Copy(dst.value, src.value, length.value, [inst]) + + if self._write_after_write_hazard(n_copy): + _barrier() + # check if the new copy does not overwrites existing data + if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy): + _barrier() + self._add_copy(n_copy) + + elif _volatile_memory(inst): + _barrier() + + _barrier() + bb.clear_dead_instructions() + + # optimize memzeroing operations + def _optimize_memzero(self, bb: IRBasicBlock): + for copy in self._copies: + inst = copy.insts[-1] + if copy.length == 32: + inst.opcode = "mstore" + inst.operands = [IRLiteral(0), IRLiteral(copy.dst)] + else: + index = bb.instructions.index(inst) + calldatasize = bb.parent.get_next_variable() + bb.insert_instruction(IRInstruction("calldatasize", [], output=calldatasize), index) + + inst.output = None + inst.opcode = "calldatacopy" + inst.operands = [IRLiteral(copy.length), calldatasize, IRLiteral(copy.dst)] + + for inst in copy.insts[:-1]: + bb.mark_for_removal(inst) + + self._copies.clear() + self._loads.clear() + + def _handle_bb_memzero(self, bb: IRBasicBlock): + self._loads = {} + self._copies = [] + + def _barrier(): + self._optimize_memzero(bb) + + # copy in necessary because there is a possibility + # of insertion in optimizations + for inst in bb.instructions.copy(): + if inst.opcode == "mstore": + val = inst.operands[0] + dst = inst.operands[1] + is_zero_literal = isinstance(val, IRLiteral) and val.value == 0 + if not (isinstance(dst, IRLiteral) and is_zero_literal): + _barrier() + continue + n_copy = _Copy.memzero(dst.value, 32, [inst]) + assert not self._write_after_write_hazard(n_copy) + self._add_copy(n_copy) + elif inst.opcode == "calldatacopy": + length, var, dst = inst.operands + if not isinstance(var, IRVariable): + _barrier() + continue + if not isinstance(dst, IRLiteral) or not isinstance(length, IRLiteral): + _barrier() + continue + src_inst = self.dfg.get_producing_instruction(var) + assert src_inst is not None, f"bad variable {var}" + if src_inst.opcode != "calldatasize": + _barrier() + continue + n_copy = _Copy.memzero(dst.value, length.value, [inst]) + assert not self._write_after_write_hazard(n_copy) + self._add_copy(n_copy) + elif _volatile_memory(inst): + _barrier() + continue + + _barrier() + bb.clear_dead_instructions() + + +def _volatile_memory(inst): + inst_effects = inst.get_read_effects() | inst.get_write_effects() + return Effects.MEMORY in inst_effects or Effects.MSIZE in inst_effects diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index cf44c3cf89..37ba1023c9 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,5 +1,5 @@ from vyper.exceptions import CompilerPanic -from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis import CFGAnalysis from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.passes.base_pass import IRPass @@ -45,9 +45,10 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB inst.operands[i] = split_bb.label # Update the labels in the data segment - for inst in fn.ctx.data_segment: - if inst.opcode == "db" and inst.operands[0] == bb.label: - inst.operands[0] = split_bb.label + for data_section in fn.ctx.data_segment: + for item in data_section.data_items: + if item.data == bb.label: + item.data = split_bb.label return split_bb diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py index be9c1ed535..73fe2112d7 100644 --- a/vyper/venom/passes/remove_unused_variables.py +++ b/vyper/venom/passes/remove_unused_variables.py @@ -1,6 +1,5 @@ -from vyper.utils import OrderedSet -from vyper.venom.analysis.dfg import DFGAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.utils import OrderedSet, uniq +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRInstruction from vyper.venom.passes.base_pass import IRPass @@ -27,6 +26,7 @@ def run_pass(self): self._process_instruction(inst) self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) def _process_instruction(self, inst): if inst.output is None: @@ -37,7 +37,7 @@ def _process_instruction(self, inst): if len(uses) > 0: return - for operand in inst.get_input_variables(): + for operand in uniq(inst.get_input_variables()): self.dfg.remove_use(operand, inst) new_uses = self.dfg.get_uses(operand) self.work_list.addmany(new_uses) diff --git a/vyper/venom/passes/sccp/eval.py b/vyper/venom/passes/sccp/eval.py index b5786bb304..4481c77c07 100644 --- a/vyper/venom/passes/sccp/eval.py +++ b/vyper/venom/passes/sccp/eval.py @@ -5,11 +5,12 @@ SizeLimits, evm_div, evm_mod, + evm_not, evm_pow, signed_to_unsigned, unsigned_to_signed, ) -from vyper.venom.basicblock import IROperand +from vyper.venom.basicblock import IRLiteral def _unsigned_to_signed(value: int) -> int: @@ -23,7 +24,7 @@ def _signed_to_unsigned(value: int) -> int: def _wrap_signed_binop(operation): - def wrapper(ops: list[IROperand]) -> int: + def wrapper(ops: list[IRLiteral]) -> int: assert len(ops) == 2 first = _unsigned_to_signed(ops[1].value) second = _unsigned_to_signed(ops[0].value) @@ -33,21 +34,23 @@ def wrapper(ops: list[IROperand]) -> int: def _wrap_binop(operation): - def wrapper(ops: list[IROperand]) -> int: + def wrapper(ops: list[IRLiteral]) -> int: assert len(ops) == 2 first = _signed_to_unsigned(ops[1].value) second = _signed_to_unsigned(ops[0].value) ret = operation(first, second) + # TODO: use wrap256 here return ret & SizeLimits.MAX_UINT256 return wrapper def _wrap_unop(operation): - def wrapper(ops: list[IROperand]) -> int: + def wrapper(ops: list[IRLiteral]) -> int: assert len(ops) == 1 value = _signed_to_unsigned(ops[0].value) ret = operation(value) + # TODO: use wrap256 here return ret & SizeLimits.MAX_UINT256 return wrapper @@ -86,6 +89,7 @@ def _evm_shl(shift_len: int, value: int) -> int: if shift_len >= 256: return 0 assert shift_len >= 0 + # TODO: refactor to use wrap256 return (value << shift_len) & SizeLimits.MAX_UINT256 @@ -95,12 +99,7 @@ def _evm_sar(shift_len: int, value: int) -> int: return value >> shift_len -def _evm_not(value: int) -> int: - assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" - return SizeLimits.MAX_UINT256 ^ value - - -ARITHMETIC_OPS: dict[str, Callable[[list[IROperand]], int]] = { +ARITHMETIC_OPS: dict[str, Callable[[list[IRLiteral]], int]] = { "add": _wrap_binop(operator.add), "sub": _wrap_binop(operator.sub), "mul": _wrap_binop(operator.mul), @@ -110,23 +109,23 @@ def _evm_not(value: int) -> int: "smod": _wrap_signed_binop(evm_mod), "exp": _wrap_binop(evm_pow), "eq": _wrap_binop(operator.eq), - "ne": _wrap_binop(operator.ne), "lt": _wrap_binop(operator.lt), - "le": _wrap_binop(operator.le), "gt": _wrap_binop(operator.gt), - "ge": _wrap_binop(operator.ge), "slt": _wrap_signed_binop(operator.lt), - "sle": _wrap_signed_binop(operator.le), "sgt": _wrap_signed_binop(operator.gt), - "sge": _wrap_signed_binop(operator.ge), "or": _wrap_binop(operator.or_), "and": _wrap_binop(operator.and_), "xor": _wrap_binop(operator.xor), - "not": _wrap_unop(_evm_not), + "not": _wrap_unop(evm_not), "signextend": _wrap_binop(_evm_signextend), "iszero": _wrap_unop(_evm_iszero), "shr": _wrap_binop(_evm_shr), "shl": _wrap_binop(_evm_shl), "sar": _wrap_signed_binop(_evm_sar), - "store": lambda ops: ops[0].value, + "store": _wrap_unop(lambda ops: ops[0].value), } + + +def eval_arith(opcode: str, ops: list[IRLiteral]) -> int: + fn = ARITHMETIC_OPS[opcode] + return fn(ops) diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 164d8e241d..cee455e031 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -5,9 +5,7 @@ from vyper.exceptions import CompilerPanic, StaticAssertionException from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.cfg import CFGAnalysis -from vyper.venom.analysis.dominators import DominatorTreeAnalysis +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, IRAnalysesCache, LivenessAnalysis from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -18,7 +16,7 @@ ) from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass -from vyper.venom.passes.sccp.eval import ARITHMETIC_OPS +from vyper.venom.passes.sccp.eval import ARITHMETIC_OPS, eval_arith class LatticeEnum(Enum): @@ -52,13 +50,13 @@ class SCCP(IRPass): """ fn: IRFunction - dom: DominatorTreeAnalysis - uses: dict[IRVariable, OrderedSet[IRInstruction]] + dfg: DFGAnalysis lattice: Lattice work_list: list[WorkListItem] - cfg_dirty: bool cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + cfg_dirty: bool + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): super().__init__(analyses_cache, function) self.lattice = {} @@ -67,14 +65,15 @@ def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): def run_pass(self): self.fn = self.function - self.dom = self.analyses_cache.request_analysis(DominatorTreeAnalysis) - self._compute_uses() + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore + self._calculate_sccp(self.fn.entry) self._propagate_constants() - - # self._propagate_variables() - - self.analyses_cache.invalidate_analysis(CFGAnalysis) + if self.cfg_dirty: + self.analyses_cache.force_analysis(CFGAnalysis) + self.fn.remove_unreachable_blocks() + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) def _calculate_sccp(self, entry: IRBasicBlock): """ @@ -93,7 +92,7 @@ def _calculate_sccp(self, entry: IRBasicBlock): self.work_list.append(FlowWorkItem(dummy, entry)) # Initialize the lattice with TOP values for all variables - for v in self.uses.keys(): + for v in self.dfg._dfg_outputs: self.lattice[v] = LatticeEnum.TOP # Iterate over the work list until it is empty @@ -144,7 +143,7 @@ def _handle_SSA_work_item(self, work_item: SSAWorkListItem): self._visit_expr(work_item.inst) def _lookup_from_lattice(self, op: IROperand) -> LatticeItem: - assert isinstance(op, IRVariable), "Can't get lattice for non-variable" + assert isinstance(op, IRVariable), f"Can't get lattice for non-variable ({op})" lat = self.lattice[op] assert lat is not None, f"Got undefined var {op}" return lat @@ -178,7 +177,7 @@ def _visit_phi(self, inst: IRInstruction): def _visit_expr(self, inst: IRInstruction): opcode = inst.opcode - if opcode in ["store", "alloca"]: + if opcode in ["store", "alloca", "palloca"]: assert inst.output is not None, "Got store/alloca without output" out = self._eval_from_lattice(inst.operands[0]) self._set_lattice(inst.output, out) @@ -228,63 +227,52 @@ def _eval(self, inst) -> LatticeItem: instruction to the SSA work list if the knowledge about the variable changed. """ - opcode = inst.opcode - ops = [] + def finalize(ret): + # Update the lattice if the value changed + old_val = self.lattice.get(inst.output, LatticeEnum.TOP) + if old_val != ret: + self.lattice[inst.output] = ret + self._add_ssa_work_items(inst) + return ret + + opcode = inst.opcode + ops: list[IRLiteral] = [] for op in inst.operands: - if isinstance(op, IRVariable): - ops.append(self.lattice[op]) - elif isinstance(op, IRLabel): - return LatticeEnum.BOTTOM + # Evaluate the operand according to the lattice + if isinstance(op, IRLabel): + return finalize(LatticeEnum.BOTTOM) + elif isinstance(op, IRVariable): + eval_result = self.lattice[op] else: - ops.append(op) + eval_result = op - ret = None - if LatticeEnum.BOTTOM in ops: - ret = LatticeEnum.BOTTOM - else: - if opcode in ARITHMETIC_OPS: - fn = ARITHMETIC_OPS[opcode] - ret = IRLiteral(fn(ops)) # type: ignore - elif len(ops) > 0: - ret = ops[0] # type: ignore - else: - raise CompilerPanic("Bad constant evaluation") + # The value from the lattice should have evaluated to BOTTOM + # or a literal by now. + # If any operand is BOTTOM, the whole operation is BOTTOM + # and we can stop the evaluation early + if eval_result is LatticeEnum.BOTTOM: + return finalize(LatticeEnum.BOTTOM) - old_val = self.lattice.get(inst.output, LatticeEnum.TOP) - if old_val != ret: - self.lattice[inst.output] = ret # type: ignore - self._add_ssa_work_items(inst) + assert isinstance(eval_result, IRLiteral), (inst.parent.label, op, inst) + ops.append(eval_result) - return ret # type: ignore + # If we haven't found BOTTOM yet, evaluate the operation + assert all(isinstance(op, IRLiteral) for op in ops) + res = IRLiteral(eval_arith(opcode, ops)) + return finalize(res) def _add_ssa_work_items(self, inst: IRInstruction): - for target_inst in self._get_uses(inst.output): # type: ignore + for target_inst in self.dfg.get_uses(inst.output): # type: ignore self.work_list.append(SSAWorkListItem(target_inst)) - def _compute_uses(self): - """ - This method computes the uses for each variable in the IR. - It iterates over the dominator tree and collects all the - instructions that use each variable. - """ - self.uses = {} - for bb in self.dom.dfs_walk: - for var, insts in bb.get_uses().items(): - self._get_uses(var).update(insts) - - def _get_uses(self, var: IRVariable): - if var not in self.uses: - self.uses[var] = OrderedSet() - return self.uses[var] - def _propagate_constants(self): """ This method iterates over the IR and replaces constant values with their actual values. It also replaces conditional jumps with unconditional jumps if the condition is a constant value. """ - for bb in self.dom.dfs_walk: + for bb in self.function.get_basic_blocks(): for inst in bb.instructions: self._replace_constants(inst) @@ -304,6 +292,7 @@ def _replace_constants(self, inst: IRInstruction): target = inst.operands[1] inst.opcode = "jmp" inst.operands = [target] + self.cfg_dirty = True elif inst.opcode in ("assert", "assert_unreachable"): @@ -312,14 +301,13 @@ def _replace_constants(self, inst: IRInstruction): if isinstance(lat, IRLiteral): if lat.value > 0: inst.opcode = "nop" + inst.operands = [] else: raise StaticAssertionException( f"assertion found to fail at compile time ({inst.error_msg}).", inst.get_ast_source(), ) - inst.operands = [] - elif inst.opcode == "phi": return diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index 08582fee96..10535c2144 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -1,6 +1,6 @@ from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis import CFGAnalysis from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.passes.base_pass import IRPass @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass): visited: OrderedSet def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): - a.instructions.pop() + a.instructions.pop() # pop terminating instruction for inst in b.instructions: - assert inst.opcode != "phi", "Not implemented yet" - if inst.opcode == "phi": - a.instructions.insert(0, inst) - else: - inst.parent = a - a.instructions.append(inst) + assert inst.opcode != "phi", f"Instruction should never be phi {b}" + inst.parent = a + a.instructions.append(inst) # Update CFG a.cfg_out = b.cfg_out - if len(b.cfg_out) > 0: - next_bb = b.cfg_out.first() + + for next_bb in a.cfg_out: next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) for inst in next_bb.instructions: + # assume phi instructions are at beginning of bb if inst.opcode != "phi": break inst.operands[inst.operands.index(b.label)] = a.label @@ -124,16 +122,16 @@ def run_pass(self): for _ in range(fn.num_basic_blocks): changes = self._optimize_empty_basicblocks() + self.analyses_cache.force_analysis(CFGAnalysis) changes += fn.remove_unreachable_blocks() if changes == 0: break else: raise CompilerPanic("Too many iterations removing empty basic blocks") - self.analyses_cache.force_analysis(CFGAnalysis) - for _ in range(fn.num_basic_blocks): # essentially `while True` self._collapse_chained_blocks(entry) + self.analyses_cache.force_analysis(CFGAnalysis) if fn.remove_unreachable_blocks() == 0: break else: diff --git a/vyper/venom/passes/stack_reorder.py b/vyper/venom/passes/stack_reorder.py deleted file mode 100644 index a92fe0e626..0000000000 --- a/vyper/venom/passes/stack_reorder.py +++ /dev/null @@ -1,23 +0,0 @@ -from vyper.utils import OrderedSet -from vyper.venom.basicblock import IRBasicBlock -from vyper.venom.passes.base_pass import IRPass - - -class StackReorderPass(IRPass): - visited: OrderedSet - - def _reorder_stack(self): - pass - - def _visit(self, bb: IRBasicBlock): - if bb in self.visited: - return - self.visited.add(bb) - - for bb_out in bb.cfg_out: - self._visit(bb_out) - - def _run_pass(self): - entry = self.function.entry - self.visited = OrderedSet() - self._visit(entry) diff --git a/vyper/venom/passes/store_elimination.py b/vyper/venom/passes/store_elimination.py index 17b9ce995a..a4f217505b 100644 --- a/vyper/venom/passes/store_elimination.py +++ b/vyper/venom/passes/store_elimination.py @@ -1,6 +1,4 @@ -from vyper.venom.analysis.cfg import CFGAnalysis -from vyper.venom.analysis.dfg import DFGAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRVariable from vyper.venom.passes.base_pass import IRPass @@ -11,37 +9,37 @@ class StoreElimination(IRPass): and removes the `store` instruction. """ + # TODO: consider renaming `store` instruction, since it is confusing + # with LoadElimination + def run_pass(self): - self.analyses_cache.request_analysis(CFGAnalysis) - dfg = self.analyses_cache.request_analysis(DFGAnalysis) + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) - for var, inst in dfg.outputs.items(): + for var, inst in self.dfg.outputs.items(): if inst.opcode != "store": continue - if not isinstance(inst.operands[0], IRVariable): - continue - if inst.operands[0].name in ["%ret_ofst", "%ret_size"]: - continue - if inst.output.name in ["%ret_ofst", "%ret_size"]: - continue - self._process_store(dfg, inst, var, inst.operands[0]) + self._process_store(inst, var, inst.operands[0]) self.analyses_cache.invalidate_analysis(LivenessAnalysis) self.analyses_cache.invalidate_analysis(DFGAnalysis) - def _process_store(self, dfg, inst, var, new_var): + def _process_store(self, inst, var: IRVariable, new_var: IRVariable): """ Process store instruction. If the variable is only used by a load instruction, forward the variable to the load instruction. """ - uses = dfg.get_uses(var) + if any([inst.opcode == "phi" for inst in self.dfg.get_uses(new_var)]): + return + uses = self.dfg.get_uses(var) if any([inst.opcode == "phi" for inst in uses]): return - - for use_inst in uses: + for use_inst in uses.copy(): for i, operand in enumerate(use_inst.operands): if operand == var: use_inst.operands[i] = new_var + self.dfg.add_use(new_var, use_inst) + self.dfg.remove_use(var, use_inst) + inst.parent.remove_instruction(inst) diff --git a/vyper/venom/passes/extract_literals.py b/vyper/venom/passes/store_expansion.py similarity index 68% rename from vyper/venom/passes/extract_literals.py rename to vyper/venom/passes/store_expansion.py index b8e042b357..be5eb3d95d 100644 --- a/vyper/venom/passes/extract_literals.py +++ b/vyper/venom/passes/store_expansion.py @@ -1,12 +1,12 @@ -from vyper.venom.analysis.dfg import DFGAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis -from vyper.venom.basicblock import IRInstruction, IRLiteral +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRInstruction, IRLiteral, IRVariable from vyper.venom.passes.base_pass import IRPass -class ExtractLiteralsPass(IRPass): +class StoreExpansionPass(IRPass): """ - This pass extracts literals so that they can be reordered by the DFT pass + This pass extracts literals and variables so that they can be + reordered by the DFT pass """ def run_pass(self): @@ -20,7 +20,7 @@ def _process_bb(self, bb): i = 0 while i < len(bb.instructions): inst = bb.instructions[i] - if inst.opcode == "store": + if inst.opcode in ("store", "offset", "phi", "param"): i += 1 continue @@ -29,9 +29,11 @@ def _process_bb(self, bb): if inst.opcode == "log" and j == 0: continue - if isinstance(op, IRLiteral): + if isinstance(op, (IRVariable, IRLiteral)): var = self.function.get_next_variable() to_insert = IRInstruction("store", [op], var) bb.insert_instruction(to_insert, index=i) inst.operands[j] = var + i += 1 + i += 1 diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py index a98e5bb25b..e284b41fb2 100644 --- a/vyper/venom/stack_model.py +++ b/vyper/venom/stack_model.py @@ -30,7 +30,7 @@ def push(self, op: IROperand) -> None: def pop(self, num: int = 1) -> None: del self._stack[len(self._stack) - num :] - def get_depth(self, op: IROperand, n: int = 1) -> int: + def get_depth(self, op: IROperand) -> int: """ Returns the depth of the n-th matching operand in the stack map. If the operand is not in the stack map, returns NOT_IN_STACK. @@ -39,10 +39,7 @@ def get_depth(self, op: IROperand, n: int = 1) -> int: for i, stack_op in enumerate(reversed(self._stack)): if stack_op.value == op.value: - if n <= 1: - return -i - else: - n -= 1 + return -i return StackModel.NOT_IN_STACK # type: ignore diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 51fac10134..59734773af 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -1,4 +1,3 @@ -from collections import Counter from typing import Any from vyper.exceptions import CompilerPanic, StackTooDeep @@ -10,10 +9,13 @@ mksymbol, optimize_assembly, ) -from vyper.utils import MemoryPositions, OrderedSet -from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.dup_requirements import DupRequirementsAnalysis -from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.utils import MemoryPositions, OrderedSet, wrap256 +from vyper.venom.analysis import ( + CFGAnalysis, + IRAnalysesCache, + LivenessAnalysis, + VarEquivalenceAnalysis, +) from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -23,9 +25,13 @@ IRVariable, ) from vyper.venom.context import IRContext -from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes import NormalizationPass from vyper.venom.stack_model import StackModel +DEBUG_SHOW_COST = False +if DEBUG_SHOW_COST: + import sys + # instructions which map one-to-one from venom to EVM _ONE_TO_ONE_INSTRUCTIONS = frozenset( [ @@ -35,6 +41,7 @@ "calldatacopy", "mcopy", "calldataload", + "codecopy", "gas", "gasprice", "gaslimit", @@ -102,9 +109,6 @@ ] ) -COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"]) - - _REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] @@ -118,6 +122,11 @@ def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: return ret # type: ignore +def _as_asm_symbol(label: IRLabel) -> str: + # Lower an IRLabel to an assembly symbol + return f"_sym_{label.value}" + + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. # There, recursing into the deploy instruction was by design, and @@ -153,7 +162,8 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: NormalizationPass(ac, fn).run_pass() self.liveness_analysis = ac.request_analysis(LivenessAnalysis) - ac.request_analysis(DupRequirementsAnalysis) + self.equivalence = ac.request_analysis(VarEquivalenceAnalysis) + ac.request_analysis(CFGAnalysis) assert fn.normalized, "Non-normalized CFG!" @@ -178,19 +188,19 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: asm.extend(_REVERT_POSTAMBLE) # Append data segment - data_segments: dict = dict() - for inst in ctx.data_segment: - if inst.opcode == "dbname": - label = inst.operands[0].value - data_segments[label] = [DataHeader(f"_sym_{label}")] - elif inst.opcode == "db": - data = inst.operands[0] + for data_section in ctx.data_segment: + label = data_section.label + asm_data_section: list[Any] = [] + asm_data_section.append(DataHeader(_as_asm_symbol(label))) + for item in data_section.data_items: + data = item.data if isinstance(data, IRLabel): - data_segments[label].append(f"_sym_{data.value}") + asm_data_section.append(_as_asm_symbol(data)) else: - data_segments[label].append(data) + assert isinstance(data, bytes) + asm_data_section.append(data) - asm.extend(list(data_segments.values())) + asm.append(asm_data_section) if no_optimize is False: optimize_assembly(top_asm) @@ -200,21 +210,19 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: def _stack_reorder( self, assembly: list, stack: StackModel, stack_ops: list[IROperand], dry_run: bool = False ) -> int: - cost = 0 - if dry_run: assert len(assembly) == 0, "Dry run should not work on assembly" stack = stack.copy() - stack_ops_count = len(stack_ops) + if len(stack_ops) == 0: + return 0 - counts = Counter(stack_ops) + assert len(stack_ops) == len(set(stack_ops)) # precondition - for i in range(stack_ops_count): - op = stack_ops[i] - final_stack_depth = -(stack_ops_count - i - 1) - depth = stack.get_depth(op, counts[op]) # type: ignore - counts[op] -= 1 + cost = 0 + for i, op in enumerate(stack_ops): + final_stack_depth = -(len(stack_ops) - i - 1) + depth = stack.get_depth(op) if depth == StackModel.NOT_IN_STACK: raise CompilerPanic(f"Variable {op} not in stack") @@ -222,36 +230,42 @@ def _stack_reorder( if depth == final_stack_depth: continue - if op == stack.peek(final_stack_depth): + to_swap = stack.peek(final_stack_depth) + if self.equivalence.equivalent(op, to_swap): + # perform a "virtual" swap + stack.poke(final_stack_depth, op) + stack.poke(depth, to_swap) continue cost += self.swap(assembly, stack, depth) cost += self.swap(assembly, stack, final_stack_depth) + assert stack._stack[-len(stack_ops) :] == stack_ops, (stack, stack_ops) + return cost def _emit_input_operands( - self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + self, + assembly: list, + inst: IRInstruction, + ops: list[IROperand], + stack: StackModel, + next_liveness: OrderedSet[IRVariable], ) -> None: # PRE: we already have all the items on the stack that have # been scheduled to be killed. now it's just a matter of emitting # SWAPs, DUPs and PUSHes until we match the `ops` argument - # dumb heuristic: if the top of stack is not wanted here, swap - # it with something that is wanted - if ops and stack.height > 0 and stack.peek(0) not in ops: - for op in ops: - if isinstance(op, IRVariable) and op not in inst.dup_requirements: - self.swap_op(assembly, stack, op) - break + # to validate store expansion invariant - + # each op is emitted at most once. + seen: set[IROperand] = set() - emitted_ops = OrderedSet[IROperand]() for op in ops: if isinstance(op, IRLabel): - # invoke emits the actual instruction itself so we don't need to emit it here - # but we need to add it to the stack map + # invoke emits the actual instruction itself so we don't need + # to emit it here but we need to add it to the stack map if inst.opcode != "invoke": - assembly.append(f"_sym_{op.value}") + assembly.append(_as_asm_symbol(op)) stack.push(op) continue @@ -260,17 +274,16 @@ def _emit_input_operands( raise Exception(f"Value too low: {op.value}") elif op.value >= 2**256: raise Exception(f"Value too high: {op.value}") - assembly.extend(PUSH(op.value % 2**256)) + assembly.extend(PUSH(wrap256(op.value))) stack.push(op) continue - if op in inst.dup_requirements and op not in emitted_ops: - self.dup_op(assembly, stack, op) - - if op in emitted_ops: + if op in next_liveness: self.dup_op(assembly, stack, op) - emitted_ops.add(op) + # guaranteed by store expansion + assert op not in seen, (op, seen) + seen.add(op) def _generate_evm_for_basicblock_r( self, asm: list, basicblock: IRBasicBlock, stack: StackModel @@ -279,39 +292,36 @@ def _generate_evm_for_basicblock_r( return self.visited_basicblocks.add(basicblock) + if DEBUG_SHOW_COST: + print(basicblock, file=sys.stderr) + + ref = asm + asm = [] + # assembly entry point into the block - asm.append(f"_sym_{basicblock.label}") + asm.append(_as_asm_symbol(basicblock.label)) asm.append("JUMPDEST") - self.clean_stack_from_cfg_in(asm, basicblock, stack) + if len(basicblock.cfg_in) == 1: + self.clean_stack_from_cfg_in(asm, basicblock, stack) - param_insts = [inst for inst in basicblock.instructions if inst.opcode == "param"] - main_insts = [inst for inst in basicblock.instructions if inst.opcode != "param"] + all_insts = sorted(basicblock.instructions, key=lambda x: x.opcode != "param") - for inst in param_insts: - asm.extend(self._generate_evm_for_instruction(inst, stack)) + for i, inst in enumerate(all_insts): + next_liveness = ( + all_insts[i + 1].liveness if i + 1 < len(all_insts) else basicblock.out_vars + ) - self._clean_unused_params(asm, basicblock, stack) + asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) - for i, inst in enumerate(main_insts): - next_liveness = main_insts[i + 1].liveness if i + 1 < len(main_insts) else OrderedSet() + if DEBUG_SHOW_COST: + print(" ".join(map(str, asm)), file=sys.stderr) + print("\n", file=sys.stderr) - asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) + ref.extend(asm) - for bb in basicblock.reachable: - self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) - - def _clean_unused_params(self, asm: list, bb: IRBasicBlock, stack: StackModel) -> None: - for i, inst in enumerate(bb.instructions): - if inst.opcode != "param": - break - if inst.is_volatile and i + 1 < len(bb.instructions): - liveness = bb.instructions[i + 1].liveness - if inst.output is not None and inst.output not in liveness: - depth = stack.get_depth(inst.output) - if depth != 0: - self.swap(asm, stack, depth) - self.pop(asm, stack) + for bb in basicblock.cfg_out: + self._generate_evm_for_basicblock_r(ref, bb, stack.copy()) # pop values from stack at entry to bb # note this produces the same result(!) no matter which basic block @@ -319,36 +329,37 @@ def _clean_unused_params(self, asm: list, bb: IRBasicBlock, stack: StackModel) - def clean_stack_from_cfg_in( self, asm: list, basicblock: IRBasicBlock, stack: StackModel ) -> None: - if len(basicblock.cfg_in) == 0: - return - - to_pop = OrderedSet[IRVariable]() - for in_bb in basicblock.cfg_in: - # inputs is the input variables we need from in_bb - inputs = self.liveness_analysis.input_vars_from(in_bb, basicblock) - - # layout is the output stack layout for in_bb (which works - # for all possible cfg_outs from the in_bb). - layout = in_bb.out_vars - - # pop all the stack items which in_bb produced which we don't need. - to_pop |= layout.difference(inputs) - + # the input block is a splitter block, like jnz or djmp + assert len(basicblock.cfg_in) == 1 + in_bb = basicblock.cfg_in.first() + assert len(in_bb.cfg_out) > 1 + + # inputs is the input variables we need from in_bb + inputs = self.liveness_analysis.input_vars_from(in_bb, basicblock) + + # layout is the output stack layout for in_bb (which works + # for all possible cfg_outs from the in_bb, in_bb is responsible + # for making sure its output stack layout works no matter which + # bb it jumps into). + layout = in_bb.out_vars + to_pop = list(layout.difference(inputs)) + + # small heuristic: pop from shallowest first. + to_pop.sort(key=lambda var: -stack.get_depth(var)) + + # NOTE: we could get more fancy and try to optimize the swap + # operations here, there is probably some more room for optimization. for var in to_pop: depth = stack.get_depth(var) - # don't pop phantom phi inputs - if depth is StackModel.NOT_IN_STACK: - continue if depth != 0: self.swap(asm, stack, depth) self.pop(asm, stack) def _generate_evm_for_instruction( - self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet = None + self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet ) -> list[str]: assembly: list[str | int] = [] - next_liveness = next_liveness or OrderedSet() opcode = inst.opcode # @@ -359,7 +370,7 @@ def _generate_evm_for_instruction( if opcode in ["jmp", "djmp", "jnz", "invoke"]: operands = list(inst.get_non_label_operands()) - elif opcode == "alloca": + elif opcode in ("alloca", "palloca"): offset, _size = inst.operands operands = [offset] @@ -393,7 +404,8 @@ def _generate_evm_for_instruction( # example, for `%56 = %label1 %13 %label2 %14`, we will # find an instance of %13 *or* %14 in the stack and replace it with %56. to_be_replaced = stack.peek(depth) - if to_be_replaced in inst.dup_requirements: + if to_be_replaced in next_liveness: + # this branch seems unreachable (maybe due to make_ssa) # %13/%14 is still live(!), so we make a copy of it self.dup(assembly, stack, depth) stack.poke(0, ret) @@ -401,29 +413,49 @@ def _generate_evm_for_instruction( stack.poke(depth, ret) return apply_line_numbers(inst, assembly) + if opcode == "offset": + ofst, label = inst.operands + assert isinstance(label, IRLabel) # help mypy + assembly.extend(["_OFST", _as_asm_symbol(label), ofst.value]) + assert isinstance(inst.output, IROperand), "Offset must have output" + stack.push(inst.output) + return apply_line_numbers(inst, assembly) + # Step 2: Emit instruction's input operands - self._emit_input_operands(assembly, inst, operands, stack) - - # Step 3: Reorder stack - if opcode in ["jnz", "djmp", "jmp"]: - # prepare stack for jump into another basic block - assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) - b = next(iter(inst.parent.cfg_out)) - target_stack = self.liveness_analysis.input_vars_from(inst.parent, b) - # TODO optimize stack reordering at entry and exit from basic blocks - # NOTE: stack in general can contain multiple copies of the same variable, - # however we are safe in the case of jmp/djmp/jnz as it's not going to - # have multiples. - target_stack_list = list(target_stack) - self._stack_reorder(assembly, stack, target_stack_list) - - if opcode in COMMUTATIVE_INSTRUCTIONS: + self._emit_input_operands(assembly, inst, operands, stack, next_liveness) + + # Step 3: Reorder stack before join points + if opcode == "jmp": + # prepare stack for jump into a join point + # we only need to reorder stack before join points, which after + # cfg normalization, join points can only be led into by + # jmp instructions. + assert isinstance(inst.parent.cfg_out, OrderedSet) + assert len(inst.parent.cfg_out) == 1 + next_bb = inst.parent.cfg_out.first() + + # guaranteed by cfg normalization+simplification + assert len(next_bb.cfg_in) > 1 + + target_stack = self.liveness_analysis.input_vars_from(inst.parent, next_bb) + # NOTE: in general the stack can contain multiple copies of + # the same variable, however, before a jump that is not possible + self._stack_reorder(assembly, stack, list(target_stack)) + + if inst.is_commutative: cost_no_swap = self._stack_reorder([], stack, operands, dry_run=True) operands[-1], operands[-2] = operands[-2], operands[-1] cost_with_swap = self._stack_reorder([], stack, operands, dry_run=True) if cost_with_swap > cost_no_swap: operands[-1], operands[-2] = operands[-2], operands[-1] + cost = self._stack_reorder([], stack, operands, dry_run=True) + if DEBUG_SHOW_COST and cost: + print("ENTER", inst, file=sys.stderr) + print(" HAVE", stack, file=sys.stderr) + print(" WANT", operands, file=sys.stderr) + print(" COST", cost, file=sys.stderr) + # final step to get the inputs to this instruction ordered # correctly on the stack self._stack_reorder(assembly, stack, operands) @@ -440,32 +472,32 @@ def _generate_evm_for_instruction( # Step 5: Emit the EVM instruction(s) if opcode in _ONE_TO_ONE_INSTRUCTIONS: assembly.append(opcode.upper()) - elif opcode == "alloca": + elif opcode in ("alloca", "palloca"): pass elif opcode == "param": pass elif opcode == "store": pass - elif opcode == "dbname": - pass elif opcode in ["codecopy", "dloadbytes"]: assembly.append("CODECOPY") + elif opcode == "dbname": + pass elif opcode == "jnz": # jump if not zero - if_nonzero_label = inst.operands[1] - if_zero_label = inst.operands[2] - assembly.append(f"_sym_{if_nonzero_label.value}") + if_nonzero_label, if_zero_label = inst.get_label_operands() + assembly.append(_as_asm_symbol(if_nonzero_label)) assembly.append("JUMPI") # make sure the if_zero_label will be optimized out # assert if_zero_label == next(iter(inst.parent.cfg_out)).label - assembly.append(f"_sym_{if_zero_label.value}") + assembly.append(_as_asm_symbol(if_zero_label)) assembly.append("JUMP") elif opcode == "jmp": - assert isinstance(inst.operands[0], IRLabel) - assembly.append(f"_sym_{inst.operands[0].value}") + (target,) = inst.operands + assert isinstance(target, IRLabel) + assembly.append(_as_asm_symbol(target)) assembly.append("JUMP") elif opcode == "djmp": assert isinstance( @@ -480,7 +512,7 @@ def _generate_evm_for_instruction( assembly.extend( [ f"_sym_label_ret_{self.label_counter}", - f"_sym_{target.value}", + _as_asm_symbol(target), "JUMP", f"_sym_label_ret_{self.label_counter}", "JUMPDEST", @@ -537,13 +569,24 @@ def _generate_evm_for_instruction( # Step 6: Emit instructions output operands (if any) if inst.output is not None: - if "call" in inst.opcode and inst.output not in next_liveness: + if inst.output not in next_liveness: self.pop(assembly, stack) - elif inst.output in next_liveness: - # peek at next_liveness to find the next scheduled item, - # and optimistically swap with it - next_scheduled = list(next_liveness)[-1] - self.swap_op(assembly, stack, next_scheduled) + else: + # heuristic: peek at next_liveness to find the next scheduled + # item, and optimistically swap with it + if DEBUG_SHOW_COST: + stack0 = stack.copy() + + next_scheduled = next_liveness.last() + cost = 0 + if not self.equivalence.equivalent(inst.output, next_scheduled): + cost = self.swap_op(assembly, stack, next_scheduled) + + if DEBUG_SHOW_COST and cost != 0: + print("ENTER", inst, file=sys.stderr) + print(" HAVE", stack0, file=sys.stderr) + print(" NEXT LIVENESS", next_liveness, file=sys.stderr) + print(" NEW_STACK", stack, file=sys.stderr) return apply_line_numbers(inst, assembly) @@ -565,10 +608,14 @@ def dup(self, assembly, stack, depth): assembly.append(_evm_dup_for(depth)) def swap_op(self, assembly, stack, op): - self.swap(assembly, stack, stack.get_depth(op)) + depth = stack.get_depth(op) + assert depth is not StackModel.NOT_IN_STACK, f"Cannot swap non-existent operand {op}" + return self.swap(assembly, stack, depth) def dup_op(self, assembly, stack, op): - self.dup(assembly, stack, stack.get_depth(op)) + depth = stack.get_depth(op) + assert depth is not StackModel.NOT_IN_STACK, f"Cannot dup non-existent operand {op}" + self.dup(assembly, stack, depth) def _evm_swap_for(depth: int) -> str: