diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100755 index 0000000..11fb8ce --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions monthly + interval: "monthly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f5bac20..b97c77c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,13 @@ -name: Python package +name: unittests -on: [ push, pull_request ] +on: + push: + tags: + - "*" + branches: + - master + - "ci/*" + pull_request: jobs: unittests: @@ -9,33 +16,34 @@ jobs: run: shell: "bash -l {0}" strategy: + fail-fast: false matrix: python-version: - 3.6 - 3.7 - 3.8 - 3.9 + - 3.10 name: Tests for Python ${{ matrix.python-version }} steps: - uses: actions/checkout@v2 - uses: conda-incubator/setup-miniconda@v2 with: - miniforge-variant: Mambaforge - use-mamba: true + mamba-version: "*" + channels: conda-forge,defaults + channel-priority: true python-version: ${{ matrix.python-version }} activate-environment: formulate-env - name: Install ROOT run: | - conda config --add channels conda-forge - conda config --set channel_priority strict - conda install root -y + mamba install root -y - name: Install test dependencies run: | python -m pip install --upgrade pip - conda install coveralls -y - conda install pytest-cov -y + mamba install coveralls -y + mamba install pytest-cov -y pip install -e .[dev,test] - name: Test with pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 394d746..8e52631 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 + rev: v4.2.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -12,10 +12,58 @@ repos: - id: mixed-line-ending - id: requirements-txt-fixer - id: trailing-whitespace - - id: fix-encoding-pragma - repo: https://github.com/mgedmin/check-manifest - rev: "0.46" + rev: "0.48" hooks: - id: check-manifest stages: [ manual ] + + - repo: https://github.com/myint/docformatter + rev: v1.4 + hooks: + - id: docformatter + args: [ -r, --in-place, --wrap-descriptions, '120', --wrap-summaries, '120', -- ] + + - repo: https://github.com/mattlqx/pre-commit-sign + rev: v1.1.3 + hooks: + - id: sign-commit + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-use-type-annotations + - id: python-check-mock-methods + - id: python-no-eval + - id: rst-directive-colons + + - repo: https://github.com/asottile/pyupgrade + rev: v2.32.1 + hooks: + - id: pyupgrade + args: [ --py36-plus ] + + - repo: https://github.com/asottile/setup-cfg-fmt + rev: v1.20.1 + hooks: + - id: setup-cfg-fmt + args: [--max-py-version=3.10] + + - repo: https://github.com/roy-ht/pre-commit-jupyter + rev: v1.2.1 + hooks: + - id: jupyter-notebook-cleanup + +# TODO: for Python 3.7+ +# - repo: https://github.com/sondrelg/pep585-upgrade +# rev: 'v1.0' +# hooks: +# - id: upgrade-type-hints +# args: [ '--futures=true' ] + + - repo: https://github.com/ambv/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3 diff --git a/conftest.py b/conftest.py index 5f7c41b..4795d54 100644 --- a/conftest.py +++ b/conftest.py @@ -1,14 +1,15 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import pytest def pytest_addoption(parser): - parser.addoption("--run-slow", action="store_true", - default=False, help="Enable running slow tests") + parser.addoption( + "--run-slow", + action="store_true", + default=False, + help="Enable running slow tests", + ) def pytest_collection_modifyitems(config, items): diff --git a/formulate/__init__.py b/formulate/__init__.py index dcf521a..a4a0a4c 100644 --- a/formulate/__init__.py +++ b/formulate/__init__.py @@ -1,33 +1,37 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function from pyparsing import ParserElement from .backends import from_auto, from_numexpr, to_numexpr from .backends import from_root, to_root -from .expression import ExpressionComponent, SingleComponent, Expression, Variable, NamedConstant, UnnamedConstant +from .expression import ( + ExpressionComponent, + SingleComponent, + Expression, + Variable, + NamedConstant, + UnnamedConstant, +) from .parser import ParsingException from .version import version as __version__ __all__ = [ - 'ExpressionComponent', - 'SingleComponent', - 'Expression', - 'Variable', - 'NamedConstant', - 'UnnamedConstant', - 'ParsingException', - 'from_auto', + "ExpressionComponent", + "SingleComponent", + "Expression", + "Variable", + "NamedConstant", + "UnnamedConstant", + "ParsingException", + "from_auto", # numexpr - 'from_numexpr', - 'to_numexpr', + "from_numexpr", + "to_numexpr", # ROOT - 'from_root', - 'to_root', - '__version__', + "from_root", + "to_root", + "__version__", ] diff --git a/formulate/__main__.py b/formulate/__main__.py index 5dace8a..6e92191 100644 --- a/formulate/__main__.py +++ b/formulate/__main__.py @@ -1,7 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import argparse import sys @@ -11,18 +8,20 @@ def parse_args(args): - parser = argparse.ArgumentParser(description='Convert between different types of formulae') + parser = argparse.ArgumentParser( + description="Convert between different types of formulae" + ) from_group = parser.add_mutually_exclusive_group(required=True) - from_group.add_argument('--from-root') - from_group.add_argument('--from-numexpr') + from_group.add_argument("--from-root") + from_group.add_argument("--from-numexpr") to_group = parser.add_mutually_exclusive_group(required=True) - to_group.add_argument('--to-root', action='store_true') - to_group.add_argument('--to-numexpr', action='store_true') - to_group.add_argument('--variables', action='store_true') - to_group.add_argument('--named-constants', action='store_true') - to_group.add_argument('--unnamed-constants', action='store_true') + to_group.add_argument("--to-root", action="store_true") + to_group.add_argument("--to-numexpr", action="store_true") + to_group.add_argument("--variables", action="store_true") + to_group.add_argument("--named-constants", action="store_true") + to_group.add_argument("--unnamed-constants", action="store_true") args = parser.parse_args(args) if args.from_root is not None: @@ -37,16 +36,16 @@ def parse_args(args): elif args.to_numexpr: result = to_numexpr(expression) elif args.variables: - result = '\n'.join(sorted(expression.variables)) + result = "\n".join(sorted(expression.variables)) elif args.named_constants: - result = '\n'.join(sorted(expression.named_constants)) + result = "\n".join(sorted(expression.named_constants)) elif args.unnamed_constants: - result = '\n'.join(sorted(expression.unnamed_constants)) + result = "\n".join(sorted(expression.unnamed_constants)) else: raise NotImplementedError() return result -if __name__ == '__main__': +if __name__ == "__main__": print(parse_args(sys.argv[1:])) diff --git a/formulate/backends/ROOT.py b/formulate/backends/ROOT.py index 7c1b6c1..fac763f 100644 --- a/formulate/backends/ROOT.py +++ b/formulate/backends/ROOT.py @@ -1,139 +1,124 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function from ..identifiers import IDs, ConstantIDs from ..parser import POperator, PFunction, Parser, PConstant __all__ = [ - 'root_parser', + "root_parser", ] config = [ - POperator(IDs.MINUS, '-', rhs_only=True), - POperator(IDs.PLUS, '+', rhs_only=True), - POperator(IDs.ADD, '+'), - POperator(IDs.SUB, '-'), - POperator(IDs.MUL, '*'), - POperator(IDs.DIV, '/'), - POperator(IDs.MOD, '%'), - POperator(IDs.POW, '**'), - POperator(IDs.LSHIFT, '<<'), - POperator(IDs.RSHIFT, '>>'), - - POperator(IDs.EQ, '=='), - POperator(IDs.NEQ, '!='), - POperator(IDs.GT, '>'), - POperator(IDs.GTEQ, '>='), - POperator(IDs.LT, '<'), - POperator(IDs.LTEQ, '<='), - - POperator(IDs.AND, '&&'), - POperator(IDs.OR, '||'), - POperator(IDs.XOR, '^'), - POperator(IDs.NOT, '!', rhs_only=True), - - PFunction(IDs.SQRT, 'sqrt'), - PFunction(IDs.SQRT, 'TMath::Sqrt'), - PFunction(IDs.ABS, 'TMath::Abs'), - - PFunction(IDs.LOG, 'log'), - PFunction(IDs.LOG, 'TMath::Log'), - PFunction(IDs.LOG2, 'log2'), - PFunction(IDs.LOG2, 'TMath::Log2'), - PFunction(IDs.LOG10, 'log10'), - PFunction(IDs.LOG10, 'TMath::Log10'), - - PFunction(IDs.EXP, 'exp'), - PFunction(IDs.EXP, 'TMath::Exp'), - - PFunction(IDs.SIN, 'sin'), - PFunction(IDs.SIN, 'TMath::Sin'), - PFunction(IDs.ASIN, 'arcsin'), - PFunction(IDs.ASIN, 'TMath::ASin'), - PFunction(IDs.COS, 'cos'), - PFunction(IDs.COS, 'TMath::Cos'), - PFunction(IDs.ACOS, 'arccos'), - PFunction(IDs.ACOS, 'TMath::ACos'), - PFunction(IDs.TAN, 'tan'), - PFunction(IDs.TAN, 'TMath::Tan'), - PFunction(IDs.ATAN, 'arctan'), - PFunction(IDs.ATAN, 'TMath::ATan'), - PFunction(IDs.ATAN2, 'arctan2', 2), - PFunction(IDs.ATAN2, 'TMath::ATan2', 2), - - PFunction(IDs.COSH, 'TMath::CosH'), - PFunction(IDs.ACOSH, 'TMath::ACosH'), - PFunction(IDs.SINH, 'TMath::SinH'), - PFunction(IDs.ASINH, 'TMath::ASinH'), - PFunction(IDs.TANH, 'TMath::TanH'), - PFunction(IDs.ATANH, 'TMath::ATanH'), - - PFunction(IDs.BESSELI0, 'TMath::BesselI0'), - PFunction(IDs.BESSELI1, 'TMath::BesselI1'), - PFunction(IDs.BESSELJ0, 'TMath::BesselJ0'), - PFunction(IDs.BESSELJ1, 'TMath::BesselJ1'), - PFunction(IDs.BESSELK0, 'TMath::BesselK0'), - PFunction(IDs.BESSELK1, 'TMath::BesselK1'), - PFunction(IDs.BESSELY0, 'TMath::BesselY0'), - PFunction(IDs.BESSELY1, 'TMath::BesselY1'), - PFunction(IDs.CEIL, 'TMath::Ceil'), - PFunction(IDs.CEILNINT, 'TMath::CeilNint'), - PFunction(IDs.DILOG, 'TMath::DiLog'), - PFunction(IDs.ERF, 'TMath::Erf'), - PFunction(IDs.ERFC, 'TMath::Erfc'), - PFunction(IDs.ERFCINVERSE, 'TMath::ErfcInverse'), - PFunction(IDs.ERFINVERSE, 'TMath::ErfInverse'), - PFunction(IDs.EVEN, 'TMath::Even'), - PFunction(IDs.FACTORIAL, 'TMath::Factorial'), - PFunction(IDs.FLOOR, 'TMath::Floor'), - PFunction(IDs.FLOORNINT, 'TMath::FloorNint'), - PFunction(IDs.FREQ, 'TMath::Freq'), - PFunction(IDs.KOLMOGOROVPROB, 'TMath::KolmogorovProb'), - PFunction(IDs.LANDAUI, 'TMath::LandauI'), - PFunction(IDs.LNGAMMA, 'TMath::LnGamma'), - PFunction(IDs.NEXTPRIME, 'TMath::NextPrime'), - PFunction(IDs.NORMQUANTILE, 'TMath::NormQuantile'), - PFunction(IDs.ODD, 'TMath::Odd'), - PFunction(IDs.SQUARE, 'TMath::Sq'), - PFunction(IDs.STRUVEH0, 'TMath::StruveH0'), - PFunction(IDs.STRUVEH1, 'TMath::StruveH1'), - PFunction(IDs.STRUVEL0, 'TMath::StruveL0'), - PFunction(IDs.STRUVEL1, 'TMath::StruveL1'), - - PFunction(IDs.BESSELI, 'TMath::BesselI', 2), - PFunction(IDs.BESSELK, 'TMath::BesselK', 2), - PFunction(IDs.BETA, 'TMath::Beta', 2), - PFunction(IDs.BINOMIAL, 'TMath::Binomial', 2), - PFunction(IDs.CHISQUAREQUANTILE, 'TMath::ChisquareQuantile', 2), - PFunction(IDs.LDEXP, 'TMath::Ldexp', 2), - PFunction(IDs.PERMUTE, 'TMath::Permute', 2), - PFunction(IDs.POISSON, 'TMath::Poisson', 2), - PFunction(IDs.POISSONI, 'TMath::PoissonI', 2), - PFunction(IDs.PROB, 'TMath::Prob', 2), - PFunction(IDs.STUDENT, 'TMath::Student', 2), - PFunction(IDs.STUDENTI, 'TMath::StudentI', 2), - - PFunction(IDs.AREEQUALABS, 'TMath::AreEqualAbs', 3), - PFunction(IDs.AREEQUALREL, 'TMath::AreEqualRel', 3), - PFunction(IDs.BETACF, 'TMath::BetaCf', 3), - PFunction(IDs.BETADIST, 'TMath::BetaDist', 3), - PFunction(IDs.BETADISTI, 'TMath::BetaDistI', 3), - PFunction(IDs.BETAINCOMPLETE, 'TMath::BetaIncomplete', 3), - PFunction(IDs.BINOMIALI, 'TMath::BinomialI', 3), - PFunction(IDs.BUBBLEHIGH, 'TMath::BubbleHigh', 3), - PFunction(IDs.BUBBLELOW, 'TMath::BubbleLow', 3), - PFunction(IDs.FDIST, 'TMath::FDist', 3), - PFunction(IDs.FDISTI, 'TMath::FDistI', 3), - PFunction(IDs.VAVILOV, 'TMath::Vavilov', 3), - PFunction(IDs.VAVILOVI, 'TMath::VavilovI', 3), - - PFunction(IDs.ROOTSCUBIC, 'TMath::RootsCubic', 4), - PFunction(IDs.QUANTILES, 'TMath::Quantiles', 5), - + POperator(IDs.MINUS, "-", rhs_only=True), + POperator(IDs.PLUS, "+", rhs_only=True), + POperator(IDs.ADD, "+"), + POperator(IDs.SUB, "-"), + POperator(IDs.MUL, "*"), + POperator(IDs.DIV, "/"), + POperator(IDs.MOD, "%"), + POperator(IDs.POW, "**"), + POperator(IDs.LSHIFT, "<<"), + POperator(IDs.RSHIFT, ">>"), + POperator(IDs.EQ, "=="), + POperator(IDs.NEQ, "!="), + POperator(IDs.GT, ">"), + POperator(IDs.GTEQ, ">="), + POperator(IDs.LT, "<"), + POperator(IDs.LTEQ, "<="), + POperator(IDs.AND, "&&"), + POperator(IDs.OR, "||"), + POperator(IDs.XOR, "^"), + POperator(IDs.NOT, "!", rhs_only=True), + PFunction(IDs.SQRT, "sqrt"), + PFunction(IDs.SQRT, "TMath::Sqrt"), + PFunction(IDs.ABS, "TMath::Abs"), + PFunction(IDs.LOG, "log"), + PFunction(IDs.LOG, "TMath::Log"), + PFunction(IDs.LOG2, "log2"), + PFunction(IDs.LOG2, "TMath::Log2"), + PFunction(IDs.LOG10, "log10"), + PFunction(IDs.LOG10, "TMath::Log10"), + PFunction(IDs.EXP, "exp"), + PFunction(IDs.EXP, "TMath::Exp"), + PFunction(IDs.SIN, "sin"), + PFunction(IDs.SIN, "TMath::Sin"), + PFunction(IDs.ASIN, "arcsin"), + PFunction(IDs.ASIN, "TMath::ASin"), + PFunction(IDs.COS, "cos"), + PFunction(IDs.COS, "TMath::Cos"), + PFunction(IDs.ACOS, "arccos"), + PFunction(IDs.ACOS, "TMath::ACos"), + PFunction(IDs.TAN, "tan"), + PFunction(IDs.TAN, "TMath::Tan"), + PFunction(IDs.ATAN, "arctan"), + PFunction(IDs.ATAN, "TMath::ATan"), + PFunction(IDs.ATAN2, "arctan2", 2), + PFunction(IDs.ATAN2, "TMath::ATan2", 2), + PFunction(IDs.COSH, "TMath::CosH"), + PFunction(IDs.ACOSH, "TMath::ACosH"), + PFunction(IDs.SINH, "TMath::SinH"), + PFunction(IDs.ASINH, "TMath::ASinH"), + PFunction(IDs.TANH, "TMath::TanH"), + PFunction(IDs.ATANH, "TMath::ATanH"), + PFunction(IDs.BESSELI0, "TMath::BesselI0"), + PFunction(IDs.BESSELI1, "TMath::BesselI1"), + PFunction(IDs.BESSELJ0, "TMath::BesselJ0"), + PFunction(IDs.BESSELJ1, "TMath::BesselJ1"), + PFunction(IDs.BESSELK0, "TMath::BesselK0"), + PFunction(IDs.BESSELK1, "TMath::BesselK1"), + PFunction(IDs.BESSELY0, "TMath::BesselY0"), + PFunction(IDs.BESSELY1, "TMath::BesselY1"), + PFunction(IDs.CEIL, "TMath::Ceil"), + PFunction(IDs.CEILNINT, "TMath::CeilNint"), + PFunction(IDs.DILOG, "TMath::DiLog"), + PFunction(IDs.ERF, "TMath::Erf"), + PFunction(IDs.ERFC, "TMath::Erfc"), + PFunction(IDs.ERFCINVERSE, "TMath::ErfcInverse"), + PFunction(IDs.ERFINVERSE, "TMath::ErfInverse"), + PFunction(IDs.EVEN, "TMath::Even"), + PFunction(IDs.FACTORIAL, "TMath::Factorial"), + PFunction(IDs.FLOOR, "TMath::Floor"), + PFunction(IDs.FLOORNINT, "TMath::FloorNint"), + PFunction(IDs.FREQ, "TMath::Freq"), + PFunction(IDs.KOLMOGOROVPROB, "TMath::KolmogorovProb"), + PFunction(IDs.LANDAUI, "TMath::LandauI"), + PFunction(IDs.LNGAMMA, "TMath::LnGamma"), + PFunction(IDs.NEXTPRIME, "TMath::NextPrime"), + PFunction(IDs.NORMQUANTILE, "TMath::NormQuantile"), + PFunction(IDs.ODD, "TMath::Odd"), + PFunction(IDs.SQUARE, "TMath::Sq"), + PFunction(IDs.STRUVEH0, "TMath::StruveH0"), + PFunction(IDs.STRUVEH1, "TMath::StruveH1"), + PFunction(IDs.STRUVEL0, "TMath::StruveL0"), + PFunction(IDs.STRUVEL1, "TMath::StruveL1"), + PFunction(IDs.BESSELI, "TMath::BesselI", 2), + PFunction(IDs.BESSELK, "TMath::BesselK", 2), + PFunction(IDs.BETA, "TMath::Beta", 2), + PFunction(IDs.BINOMIAL, "TMath::Binomial", 2), + PFunction(IDs.CHISQUAREQUANTILE, "TMath::ChisquareQuantile", 2), + PFunction(IDs.LDEXP, "TMath::Ldexp", 2), + PFunction(IDs.PERMUTE, "TMath::Permute", 2), + PFunction(IDs.POISSON, "TMath::Poisson", 2), + PFunction(IDs.POISSONI, "TMath::PoissonI", 2), + PFunction(IDs.PROB, "TMath::Prob", 2), + PFunction(IDs.STUDENT, "TMath::Student", 2), + PFunction(IDs.STUDENTI, "TMath::StudentI", 2), + PFunction(IDs.AREEQUALABS, "TMath::AreEqualAbs", 3), + PFunction(IDs.AREEQUALREL, "TMath::AreEqualRel", 3), + PFunction(IDs.BETACF, "TMath::BetaCf", 3), + PFunction(IDs.BETADIST, "TMath::BetaDist", 3), + PFunction(IDs.BETADISTI, "TMath::BetaDistI", 3), + PFunction(IDs.BETAINCOMPLETE, "TMath::BetaIncomplete", 3), + PFunction(IDs.BINOMIALI, "TMath::BinomialI", 3), + PFunction(IDs.BUBBLEHIGH, "TMath::BubbleHigh", 3), + PFunction(IDs.BUBBLELOW, "TMath::BubbleLow", 3), + PFunction(IDs.FDIST, "TMath::FDist", 3), + PFunction(IDs.FDISTI, "TMath::FDistI", 3), + PFunction(IDs.VAVILOV, "TMath::Vavilov", 3), + PFunction(IDs.VAVILOVI, "TMath::VavilovI", 3), + PFunction(IDs.ROOTSCUBIC, "TMath::RootsCubic", 4), + PFunction(IDs.QUANTILES, "TMath::Quantiles", 5), # PFunction(IDs., 'TMath::BreitWigner()'), # PFunction(IDs., 'TMath::CauchyDist()'), # PFunction(IDs., 'TMath::Finite()'), @@ -162,60 +147,58 @@ ] constants = [ - PConstant(ConstantIDs.TRUE, 'true'), - PConstant(ConstantIDs.FALSE, 'false'), - PConstant(ConstantIDs.INFINITY, 'TMath::Infinity()'), - PConstant(ConstantIDs.NAN, 'TMath::QuietNaN()'), + PConstant(ConstantIDs.TRUE, "true"), + PConstant(ConstantIDs.FALSE, "false"), + PConstant(ConstantIDs.INFINITY, "TMath::Infinity()"), + PConstant(ConstantIDs.NAN, "TMath::QuietNaN()"), # PConstant(ConstantIDs., 'TMath::SignalingNaN()'), - - PConstant(ConstantIDs.SQRT2, 'sqrt2'), - PConstant(ConstantIDs.SQRT2, 'TMath::Sqrt2()'), - PConstant(ConstantIDs.E, 'e'), - PConstant(ConstantIDs.E, 'TMath::E()'), - PConstant(ConstantIDs.PI, 'pi'), - PConstant(ConstantIDs.PI, 'TMath::Pi()'), - PConstant(ConstantIDs.INVPI, 'TMath::InvPi()'), - PConstant(ConstantIDs.PIOVER2, 'TMath::PiOver2()'), - PConstant(ConstantIDs.PIOVER4, 'TMath::PiOver4()'), - PConstant(ConstantIDs.TAU, 'TMath::TwoPi()'), - PConstant(ConstantIDs.LN10, 'ln10'), - PConstant(ConstantIDs.LN10, 'TMath::Ln10()'), - PConstant(ConstantIDs.LOG10E, 'TMath::LogE()'), - PConstant(ConstantIDs.DEG2RAD, 'TMath::DegToRad()'), - PConstant(ConstantIDs.RAD2DEG, 'TMath::RadToDeg()'), - - PConstant(ConstantIDs.AVOGADRO, 'TMath::Na()'), - PConstant(ConstantIDs.AVOGADRO_ERR, 'TMath::NaUncertainty()'), - PConstant(ConstantIDs.BOLTZMANN, 'TMath::K()'), - PConstant(ConstantIDs.BOLTZMANN_CGS, 'TMath::Kcgs()'), - PConstant(ConstantIDs.BOLTZMANN_ERR, 'TMath::KUncertainty()'), - PConstant(ConstantIDs.C, 'TMath::C()'), - PConstant(ConstantIDs.C_CGS, 'TMath::Ccgs()'), - PConstant(ConstantIDs.C_ERR, 'TMath::CUncertainty()'), - PConstant(ConstantIDs.DRY_AIR_GAS, 'TMath::Rgair()'), - PConstant(ConstantIDs.ELEMENTARY_CHARGE, 'TMath::Qe()'), - PConstant(ConstantIDs.ELEMENTARY_CHARGE_ERR, 'TMath::QeUncertainty()'), - PConstant(ConstantIDs.EULER_MASCHERONI, 'TMath::EulerGamma()'), - PConstant(ConstantIDs.G, 'TMath::G()'), - PConstant(ConstantIDs.G_CGS, 'TMath::Gcgs()'), - PConstant(ConstantIDs.G_ERR, 'TMath::GUncertainty()'), - PConstant(ConstantIDs.G_OVER_HBARC, 'TMath::GhbarC()'), - PConstant(ConstantIDs.G_OVER_HBARC_ERR, 'TMath::GhbarCUncertainty()'), - PConstant(ConstantIDs.GRAV_ACCEL, 'TMath::Gn()'), - PConstant(ConstantIDs.GRAV_ACCEL_ERR, 'TMath::GnUncertainty()'), - PConstant(ConstantIDs.H, 'TMath::H()'), - PConstant(ConstantIDs.H_CGS, 'TMath::Hcgs()'), - PConstant(ConstantIDs.H_ERR, 'TMath::HUncertainty()'), - PConstant(ConstantIDs.HBAR, 'TMath::Hbar()'), - PConstant(ConstantIDs.HBAR_CGS, 'TMath::Hbarcgs()'), - PConstant(ConstantIDs.HBAR_ERR, 'TMath::HbarUncertainty()'), - PConstant(ConstantIDs.HxC, 'TMath::HC()'), - PConstant(ConstantIDs.HxC_CGS, 'TMath::HCcgs()'), - PConstant(ConstantIDs.MOL_WEIGHT_DRY_AIR, 'TMath::MWair()'), - PConstant(ConstantIDs.STEFAN_BOLTZMANN, 'TMath::Sigma()'), - PConstant(ConstantIDs.STEFAN_BOLTZMANN_ERR, 'TMath::SigmaUncertainty()'), - PConstant(ConstantIDs.UNIVERSAL_GAS, 'TMath::R()'), - PConstant(ConstantIDs.UNIVERSAL_GAS_ERR, 'TMath::RUncertainty()'), + PConstant(ConstantIDs.SQRT2, "sqrt2"), + PConstant(ConstantIDs.SQRT2, "TMath::Sqrt2()"), + PConstant(ConstantIDs.E, "e"), + PConstant(ConstantIDs.E, "TMath::E()"), + PConstant(ConstantIDs.PI, "pi"), + PConstant(ConstantIDs.PI, "TMath::Pi()"), + PConstant(ConstantIDs.INVPI, "TMath::InvPi()"), + PConstant(ConstantIDs.PIOVER2, "TMath::PiOver2()"), + PConstant(ConstantIDs.PIOVER4, "TMath::PiOver4()"), + PConstant(ConstantIDs.TAU, "TMath::TwoPi()"), + PConstant(ConstantIDs.LN10, "ln10"), + PConstant(ConstantIDs.LN10, "TMath::Ln10()"), + PConstant(ConstantIDs.LOG10E, "TMath::LogE()"), + PConstant(ConstantIDs.DEG2RAD, "TMath::DegToRad()"), + PConstant(ConstantIDs.RAD2DEG, "TMath::RadToDeg()"), + PConstant(ConstantIDs.AVOGADRO, "TMath::Na()"), + PConstant(ConstantIDs.AVOGADRO_ERR, "TMath::NaUncertainty()"), + PConstant(ConstantIDs.BOLTZMANN, "TMath::K()"), + PConstant(ConstantIDs.BOLTZMANN_CGS, "TMath::Kcgs()"), + PConstant(ConstantIDs.BOLTZMANN_ERR, "TMath::KUncertainty()"), + PConstant(ConstantIDs.C, "TMath::C()"), + PConstant(ConstantIDs.C_CGS, "TMath::Ccgs()"), + PConstant(ConstantIDs.C_ERR, "TMath::CUncertainty()"), + PConstant(ConstantIDs.DRY_AIR_GAS, "TMath::Rgair()"), + PConstant(ConstantIDs.ELEMENTARY_CHARGE, "TMath::Qe()"), + PConstant(ConstantIDs.ELEMENTARY_CHARGE_ERR, "TMath::QeUncertainty()"), + PConstant(ConstantIDs.EULER_MASCHERONI, "TMath::EulerGamma()"), + PConstant(ConstantIDs.G, "TMath::G()"), + PConstant(ConstantIDs.G_CGS, "TMath::Gcgs()"), + PConstant(ConstantIDs.G_ERR, "TMath::GUncertainty()"), + PConstant(ConstantIDs.G_OVER_HBARC, "TMath::GhbarC()"), + PConstant(ConstantIDs.G_OVER_HBARC_ERR, "TMath::GhbarCUncertainty()"), + PConstant(ConstantIDs.GRAV_ACCEL, "TMath::Gn()"), + PConstant(ConstantIDs.GRAV_ACCEL_ERR, "TMath::GnUncertainty()"), + PConstant(ConstantIDs.H, "TMath::H()"), + PConstant(ConstantIDs.H_CGS, "TMath::Hcgs()"), + PConstant(ConstantIDs.H_ERR, "TMath::HUncertainty()"), + PConstant(ConstantIDs.HBAR, "TMath::Hbar()"), + PConstant(ConstantIDs.HBAR_CGS, "TMath::Hbarcgs()"), + PConstant(ConstantIDs.HBAR_ERR, "TMath::HbarUncertainty()"), + PConstant(ConstantIDs.HxC, "TMath::HC()"), + PConstant(ConstantIDs.HxC_CGS, "TMath::HCcgs()"), + PConstant(ConstantIDs.MOL_WEIGHT_DRY_AIR, "TMath::MWair()"), + PConstant(ConstantIDs.STEFAN_BOLTZMANN, "TMath::Sigma()"), + PConstant(ConstantIDs.STEFAN_BOLTZMANN_ERR, "TMath::SigmaUncertainty()"), + PConstant(ConstantIDs.UNIVERSAL_GAS, "TMath::R()"), + PConstant(ConstantIDs.UNIVERSAL_GAS_ERR, "TMath::RUncertainty()"), ] -root_parser = Parser('ROOT', config, constants) +root_parser = Parser("ROOT", config, constants) diff --git a/formulate/backends/__init__.py b/formulate/backends/__init__.py index 3a49132..5ee3df5 100644 --- a/formulate/backends/__init__.py +++ b/formulate/backends/__init__.py @@ -1,7 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import re @@ -11,11 +8,11 @@ __all__ = [ - 'from_auto', - 'from_numexpr', - 'to_numexpr', - 'from_root', - 'to_root', + "from_auto", + "from_numexpr", + "to_numexpr", + "from_root", + "to_root", ] @@ -28,10 +25,13 @@ def from_auto(string): # Intelligently detect which kind of string is passed - if any(x in string for x in ['&&', '||', 'TMath::', 'true', 'false']): + if any(x in string for x in ["&&", "||", "TMath::", "true", "false"]): return from_root(string) - elif (re.findall(r'([^\&]\&[^\&])|([^\|]\|[^\|])', string) or - 'True' in string or 'False' in string): + elif ( + re.findall(r"([^\&]\&[^\&])|([^\|]\|[^\|])", string) + or "True" in string + or "False" in string + ): return from_numexpr(string) # Intelligently detecting failed so fall back to brute force @@ -45,4 +45,4 @@ def from_auto(string): except ParsingException: pass - raise ParsingException('No available backend which can parse: '+string) + raise ParsingException("No available backend which can parse: " + string) diff --git a/formulate/backends/numexpr.py b/formulate/backends/numexpr.py index 9004ce6..e2d32fb 100644 --- a/formulate/backends/numexpr.py +++ b/formulate/backends/numexpr.py @@ -1,7 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import math import scipy.constants as sc @@ -9,100 +6,94 @@ from ..identifiers import IDs, ConstantIDs from ..parser import POperator, PFunction, Parser, PConstant +# TODO: constants still needed? Or in hepunits? __all__ = [ - 'numexpr_parser', + "numexpr_parser", ] config = [ - POperator(IDs.MINUS, '-', rhs_only=True), - POperator(IDs.PLUS, '+', rhs_only=True), - POperator(IDs.ADD, '+'), - POperator(IDs.SUB, '-'), - POperator(IDs.MUL, '*'), - POperator(IDs.DIV, '/'), - POperator(IDs.MOD, '%'), - POperator(IDs.POW, '**'), - POperator(IDs.LSHIFT, '<<'), - POperator(IDs.RSHIFT, '>>'), - - POperator(IDs.EQ, '=='), - POperator(IDs.NEQ, '!='), - POperator(IDs.GT, '>'), - POperator(IDs.GTEQ, '>='), - POperator(IDs.LT, '<'), - POperator(IDs.LTEQ, '<='), - - POperator(IDs.AND, '&'), - POperator(IDs.OR, '|'), - POperator(IDs.XOR, '^'), - POperator(IDs.NOT, '~', rhs_only=True), - - PFunction(IDs.SQRT, 'sqrt'), - PFunction(IDs.ABS, 'abs'), - PFunction(IDs.WHERE, 'where', 3), - - PFunction(IDs.LOG, 'log'), - PFunction(IDs.LOG10, 'log10'), - PFunction(IDs.LOG1p, 'log1p'), - - PFunction(IDs.EXP, 'exp'), - PFunction(IDs.EXPM1, 'expm1'), - - PFunction(IDs.SIN, 'sin'), - PFunction(IDs.ASIN, 'arcsin'), - PFunction(IDs.COS, 'cos'), - PFunction(IDs.ACOS, 'arccos'), - PFunction(IDs.TAN, 'tan'), - PFunction(IDs.ATAN, 'arctan'), - PFunction(IDs.ATAN2, 'arctan2', 2), - - PFunction(IDs.SINH, 'sinh'), - PFunction(IDs.ASINH, 'arcsinh'), - PFunction(IDs.COSH, 'cosh'), - PFunction(IDs.ACOSH, 'arccosh'), - PFunction(IDs.TANH, 'tanh'), - PFunction(IDs.ATANH, 'arctanh'), - - POperator(IDs.SQUARE, '**2', lhs_only=True), + POperator(IDs.MINUS, "-", rhs_only=True), + POperator(IDs.PLUS, "+", rhs_only=True), + POperator(IDs.ADD, "+"), + POperator(IDs.SUB, "-"), + POperator(IDs.MUL, "*"), + POperator(IDs.DIV, "/"), + POperator(IDs.MOD, "%"), + POperator(IDs.POW, "**"), + POperator(IDs.LSHIFT, "<<"), + POperator(IDs.RSHIFT, ">>"), + POperator(IDs.EQ, "=="), + POperator(IDs.NEQ, "!="), + POperator(IDs.GT, ">"), + POperator(IDs.GTEQ, ">="), + POperator(IDs.LT, "<"), + POperator(IDs.LTEQ, "<="), + POperator(IDs.AND, "&"), + POperator(IDs.OR, "|"), + POperator(IDs.XOR, "^"), + POperator(IDs.NOT, "~", rhs_only=True), + PFunction(IDs.SQRT, "sqrt"), + PFunction(IDs.ABS, "abs"), + PFunction(IDs.WHERE, "where", 3), + PFunction(IDs.LOG, "log"), + PFunction(IDs.LOG10, "log10"), + PFunction(IDs.LOG1p, "log1p"), + PFunction(IDs.EXP, "exp"), + PFunction(IDs.EXPM1, "expm1"), + PFunction(IDs.SIN, "sin"), + PFunction(IDs.ASIN, "arcsin"), + PFunction(IDs.COS, "cos"), + PFunction(IDs.ACOS, "arccos"), + PFunction(IDs.TAN, "tan"), + PFunction(IDs.ATAN, "arctan"), + PFunction(IDs.ATAN2, "arctan2", 2), + PFunction(IDs.SINH, "sinh"), + PFunction(IDs.ASINH, "arcsinh"), + PFunction(IDs.COSH, "cosh"), + PFunction(IDs.ACOSH, "arccosh"), + PFunction(IDs.TANH, "tanh"), + PFunction(IDs.ATANH, "arctanh"), + POperator(IDs.SQUARE, "**2", lhs_only=True), ] constants = [ - PConstant(ConstantIDs.TRUE, 'True'), - PConstant(ConstantIDs.FALSE, 'False'), + PConstant(ConstantIDs.TRUE, "True"), + PConstant(ConstantIDs.FALSE, "False"), # PConstant(ConstantIDs.INFINITY, np.inf), # PConstant(ConstantIDs.NAN, np.nan), - PConstant(ConstantIDs.SQRT2, math.sqrt(2)), PConstant(ConstantIDs.E, math.e), PConstant(ConstantIDs.PI, math.pi), - PConstant(ConstantIDs.INVPI, 1/math.pi), - PConstant(ConstantIDs.PIOVER2, math.pi/2), - PConstant(ConstantIDs.PIOVER4, math.pi/4), - PConstant(ConstantIDs.TAU, 2*math.pi), + PConstant(ConstantIDs.INVPI, 1 / math.pi), + PConstant(ConstantIDs.PIOVER2, math.pi / 2), + PConstant(ConstantIDs.PIOVER4, math.pi / 4), + PConstant(ConstantIDs.TAU, 2 * math.pi), PConstant(ConstantIDs.LN10, math.log(10)), PConstant(ConstantIDs.LOG10E, math.log10(math.e)), - PConstant(ConstantIDs.DEG2RAD, math.pi/180), - PConstant(ConstantIDs.RAD2DEG, 180/math.pi), - + PConstant(ConstantIDs.DEG2RAD, math.pi / 180), + PConstant(ConstantIDs.RAD2DEG, 180 / math.pi), PConstant(ConstantIDs.AVOGADRO, sc.Avogadro), # PConstant(ConstantIDs.AVOGADRO_ERR, 'TMath::NaUncertainty()'), PConstant(ConstantIDs.BOLTZMANN, sc.Boltzmann), PConstant(ConstantIDs.BOLTZMANN_CGS, 1.0e7 * sc.Boltzmann), # PConstant(ConstantIDs.BOLTZMANN_ERR, 'TMath::KUncertainty()'), PConstant(ConstantIDs.C, sc.speed_of_light), - PConstant(ConstantIDs.C_CGS, 100*sc.speed_of_light), + PConstant(ConstantIDs.C_CGS, 100 * sc.speed_of_light), PConstant(ConstantIDs.C_ERR, 0.0), PConstant(ConstantIDs.DRY_AIR_GAS, 0.577216), # TODO: Taken from ROOT PConstant(ConstantIDs.ELEMENTARY_CHARGE, sc.elementary_charge), # PConstant(ConstantIDs.ELEMENTARY_CHARGE_ERR, 'TMath::QeUncertainty()'), PConstant(ConstantIDs.EULER_MASCHERONI, 28.964400), # TODO: Taken from ROOT PConstant(ConstantIDs.G, sc.gravitational_constant), - PConstant(ConstantIDs.G_CGS, sc.gravitational_constant/1000), + PConstant(ConstantIDs.G_CGS, sc.gravitational_constant / 1000), # PConstant(ConstantIDs.G_ERR, 'TMath::GUncertainty()'), - PConstant(ConstantIDs.G_OVER_HBARC, sc.gravitational_constant/(sc.hbar*sc.speed_of_light)), + PConstant( + ConstantIDs.G_OVER_HBARC, + sc.gravitational_constant / (sc.hbar * sc.speed_of_light), + ), # PConstant(ConstantIDs.G_OVER_HBARC_ERR, 'TMath::GhbarCUncertainty()'), PConstant(ConstantIDs.GRAV_ACCEL, sc.g), # PConstant(ConstantIDs.GRAV_ACCEL_ERR, 'TMath::GnUncertainty()'), @@ -113,7 +104,7 @@ PConstant(ConstantIDs.HBAR_CGS, 100 * sc.hbar), # PConstant(ConstantIDs.HBAR_ERR, 'TMath::HbarUncertainty()'), PConstant(ConstantIDs.HxC, sc.Planck * sc.speed_of_light), - PConstant(ConstantIDs.HxC_CGS, 100*sc.Planck * 100*sc.speed_of_light), + PConstant(ConstantIDs.HxC_CGS, 100 * sc.Planck * 100 * sc.speed_of_light), PConstant(ConstantIDs.MOL_WEIGHT_DRY_AIR, 287.058325), # TODO: Taken from ROOT PConstant(ConstantIDs.STEFAN_BOLTZMANN, sc.Stefan_Boltzmann), # PConstant(ConstantIDs.STEFAN_BOLTZMANN_ERR, 'TMath::SigmaUncertainty()'), @@ -122,4 +113,4 @@ ] -numexpr_parser = Parser('numexpr', config, constants) +numexpr_parser = Parser("numexpr", config, constants) diff --git a/formulate/expression.py b/formulate/expression.py index b427ff2..cd731bd 100644 --- a/formulate/expression.py +++ b/formulate/expression.py @@ -1,6 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import, division, print_function - import numbers import sys @@ -8,22 +6,24 @@ from .logging import add_logging __all__ = [ - 'ExpressionComponent', - 'SingleComponent', - 'NamedConstant', - 'UnnamedConstant', - 'Expression', - 'Variable', + "ExpressionComponent", + "SingleComponent", + "NamedConstant", + "UnnamedConstant", + "Expression", + "Variable", ] -class ExpressionComponent(object): +class ExpressionComponent: def to_numexpr(self, *args, **kwargs): from .backends.numexpr import numexpr_parser + return numexpr_parser.to_string(self, *args, **kwargs) def to_root(self, *args, **kwargs): from .backends.ROOT import root_parser + return root_parser.to_string(self, *args, **kwargs) # Binary arithmetic operators @@ -228,7 +228,7 @@ def __init__(self, id, *args): elif isinstance(arg, ExpressionComponent): checked_args.append(arg) else: - raise ValueError(repr(arg) + ' is not a valid type') + raise ValueError(f"{repr(arg)} is not a valid type") self._id = id self._args = checked_args @@ -239,9 +239,7 @@ def __repr__(self): # Python < 3.3 doesn't have __qualname__ class_name = self.__class__.__name__ - return '{class_name}<{id_name}>({args})'.format( - class_name=class_name, id_name=self.id.name, - args=", ".join(map(repr, self.args))) + return f'{class_name}<{self.id.name}>({", ".join(map(repr, self.args))})' def __str__(self): return repr(self) @@ -263,7 +261,11 @@ def _search_for(self, class_to_find): elif isinstance(current_component, Expression): components_to_check.extend(current_component.args) else: - raise ValueError('Unrecognised component "' + repr(current_component) + '" in expression') + raise ValueError( + 'Unrecognised component "' + + repr(current_component) + + '" in expression' + ) return result @property @@ -278,8 +280,9 @@ def named_constants(self): def unnamed_constants(self): return self._search_for(UnnamedConstant) + # TODO: fix this name and method! def equivilent(self, other): - """Check if two expression objects are the same""" + """Check if two expression objects are the same.""" raise NotImplementedError() if isinstance(other, self.__class__): return self.id == other.id and self._args == other._args @@ -298,7 +301,7 @@ def to_string(self, config, constants): try: return config[self.id].to_string(self, config, constants) except KeyError: - raise NotImplementedError('No known conversion for: ' + str(self)) + raise NotImplementedError(f"No known conversion for: {self}") class Variable(SingleComponent): @@ -306,8 +309,7 @@ def __init__(self, name): self._name = name def __repr__(self): - return '{class_name}({name})'.format( - class_name=self.__class__.__name__, name=self.name) + return f"{self.__class__.__name__}({self.name})" def __str__(self): return self.name @@ -338,8 +340,9 @@ def __init__(self, id): self._id = id def __repr__(self): - return '{class_name}({id})'.format( - class_name=self.__class__.__name__, id=self.id) + return "{class_name}({id})".format( + class_name=self.__class__.__name__, id=self.id + ) def __str__(self): return self.id.name @@ -365,7 +368,7 @@ def to_string(self, config, constants): try: return str(constants[self.id].value) except KeyError: - raise NotImplementedError('No known conversion for constant: ' + str(self)) + raise NotImplementedError(f"No known conversion for constant: {self}") class UnnamedConstant(SingleComponent): @@ -374,8 +377,7 @@ def __init__(self, value): self._value = value def __repr__(self): - return '{class_name}({value})'.format( - class_name=self.__class__.__name__, value=self.value) + return f"{self.__class__.__name__}({self.value})" def __str__(self): return self.value diff --git a/formulate/identifiers.py b/formulate/identifiers.py index d5bafbd..f7c0ef7 100644 --- a/formulate/identifiers.py +++ b/formulate/identifiers.py @@ -1,22 +1,16 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -try: - from enum import Enum, auto -except ImportError: - from aenum import Enum, auto - +from enum import Enum, auto __all__ = [ - 'IDs', - 'order_of_operations', + "IDs", + "order_of_operations", ] class IDs(Enum): - FIXED = auto() # Something which can't change such as: constants, numbers and variables + FIXED = ( + auto() + ) # Something which can't change such as: constants, numbers and variables MINUS = auto() PLUS = auto() @@ -164,11 +158,12 @@ class IDs(Enum): class ConstantIDs(Enum): - """Identifiers for constants + """Identifiers for constants. - - CGS => Use cm, g & s for units - - ERR => Uncertainty on quantity + - CGS => Use cm, g & s for units + - ERR => Uncertainty on quantity """ + TRUE = auto() FALSE = auto() INFINITY = auto() diff --git a/formulate/logging.py b/formulate/logging.py index 73e6923..d3100c5 100644 --- a/formulate/logging.py +++ b/formulate/logging.py @@ -1,8 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - from functools import wraps import logging import os @@ -11,19 +7,20 @@ __all__ = [ - 'logger', - 'add_logging', + "logger", + "add_logging", ] -def get_identifier(): - """Generate an identifier for keeping track of return values when logging""" - return''.join(random.choice(string.ascii_uppercase + string.digits) - for i in range(5)) +def get_identifier() -> str: + """Generate an identifier for keeping track of return values when logging.""" + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(5) + ) -def add_logging(*args, **kwargs): - """Decorator to add logging to a method +def add_logging(func=None, *args, ignore_args=None, ignore_kwargs=None, **kwargs): + """Decorator to add logging to a method. Parameters ---------- @@ -34,28 +31,30 @@ def add_logging(*args, **kwargs): ignore_kwargs : list Names of the keyword arguments to ignore """ - # Workaround for Python 2 - func = args[0] if args else kwargs.pop('func', None) - ignore_args = kwargs.pop('ignore_args', None) - ignore_kwargs = kwargs.pop('ignore_kwargs', None) - assert len(args) in [0, 1] and not kwargs, (args, kwargs) def real_decorator(func): @wraps(func) def new_func(*args, **kwargs): my_id = get_identifier() - try: - func_name = func.__qualname__ - except AttributeError: - # Python < 3.3 doesn't have __qualname__ - func_name = func.__name__ + func_name = func.__qualname__ + # Don't log arguments which should be ignored - _args = [a for i, a in enumerate(args) if ignore_args and i not in ignore_args] - _kwargs = {k: v for k, v in kwargs.items() if ignore_kwargs and k not in ignore_kwargs} - logger.debug(my_id+' Calling '+func_name+' with '+repr(_args)+' and '+repr(_kwargs)) + _args = [ + a for i, a in enumerate(args) if ignore_args and i not in ignore_args + ] + _kwargs = { + k: v + for k, v in kwargs.items() + if ignore_kwargs and k not in ignore_kwargs + } + logger.debug( + f"{my_id} Calling {func_name} with {repr(_args)} and {repr(_kwargs)}" + ) + result = func(*args, **kwargs) - logger.debug(my_id+' - Got result '+repr(result)) + logger.debug(f"{my_id} - Got result {repr(result)}") return result + return new_func if func is None: @@ -64,13 +63,15 @@ def new_func(*args, **kwargs): return real_decorator(func) -LOGGER_NAME = 'formulate' +LOGGER_NAME = "formulate" try: import colorlog + handler = colorlog.StreamHandler() - handler.setFormatter(colorlog.ColoredFormatter( - '%(log_color)s%(levelname)s:%(name)s:%(message)s')) + handler.setFormatter( + colorlog.ColoredFormatter("%(log_color)s%(levelname)s:%(name)s:%(message)s") + ) logger = colorlog.getLogger(LOGGER_NAME) logger.addHandler(handler) @@ -80,14 +81,16 @@ def new_func(*args, **kwargs): try: - logger.setLevel({ - 'CRITICAL': logging.CRITICAL, - 'FATAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'WARN': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, - }[os.environ['FOMULATE_LOG_LEVEL'].upper()]) + logger.setLevel( + { + "CRITICAL": logging.CRITICAL, + "FATAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "WARN": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + }[os.environ["FOMULATE_LOG_LEVEL"].upper()] + ) except KeyError: logger.setLevel(logging.WARNING) diff --git a/formulate/parser.py b/formulate/parser.py index 2c8a1d4..97cacfc 100644 --- a/formulate/parser.py +++ b/formulate/parser.py @@ -1,31 +1,32 @@ -# -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - from collections import defaultdict import pyparsing from pyparsing import Literal, Suppress, pyparsing_common, opAssoc, Word, delimitedList -from .expression import Expression, Variable, NamedConstant, UnnamedConstant, ExpressionComponent +from .expression import ( + Expression, + Variable, + NamedConstant, + UnnamedConstant, + ExpressionComponent, +) from .identifiers import order_of_operations from .logging import logger, add_logging __all__ = [ - 'PConstant', - 'PFunction', - 'POperator', - 'Parser', - 'ParsingException', + "PConstant", + "PFunction", + "POperator", + "Parser", + "ParsingException", ] -class PConstant(object): +class PConstant: def __init__(self, id, value): - """Represents a named constant + """Represents a named constant. Parameters ---------- @@ -59,7 +60,7 @@ def value(self): def get_parser(self, EXPRESSION): if isinstance(self.value, str): result = Suppress(self.value) - result.setName('NamedConstant({value})'.format(value=self.value)) + result.setName(f"NamedConstant({self.value})") result.setParseAction(self) return result else: @@ -71,9 +72,9 @@ def to_string(self): return str(self.value) -class PFunction(object): +class PFunction: def __init__(self, id, name, n_args=1): - """Represents a function call with augments + """Represents a function call with augments. Parameters ---------- @@ -97,18 +98,17 @@ def __init__(self, id, name, n_args=1): self._n_args = n_args def __str__(self): - return '{name}<{n_args}>'.format(name=self._name, n_args=self._n_args) + return f"{self._name}<{self._n_args}>" def __repr__(self): - return '{class_name}<{id_name},{name},n_args={n_args}>'.format( - class_name=self.__class__.__name__, id_name=self._id.name, - name=self._name, n_args=self._n_args) + return f"{self.__class__.__name__}<{self._id.name},{self._name},n_args={self._n_args}>" @add_logging def __call__(self, string, location, result): if len(result) != self._n_args: - raise TypeError('Function({name}) requires {n} arguments, {x} given' - .format(name=self._name, n=self._n_args, x=len(result))) + raise TypeError( + f"Function({self._name}) requires {self._n_args} arguments, {len(result)} given" + ) return Expression(self._id, *result) @property @@ -120,11 +120,11 @@ def name(self): return self._name def get_parser(self, EXPRESSION): - result = Suppress(self._name) + Suppress('(') + EXPRESSION - for i in range(1, self._n_args): - result += Suppress(',') + EXPRESSION - result += Suppress(')') - result.setName('Function({name})'.format(name=self._name)) + result = Suppress(self._name) + Suppress("(") + EXPRESSION + for _ in range(1, self._n_args): + result += Suppress(",") + EXPRESSION + result += Suppress(")") + result.setName(f"Function({self._name})") result.setParseAction(self) return result @@ -137,12 +137,12 @@ def to_string(self, expression, config, constants): else: arg = str(arg) args.append(arg) - return self.name+'('+", ".join(args)+')' + return f"{self.name}(" + ", ".join(args) + ")" -class POperator(object): +class POperator: def __init__(self, id, op, rhs_only=False, lhs_only=False): - """Represents an operator of the form "A x B" + """Represents an operator of the form "A x B". Parameters ---------- @@ -167,12 +167,14 @@ def __init__(self, id, op, rhs_only=False, lhs_only=False): self._lhs_only = lhs_only def __str__(self): - return self._name+'<'+str(self._n_args)+'>' + return self._name + "<" + str(self._n_args) + ">" def __repr__(self): - return '{class_name}<{id_name},{op_name},rhs_only={rhs_only},lhs_only={lhs_only}>'.format( - class_name=self.__class__.__name__, id_name=self._id.name, - op_name=self._op, rhs_only=self._rhs_only, lhs_only=self._lhs_only) + return ( + f"{self.__class__.__name__}<{self._id.name}," + f"{self._op},rhs_only={self._rhs_only}," + f"lhs_only={self._lhs_only}>" + ) @add_logging def __call__(self, *result): @@ -215,10 +217,10 @@ def to_string(self, expression, config, constants): return args[0] + self.op else: assert len(args) >= 2, args - return '('+(' '+self.op+' ').join(args)+')' + return "(" + (" " + self.op + " ").join(args) + ")" -class Parser(object): +class Parser: def __init__(self, name, config, constants): self._name = name self._config = config @@ -227,36 +229,39 @@ def __init__(self, name, config, constants): def to_expression(self, string): if not isinstance(string, str): - raise ValueError('Can only convert string objects to strings but '+str(type(string))+' was passed') + raise ValueError( + f"Can only convert string objects to strings but {str(type(string))} was passed" + ) try: result = self._parser.parseString(string, parseAll=True) assert len(result) == 1, result result = result[0] except pyparsing.ParseException as e: - logger.error('TODO TRACEBACK: '+repr(e.args)) - logger.error('Error parsing: '+e.line) - logger.error(' '+' '*e.loc + '▲') - logger.error(' '+' '*e.loc + '┃') - logger.error(' '+' '*e.loc + '┗━━━━━━ Error here or shortly after') - # Remove the context from the exception - # Can't use "raise X from None" with Python 2 - exception = ParsingException() - exception.__context__ = None - raise exception + logger.error(f"TODO TRACEBACK: {repr(e.args)}") + logger.error(f"Error parsing: {e.line}") + logger.error(" " + " " * e.loc + "▲") + logger.error(" " + " " * e.loc + "┃") + logger.error( + " " + " " * e.loc + "┗━━━━━━ Error here or shortly after" + ) + raise ParsingException from None else: return result def to_string(self, expression): if not isinstance(expression, ExpressionComponent): - raise ValueError('Can only convert ExpressionComponent objects to strings but ' + - str(type(expression)) + ' was passed') + raise ValueError( + "Can only convert ExpressionComponent objects to strings but " + + str(type(expression)) + + " was passed" + ) result = expression.to_string( {x.id: x for x in self._config}, {c.id: c for c in self._constants}, ) - if result.startswith('(') and result.endswith(')'): + if result.startswith("(") and result.endswith(")"): result = result[1:-1] return result @@ -268,22 +273,38 @@ class ParsingException(Exception): def create_parser(config, constants): EXPRESSION = pyparsing.Forward() - VARIABLE = delimitedList(Word(pyparsing.alphas+'_', pyparsing.alphanums+'_-'), delim='.', combine=True) - VARIABLE.setName('Variable') - VARIABLE.setParseAction(add_logging(lambda string, location, result: Variable(result[0]))) + VARIABLE = delimitedList( + Word(pyparsing.alphas + "_", pyparsing.alphanums + "_-"), + delim=".", + combine=True, + ) + VARIABLE.setName("Variable") + VARIABLE.setParseAction( + add_logging(lambda string, location, result: Variable(result[0])) + ) REAL = pyparsing_common.real - REAL.setParseAction(add_logging(lambda string, location, result: UnnamedConstant(result[0]))) + REAL.setParseAction( + add_logging(lambda string, location, result: UnnamedConstant(result[0])) + ) SCI_REAL = pyparsing_common.sci_real - SCI_REAL.setParseAction(add_logging(lambda string, location, result: UnnamedConstant(result[0]))) + SCI_REAL.setParseAction( + add_logging(lambda string, location, result: UnnamedConstant(result[0])) + ) SIGNED_INTEGER = pyparsing_common.signed_integer - SIGNED_INTEGER.setParseAction(add_logging(lambda string, location, result: UnnamedConstant(result[0]))) + SIGNED_INTEGER.setParseAction( + add_logging(lambda string, location, result: UnnamedConstant(result[0])) + ) NUMBER = pyparsing.Or([REAL, SCI_REAL, SIGNED_INTEGER]) COMPONENT = pyparsing.Or( - [f.get_parser(EXPRESSION) for f in config if isinstance(f, PFunction)] + - [p for p in map(lambda c: c.get_parser(EXPRESSION), constants) if p is not None] + - [NUMBER, VARIABLE] + [f.get_parser(EXPRESSION) for f in config if isinstance(f, PFunction)] + + [ + p + for p in map(lambda c: c.get_parser(EXPRESSION), constants) + if p is not None + ] + + [NUMBER, VARIABLE] ) # TODO Generating operators_config should be rewritten @@ -299,30 +320,40 @@ def create_parser(config, constants): # TODO This is a hack, is there a nicer way? from .identifiers import IDs + if ops[0].id in (IDs.MINUS, IDs.PLUS): assert ops[0]._rhs_only - parser = pyparsing.Or([Literal(o.op) + ~pyparsing.FollowedBy(NUMBER) for o in ops]) + parser = pyparsing.Or( + [Literal(o.op) + ~pyparsing.FollowedBy(NUMBER) for o in ops] + ) elif ops[0].id in (IDs.SQUARE,): assert ops[0]._lhs_only - parser = pyparsing.Or([Literal(o.op) + ~pyparsing.FollowedBy(NUMBER) for o in ops]) + parser = pyparsing.Or( + [Literal(o.op) + ~pyparsing.FollowedBy(NUMBER) for o in ops] + ) else: parser = pyparsing.Or([Literal(o.op) for o in ops]) if ops[0].rhs_only: + def parse_action(string, location, result, op_map={o.op: o for o in ops}): assert len(result) == 1, result result = result[0] assert len(result) == 2, result return op_map[result[0]](result[1]) + operators_config.append((parser, 1, opAssoc.RIGHT, parse_action)) elif ops[0].lhs_only: + def parse_action(string, location, result, op_map={o.op: o for o in ops}): assert len(result) == 1, result result = result[0] assert len(result) == 2, result return op_map[result[0]](result[1]) + operators_config.append((parser, 1, opAssoc.LEFT, parse_action)) else: + def parse_action(string, location, result, op_map={o.op: o for o in ops}): assert len(result) == 1, result result = result[0] @@ -334,14 +365,19 @@ def parse_action(string, location, result, op_map={o.op: o for o in ops}): if op_name == last_op_name: expression_args.append(value) else: - expression = Expression(op_map[last_op_name].id, expression, *expression_args) + expression = Expression( + op_map[last_op_name].id, expression, *expression_args + ) expression_args = [value] last_op_name = op_name - expression = Expression(op_map[last_op_name].id, expression, *expression_args) + expression = Expression( + op_map[last_op_name].id, expression, *expression_args + ) # for operator, value in zip(result[1::2], result[2::2]): # operator = op_map[operator] # expression = Expression(operator.id, expression, value) return expression + operators_config.append((parser, 2, opAssoc.LEFT, parse_action)) EXPRESSION << pyparsing.infixNotation(COMPONENT, operators_config) diff --git a/setup.cfg b/setup.cfg index b7c6bb7..9814beb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,20 +1,15 @@ [metadata] name = formulate -author = Chris Burr -author_email = c.b@cern.ch -maintainer = The Scikit-HEP admins -maintainer_email = scikit-hep-admins@googlegroups.com description = Convert between different style of formulae long_description = file: README.rst long_description_content_type = text/x-rst url = https://github.com/scikit-hep/formulate -license = BSD 3-Clause License -keywords = - formula - conversion - ROOT - numexpr - HEP +author = Chris Burr +author_email = c.b@cern.ch +maintainer = The Scikit-HEP admins +maintainer_email = scikit-hep-admins@googlegroups.com +license = BSD-3-Clause +license_file = LICENSE platforms = Any classifiers = @@ -23,37 +18,45 @@ classifiers = Intended Audience :: Information Technology Intended Audience :: Science/Research License :: OSI Approved :: BSD License - Operating System :: Microsoft :: Windows Operating System :: MacOS + Operating System :: Microsoft :: Windows Operating System :: POSIX Operating System :: Unix Programming Language :: Python - Programming Language :: Python :: 2.7 + Programming Language :: Python :: 3 + Programming Language :: Python :: 3 :: Only Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Information Analysis Topic :: Scientific/Engineering :: Mathematics Topic :: Scientific/Engineering :: Physics Topic :: Software Development Topic :: Utilities +keywords = + formula + conversion + ROOT + numexpr + HEP [options] -python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.* packages = find: - install_requires = - numpy >=1.13.3 - pyparsing>=2.1.9 - colorlog aenum + colorlog + numpy>=1.13.3 + pyparsing>=2.1.9 scipy +python_requires = >=3.6 [options.packages.find] exclude = tests + [tool:pytest] addopts = -ra -s -Wd testpaths = diff --git a/setup.py b/setup.py index ac0f40b..b2502b9 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,15 @@ #!/usr/bin/env python # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import, division, print_function - # testing if packages are installed import setuptools_scm # noqa: F401 -import toml # noqa: F401 + +# import toml # noqa: F401 from setuptools import setup extras = { - "dev": ["pytest>=4.6", 'numexpr'], - "test": ["pytest>=4.6", 'numexpr'], + "test": ["pytest>=4.6", "numexpr", "pytest-helpers-namespace"], } +extras["dev"] = extras["test"] extras["all"] = sum(extras.values(), []) setup(extras_require=extras) diff --git a/tests/__init__.py b/tests/__init__.py index 3c8d032..b20e1c0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py index 3c8d032..b20e1c0 100644 --- a/tests/backends/__init__.py +++ b/tests/backends/__init__.py @@ -1,4 +1 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function diff --git a/tests/backends/test_ROOT.py b/tests/backends/test_ROOT.py deleted file mode 100644 index 351704f..0000000 --- a/tests/backends/test_ROOT.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed under a 3-clause BSD style license, see LICENSE. -# This file is automatically created by "make root_test" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from formulate import Expression, Variable -from formulate import UnnamedConstant as UC -from formulate import from_root, to_root -from formulate.identifiers import IDs - -from ..utils import make_check_result - - -check_result = make_check_result(from_root, to_root) - - -def _create_test_type(name, A, B, C, D): - kwargs = {'A': str(A), 'B': str(B), 'C': str(C), 'D': str(D)} - - class NewTestClass(object): - def test_basic_math(self): - if isinstance(A, UC): - # TODO - pass - else: - check_result('+{A}', Expression(IDs.PLUS, A), **kwargs) - check_result('-{A}', Expression(IDs.MINUS, A), **kwargs) - check_result('{A} + {B}', Expression(IDs.ADD, A, B), **kwargs) - check_result('{A} - {B}', Expression(IDs.SUB, A, B), **kwargs) - check_result('{A} * {B}', Expression(IDs.MUL, A, B), **kwargs) - check_result('{A} / {B}', Expression(IDs.DIV, A, B), **kwargs) - check_result('{A} % {B}', Expression(IDs.MOD, A, B), **kwargs) - - def test_chain_math(self): - check_result('{A} + {B} + {C} + {D}', Expression(IDs.ADD, A, B, C, D), **kwargs) - check_result('{A} - {B} - {C} - {D}', Expression(IDs.SUB, A, B, C, D), **kwargs) - check_result('{A} * {B} * {C} * {D}', Expression(IDs.MUL, A, B, C, D), **kwargs) - check_result('{A} / {B} / {C} / {D}', Expression(IDs.DIV, A, B, C, D), **kwargs) - check_result('{A} % {B} % {C} % {D}', Expression(IDs.MOD, A, B, C, D), **kwargs) - - def test_basic_boolean_operations(self): - check_result('!{A}', Expression(IDs.NOT, A), **kwargs) - check_result('{A} && {B}', Expression(IDs.AND, A, B), **kwargs) - check_result('{A} || {B}', Expression(IDs.OR, A, B), **kwargs) - check_result('{A} == {B}', Expression(IDs.EQ, A, B), **kwargs) - check_result('{A} != {B}', Expression(IDs.NEQ, A, B), **kwargs) - check_result('{A} > {B}', Expression(IDs.GT, A, B), **kwargs) - check_result('{A} >= {B}', Expression(IDs.GTEQ, A, B), **kwargs) - check_result('{A} < {B}', Expression(IDs.LT, A, B), **kwargs) - check_result('{A} <= {B}', Expression(IDs.LTEQ, A, B), **kwargs) - - # def test_chain_boolean_operations(self): - # check_result('{A} && {B} && {C}', Expression(IDs.AND, A, Expression(IDs.AND, B, C)), **kwargs) - - def test_basic_functions(self): - check_result('TMath::Sqrt({A})', Expression(IDs.SQRT, A), **kwargs) - check_result('TMath::ATan2({A}, {B})', Expression(IDs.ATAN2, A, B), **kwargs) - - def test_signed_functions(self): - check_result('TMath::Sqrt({A})', Expression(IDs.SQRT, A), **kwargs) - check_result('TMath::ATan2({A}, {B})', Expression(IDs.ATAN2, A, B), **kwargs) - check_result('-TMath::Sqrt({A})', Expression(IDs.MINUS, Expression(IDs.SQRT, A)), **kwargs) - check_result('+TMath::Sqrt({A})', Expression(IDs.PLUS, Expression(IDs.SQRT, A)), **kwargs) - check_result('- TMath::ATan2({A}, {B})', Expression(IDs.MINUS, Expression(IDs.ATAN2, A, B)), **kwargs) - check_result(' + TMath::ATan2({A}, {B})', Expression(IDs.PLUS, Expression(IDs.ATAN2, A, B)), **kwargs) - - def test_math_with_functions(self): - if isinstance(A, UC): - # TODO - pass - else: - check_result('TMath::Sqrt(-{A})', Expression(IDs.SQRT, Expression(IDs.MINUS, A)), **kwargs) - check_result('TMath::Sqrt(+ {A})', Expression(IDs.SQRT, Expression(IDs.PLUS, A)), **kwargs) - check_result('TMath::Sqrt({A} + {B})', Expression(IDs.SQRT, Expression(IDs.ADD, A, B)), **kwargs) - check_result('TMath::ATan2({A} - {B}, {B} % {A})', Expression(IDs.ATAN2, Expression(IDs.SUB, A, B), Expression(IDs.MOD, B, A)), **kwargs) - check_result('TMath::Sqrt({A})+TMath::Sqrt({B})', Expression(IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B)), **kwargs) - check_result('TMath::Sqrt({A})*TMath::Sqrt({A})', Expression(IDs.MUL, Expression(IDs.SQRT, A), Expression(IDs.SQRT, A)), **kwargs) - check_result('TMath::ATan2({A}, {B})/TMath::Sqrt({B})', Expression(IDs.DIV, Expression(IDs.ATAN2, A, B), Expression(IDs.SQRT, B)), **kwargs) - - def test_functions_of_functions(self): - check_result('TMath::Sqrt(TMath::Sqrt({A}))', Expression(IDs.SQRT, Expression(IDs.SQRT, A)), **kwargs) - check_result('TMath::Sqrt(TMath::ATan2({A}, {B}))', Expression(IDs.SQRT, Expression(IDs.ATAN2, A, B)), **kwargs) - check_result('TMath::ATan2(TMath::Sqrt({A}), {B})', Expression(IDs.ATAN2, Expression(IDs.SQRT, A), B), **kwargs) - - def test_nested(self): - check_result('TMath::Sqrt(-TMath::Sqrt({A}))', Expression(IDs.SQRT, Expression(IDs.MINUS, Expression(IDs.SQRT, A))), **kwargs) - check_result('TMath::Sqrt(TMath::Sqrt({A}) + {B})', Expression(IDs.SQRT, Expression(IDs.ADD, Expression(IDs.SQRT, A), B)), **kwargs) - check_result('TMath::Sqrt(TMath::Sqrt({A}) + TMath::Sqrt({B}))', Expression(IDs.SQRT, Expression(IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B))), **kwargs) - - NewTestClass.__name__ = name - - return NewTestClass - - -TestPosInts = _create_test_type('TestPosInts', UC('1'), UC('2'), UC('3'), UC('4')) -TestNegInts = _create_test_type('TestNegInts', UC('-1'), UC('-2'), UC('-3'), UC('-4')) -TestMixInts = _create_test_type('TestMixInts', UC('1'), UC('-2'), UC('3'), UC('-4')) - -TestPosFloats = _create_test_type('TestPosFloats', UC('1.2'), UC('3.4'), UC('4.5'), UC('6.7')) -TestNegFloats = _create_test_type('TestNegFloats', UC('-1.2'), UC('-3.4'), UC('-4.5'), UC('-6.7')) -TestMixFloats = _create_test_type('TestMixFloats', UC('1.2'), UC('-3.4'), UC('4.5'), UC('-6.7')) - -TestPosScientific = _create_test_type('TestPosScientific', UC('1e-2'), UC('3.4e5'), UC('6.7e8'), UC('9e10')) -TestNegScientific = _create_test_type('TestNegScientific', UC('-1e-2'), UC('-3.4e5'), UC('-6.7e8'), UC('-9e10')) -TestMixScientific = _create_test_type('TestMixScientific', UC('1e-2'), UC('-3.4e5'), UC('6.7e8'), UC('-9e10')) - -TestVariables = _create_test_type('TestVariables', Variable('A'), Variable('Bee'), Variable('C_is_4'), Variable('_Dxyz')) diff --git a/tests/backends/test_backends.py b/tests/backends/test_backends.py index 9c49a7b..2a6a169 100644 --- a/tests/backends/test_backends.py +++ b/tests/backends/test_backends.py @@ -1,8 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import pytest import numpy as np import numexpr @@ -25,34 +21,38 @@ def test(): return test -test_001 = do_checks('True', 'true') -test_002 = do_checks('False', 'false') -test_003 = do_checks('sqrt(2)', 'sqrt(2)') -test_004 = do_checks('sqrt(2)', 'TMath::Sqrt(2)') -test_005 = do_checks('sqrt(abs(-4))', 'TMath::Sqrt(TMath::Abs(-4))') -test_006 = do_checks('A & B & C & D', 'A && B && C && D') -test_007 = do_checks('A & B | C & D', 'A && B || C && D') -test_008 = do_checks('A & ~B | C & D', 'A && !B || C && D') +test_001 = do_checks("True", "true") +test_002 = do_checks("False", "false") +test_003 = do_checks("sqrt(2)", "sqrt(2)") +test_004 = do_checks("sqrt(2)", "TMath::Sqrt(2)") +test_005 = do_checks("sqrt(abs(-4))", "TMath::Sqrt(TMath::Abs(-4))") +test_006 = do_checks("A & B & C & D", "A && B && C && D") +test_007 = do_checks("A & B | C & D", "A && B || C && D") +test_008 = do_checks("A & ~B | C & D", "A && !B || C && D") def test_readme(): - momentum = from_root('TMath::Sqrt(X_PX**2 + X_PY**2 + X_PZ**2)') - assert momentum.to_numexpr() == 'sqrt(((X_PX ** 2) + (X_PY ** 2) + (X_PZ ** 2)))' - assert momentum.to_root() == 'TMath::Sqrt(((X_PX ** 2) + (X_PY ** 2) + (X_PZ ** 2)))' - my_selection = from_numexpr('X_PT > 5 & (Mu_NHits > 3 | Mu_PT > 10)') - assert my_selection.to_root() == '(X_PT > 5) && ((Mu_NHits > 3) || (Mu_PT > 10))' - assert my_selection.to_numexpr() == '(X_PT > 5) & ((Mu_NHits > 3) | (Mu_PT > 10))' - my_sum = from_auto('True + False') - assert my_sum.to_root() == 'true + false' - assert my_sum.to_numexpr() == 'True + False' - my_check = from_auto('(X_THETA*TMath::DegToRad() > pi/4) && D_PE > 9.2') - assert my_check.variables == {'D_PE', 'X_THETA'} - assert my_check.named_constants == {'DEG2RAD', 'PI'} - assert my_check.unnamed_constants == {'4', '9.2'} + momentum = from_root("TMath::Sqrt(X_PX**2 + X_PY**2 + X_PZ**2)") + assert momentum.to_numexpr() == "sqrt(((X_PX ** 2) + (X_PY ** 2) + (X_PZ ** 2)))" + assert ( + momentum.to_root() == "TMath::Sqrt(((X_PX ** 2) + (X_PY ** 2) + (X_PZ ** 2)))" + ) + my_selection = from_numexpr("X_PT > 5 & (Mu_NHits > 3 | Mu_PT > 10)") + assert my_selection.to_root() == "(X_PT > 5) && ((Mu_NHits > 3) || (Mu_PT > 10))" + assert my_selection.to_numexpr() == "(X_PT > 5) & ((Mu_NHits > 3) | (Mu_PT > 10))" + my_sum = from_auto("True + False") + assert my_sum.to_root() == "true + false" + assert my_sum.to_numexpr() == "True + False" + my_check = from_auto("(X_THETA*TMath::DegToRad() > pi/4) && D_PE > 9.2") + assert my_check.variables == {"D_PE", "X_THETA"} + assert my_check.named_constants == {"DEG2RAD", "PI"} + assert my_check.unnamed_constants == {"4", "9.2"} new_selection = (momentum > 100) and (my_check or (np.sqrt(my_sum) < 1)) def numexpr_eval(string): return numexpr.evaluate(string, local_dict=dict(X_THETA=1234, D_PE=678)) - assert pytest.approx(numexpr_eval(new_selection.to_numexpr()), - numexpr_eval('((X_THETA * 0.017453292519943295) > (3.141592653589793 / 4)) & (D_PE > 9.2)')) + true_numexpr_eval = numexpr_eval( + "((X_THETA * 0.017453292519943295) > (3.141592653589793 / 4)) & (D_PE > 9.2)" + ) + assert numexpr_eval(new_selection.to_numexpr()) == pytest.approx(true_numexpr_eval) diff --git a/tests/backends/test_numexpr.py b/tests/backends/test_numexpr.py index 397561c..50990c0 100644 --- a/tests/backends/test_numexpr.py +++ b/tests/backends/test_numexpr.py @@ -1,8 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - from formulate import Expression, Variable from formulate import UnnamedConstant as UC from formulate import from_numexpr, to_numexpr @@ -15,93 +11,194 @@ def _create_test_type(name, A, B, C, D): - kwargs = {'A': str(A), 'B': str(B), 'C': str(C), 'D': str(D)} + kwargs = {"A": str(A), "B": str(B), "C": str(C), "D": str(D)} - class NewTestClass(object): + class NewTestClass: def test_basic_math(self): if isinstance(A, UC): # TODO pass else: - check_result('+{A}', Expression(IDs.PLUS, A), **kwargs) - check_result('-{A}', Expression(IDs.MINUS, A), **kwargs) - check_result('{A} + {B}', Expression(IDs.ADD, A, B), **kwargs) - check_result('{A} - {B}', Expression(IDs.SUB, A, B), **kwargs) - check_result('{A} * {B}', Expression(IDs.MUL, A, B), **kwargs) - check_result('{A} / {B}', Expression(IDs.DIV, A, B), **kwargs) - check_result('{A} % {B}', Expression(IDs.MOD, A, B), **kwargs) + check_result("+{A}", Expression(IDs.PLUS, A), **kwargs) + check_result("-{A}", Expression(IDs.MINUS, A), **kwargs) + check_result("{A} + {B}", Expression(IDs.ADD, A, B), **kwargs) + check_result("{A} - {B}", Expression(IDs.SUB, A, B), **kwargs) + check_result("{A} * {B}", Expression(IDs.MUL, A, B), **kwargs) + check_result("{A} / {B}", Expression(IDs.DIV, A, B), **kwargs) + check_result("{A} % {B}", Expression(IDs.MOD, A, B), **kwargs) def test_chain_math(self): - check_result('{A} + {B} + {C} + {D}', Expression(IDs.ADD, A, B, C, D), **kwargs) - check_result('{A} - {B} - {C} - {D}', Expression(IDs.SUB, A, B, C, D), **kwargs) - check_result('{A} * {B} * {C} * {D}', Expression(IDs.MUL, A, B, C, D), **kwargs) - check_result('{A} / {B} / {C} / {D}', Expression(IDs.DIV, A, B, C, D), **kwargs) - check_result('{A} % {B} % {C} % {D}', Expression(IDs.MOD, A, B, C, D), **kwargs) + check_result( + "{A} + {B} + {C} + {D}", Expression(IDs.ADD, A, B, C, D), **kwargs + ) + check_result( + "{A} - {B} - {C} - {D}", Expression(IDs.SUB, A, B, C, D), **kwargs + ) + check_result( + "{A} * {B} * {C} * {D}", Expression(IDs.MUL, A, B, C, D), **kwargs + ) + check_result( + "{A} / {B} / {C} / {D}", Expression(IDs.DIV, A, B, C, D), **kwargs + ) + check_result( + "{A} % {B} % {C} % {D}", Expression(IDs.MOD, A, B, C, D), **kwargs + ) def test_basic_boolean_operations(self): - check_result('~{A}', Expression(IDs.NOT, A), **kwargs) - check_result('{A} & {B}', Expression(IDs.AND, A, B), **kwargs) - check_result('{A} | {B}', Expression(IDs.OR, A, B), **kwargs) - check_result('{A} == {B}', Expression(IDs.EQ, A, B), **kwargs) - check_result('{A} != {B}', Expression(IDs.NEQ, A, B), **kwargs) - check_result('{A} > {B}', Expression(IDs.GT, A, B), **kwargs) - check_result('{A} >= {B}', Expression(IDs.GTEQ, A, B), **kwargs) - check_result('{A} < {B}', Expression(IDs.LT, A, B), **kwargs) - check_result('{A} <= {B}', Expression(IDs.LTEQ, A, B), **kwargs) + check_result("~{A}", Expression(IDs.NOT, A), **kwargs) + check_result("{A} & {B}", Expression(IDs.AND, A, B), **kwargs) + check_result("{A} | {B}", Expression(IDs.OR, A, B), **kwargs) + check_result("{A} == {B}", Expression(IDs.EQ, A, B), **kwargs) + check_result("{A} != {B}", Expression(IDs.NEQ, A, B), **kwargs) + check_result("{A} > {B}", Expression(IDs.GT, A, B), **kwargs) + check_result("{A} >= {B}", Expression(IDs.GTEQ, A, B), **kwargs) + check_result("{A} < {B}", Expression(IDs.LT, A, B), **kwargs) + check_result("{A} <= {B}", Expression(IDs.LTEQ, A, B), **kwargs) # def test_chain_boolean_operations(self): # check_result('{A} & {B} & {C}', Expression(IDs.AND, A, Expression(IDs.AND, B, C)), **kwargs) def test_basic_functions(self): - check_result('sqrt({A})', Expression(IDs.SQRT, A), **kwargs) - check_result('arctan2({A}, {B})', Expression(IDs.ATAN2, A, B), **kwargs) + check_result("sqrt({A})", Expression(IDs.SQRT, A), **kwargs) + check_result("arctan2({A}, {B})", Expression(IDs.ATAN2, A, B), **kwargs) def test_signed_functions(self): - check_result('sqrt({A})', Expression(IDs.SQRT, A), **kwargs) - check_result('arctan2({A}, {B})', Expression(IDs.ATAN2, A, B), **kwargs) - check_result('-sqrt({A})', Expression(IDs.MINUS, Expression(IDs.SQRT, A)), **kwargs) - check_result('+sqrt({A})', Expression(IDs.PLUS, Expression(IDs.SQRT, A)), **kwargs) - check_result('- arctan2({A}, {B})', Expression(IDs.MINUS, Expression(IDs.ATAN2, A, B)), **kwargs) - check_result(' + arctan2({A}, {B})', Expression(IDs.PLUS, Expression(IDs.ATAN2, A, B)), **kwargs) + check_result("sqrt({A})", Expression(IDs.SQRT, A), **kwargs) + check_result("arctan2({A}, {B})", Expression(IDs.ATAN2, A, B), **kwargs) + check_result( + "-sqrt({A})", Expression(IDs.MINUS, Expression(IDs.SQRT, A)), **kwargs + ) + check_result( + "+sqrt({A})", Expression(IDs.PLUS, Expression(IDs.SQRT, A)), **kwargs + ) + check_result( + "- arctan2({A}, {B})", + Expression(IDs.MINUS, Expression(IDs.ATAN2, A, B)), + **kwargs + ) + check_result( + " + arctan2({A}, {B})", + Expression(IDs.PLUS, Expression(IDs.ATAN2, A, B)), + **kwargs + ) def test_math_with_functions(self): if isinstance(A, UC): # TODO pass else: - check_result('sqrt(-{A})', Expression(IDs.SQRT, Expression(IDs.MINUS, A)), **kwargs) - check_result('sqrt(+ {A})', Expression(IDs.SQRT, Expression(IDs.PLUS, A)), **kwargs) - check_result('sqrt({A} + {B})', Expression(IDs.SQRT, Expression(IDs.ADD, A, B)), **kwargs) - check_result('arctan2({A} - {B}, {B} % {A})', Expression(IDs.ATAN2, Expression(IDs.SUB, A, B), Expression(IDs.MOD, B, A)), **kwargs) - check_result('sqrt({A})+sqrt({B})', Expression(IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B)), **kwargs) - check_result('sqrt({A})*sqrt({A})', Expression(IDs.MUL, Expression(IDs.SQRT, A), Expression(IDs.SQRT, A)), **kwargs) - check_result('arctan2({A}, {B})/sqrt({B})', Expression(IDs.DIV, Expression(IDs.ATAN2, A, B), Expression(IDs.SQRT, B)), **kwargs) + check_result( + "sqrt(-{A})", + Expression(IDs.SQRT, Expression(IDs.MINUS, A)), + **kwargs + ) + check_result( + "sqrt(+ {A})", + Expression(IDs.SQRT, Expression(IDs.PLUS, A)), + **kwargs + ) + check_result( + "sqrt({A} + {B})", + Expression(IDs.SQRT, Expression(IDs.ADD, A, B)), + **kwargs + ) + check_result( + "arctan2({A} - {B}, {B} % {A})", + Expression( + IDs.ATAN2, Expression(IDs.SUB, A, B), Expression(IDs.MOD, B, A) + ), + **kwargs + ) + check_result( + "sqrt({A})+sqrt({B})", + Expression(IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B)), + **kwargs + ) + check_result( + "sqrt({A})*sqrt({A})", + Expression(IDs.MUL, Expression(IDs.SQRT, A), Expression(IDs.SQRT, A)), + **kwargs + ) + check_result( + "arctan2({A}, {B})/sqrt({B})", + Expression( + IDs.DIV, Expression(IDs.ATAN2, A, B), Expression(IDs.SQRT, B) + ), + **kwargs + ) def test_functions_of_functions(self): - check_result('sqrt(sqrt({A}))', Expression(IDs.SQRT, Expression(IDs.SQRT, A)), **kwargs) - check_result('sqrt(arctan2({A}, {B}))', Expression(IDs.SQRT, Expression(IDs.ATAN2, A, B)), **kwargs) - check_result('arctan2(sqrt({A}), {B})', Expression(IDs.ATAN2, Expression(IDs.SQRT, A), B), **kwargs) + check_result( + "sqrt(sqrt({A}))", + Expression(IDs.SQRT, Expression(IDs.SQRT, A)), + **kwargs + ) + check_result( + "sqrt(arctan2({A}, {B}))", + Expression(IDs.SQRT, Expression(IDs.ATAN2, A, B)), + **kwargs + ) + check_result( + "arctan2(sqrt({A}), {B})", + Expression(IDs.ATAN2, Expression(IDs.SQRT, A), B), + **kwargs + ) def test_nested(self): - check_result('sqrt(-sqrt({A}))', Expression(IDs.SQRT, Expression(IDs.MINUS, Expression(IDs.SQRT, A))), **kwargs) - check_result('sqrt(sqrt({A}) + {B})', Expression(IDs.SQRT, Expression(IDs.ADD, Expression(IDs.SQRT, A), B)), **kwargs) - check_result('sqrt(sqrt({A}) + sqrt({B}))', Expression(IDs.SQRT, Expression(IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B))), **kwargs) + check_result( + "sqrt(-sqrt({A}))", + Expression(IDs.SQRT, Expression(IDs.MINUS, Expression(IDs.SQRT, A))), + **kwargs + ) + check_result( + "sqrt(sqrt({A}) + {B})", + Expression(IDs.SQRT, Expression(IDs.ADD, Expression(IDs.SQRT, A), B)), + **kwargs + ) + check_result( + "sqrt(sqrt({A}) + sqrt({B}))", + Expression( + IDs.SQRT, + Expression( + IDs.ADD, Expression(IDs.SQRT, A), Expression(IDs.SQRT, B) + ), + ), + **kwargs + ) NewTestClass.__name__ = name return NewTestClass -TestPosInts = _create_test_type('TestPosInts', UC('1'), UC('2'), UC('3'), UC('4')) -TestNegInts = _create_test_type('TestNegInts', UC('-1'), UC('-2'), UC('-3'), UC('-4')) -TestMixInts = _create_test_type('TestMixInts', UC('1'), UC('-2'), UC('3'), UC('-4')) - -TestPosFloats = _create_test_type('TestPosFloats', UC('1.2'), UC('3.4'), UC('4.5'), UC('6.7')) -TestNegFloats = _create_test_type('TestNegFloats', UC('-1.2'), UC('-3.4'), UC('-4.5'), UC('-6.7')) -TestMixFloats = _create_test_type('TestMixFloats', UC('1.2'), UC('-3.4'), UC('4.5'), UC('-6.7')) - -TestPosScientific = _create_test_type('TestPosScientific', UC('1e-2'), UC('3.4e5'), UC('6.7e8'), UC('9e10')) -TestNegScientific = _create_test_type('TestNegScientific', UC('-1e-2'), UC('-3.4e5'), UC('-6.7e8'), UC('-9e10')) -TestMixScientific = _create_test_type('TestMixScientific', UC('1e-2'), UC('-3.4e5'), UC('6.7e8'), UC('-9e10')) - -TestVariables = _create_test_type('TestVariables', Variable('A'), Variable('Bee'), Variable('C_is_4'), Variable('_Dxyz')) +TestPosInts = _create_test_type("TestPosInts", UC("1"), UC("2"), UC("3"), UC("4")) +TestNegInts = _create_test_type("TestNegInts", UC("-1"), UC("-2"), UC("-3"), UC("-4")) +TestMixInts = _create_test_type("TestMixInts", UC("1"), UC("-2"), UC("3"), UC("-4")) + +TestPosFloats = _create_test_type( + "TestPosFloats", UC("1.2"), UC("3.4"), UC("4.5"), UC("6.7") +) +TestNegFloats = _create_test_type( + "TestNegFloats", UC("-1.2"), UC("-3.4"), UC("-4.5"), UC("-6.7") +) +TestMixFloats = _create_test_type( + "TestMixFloats", UC("1.2"), UC("-3.4"), UC("4.5"), UC("-6.7") +) + +TestPosScientific = _create_test_type( + "TestPosScientific", UC("1e-2"), UC("3.4e5"), UC("6.7e8"), UC("9e10") +) +TestNegScientific = _create_test_type( + "TestNegScientific", UC("-1e-2"), UC("-3.4e5"), UC("-6.7e8"), UC("-9e10") +) +TestMixScientific = _create_test_type( + "TestMixScientific", UC("1e-2"), UC("-3.4e5"), UC("6.7e8"), UC("-9e10") +) + +TestVariables = _create_test_type( + "TestVariables", + Variable("A"), + Variable("Bee"), + Variable("C_is_4"), + Variable("_Dxyz"), +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..832dee8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,76 @@ +import pytest + + +def root_eval(string, x=None, y=None, z=None, t=None): + import ROOT + + f = ROOT.TFormula("", string) + f.Compile() + if x is None: + assert y is None and z is None and t is None + return f.Eval(0) + elif y is None: + assert z is None and t is None + return f.Eval(x) + elif z is None: + assert t is None + return f.Eval(x, y) + elif t is None: + return f.Eval(x, y, z) + else: + return f.Eval(x, y, z, t) + + +def numexpr_eval(string, **kwargs): + import numexpr + + return numexpr.evaluate(string, local_dict=kwargs) + + +@pytest.helpers.register +def create_formula_test( + input_string, input_backend="root", root_raises=None, numexpr_raises=None +): + assert input_backend in ("root", "numexpr"), "Unrecognised backend specified" + from formulate import from_root, from_numexpr + + input_from_method = { + "root": from_root, + "numexpr": from_numexpr, + }[input_backend] + + def test_constant(): + from formulate import to_root, to_numexpr + + expression = input_from_method(input_string) + + if input_backend == "root": + from formulate import to_root + + root_result = to_root(expression) + assert input_string, root_result + + if numexpr_raises: + with pytest.raises(numexpr_raises): + from formulate import to_numexpr + + to_numexpr(expression) + else: + numexpr_result = to_numexpr(expression) + assert root_eval(root_result) == pytest.approx( + numexpr_eval(numexpr_result) + ) + else: + numexpr_result = to_numexpr(expression) + assert input_string, numexpr_result + + if root_raises: + with pytest.raises(root_raises): + to_root(expression) + else: + root_result = to_root(expression) + assert numexpr_eval(numexpr_result) == pytest.approx( + root_eval(root_result) + ) + + return test_constant diff --git a/tests/test_constants.py b/tests/test_constants.py index cba9480..dc90432 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -1,117 +1,87 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import pytest -from formulate import from_numexpr, to_numexpr -from formulate import from_root, to_root - -numexpr = pytest.importorskip("numexpr") -ROOT = pytest.importorskip("ROOT") - - -def root_eval(string, x=None, y=None, z=None, t=None): - f = ROOT.TFormula('', string) - f.Compile() - if x is None: - assert y is None and z is None and t is None - return f.Eval(0) - elif y is None: - assert z is None and t is None - return f.Eval(x) - elif z is None: - assert t is None - return f.Eval(x, y) - elif t is None: - return f.Eval(x, y, z) - else: - return f.Eval(x, y, z, t) - - -def numexpr_eval(string, **kwargs): - return numexpr.evaluate(string, local_dict=kwargs) - - -def create_constant_test(input_string, input_backend='root', numexpr_raises=None): - assert input_backend in ('root', 'numexpr'), 'Unrecognised backend specified' - input_from_method = { - 'root': from_root, - 'numexpr': from_numexpr, - }[input_backend] - - def test_constant(): - expression = input_from_method(input_string) - - root_result = to_root(expression) - assert pytest.approx(root_eval(input_string), root_eval(root_result)) - - if numexpr_raises: - with pytest.raises(numexpr_raises): - numexpr_result = to_numexpr(expression) - else: - numexpr_result = to_numexpr(expression) - assert pytest.approx(root_eval(root_result), numexpr_eval(numexpr_result)) - - return test_constant - - # Test basic numexpr constants -test_numexpr_true = create_constant_test('True', input_backend='numexpr') -test_numexpr_false = create_constant_test('False', input_backend='numexpr') +test_numexpr_true = pytest.helpers.create_formula_test("True", input_backend="numexpr") +test_numexpr_false = pytest.helpers.create_formula_test( + "False", input_backend="numexpr" +) # Test basic ROOT constants -test_true = create_constant_test('true') -test_false = create_constant_test('false') -test_infinity = create_constant_test('TMath::Infinity()', numexpr_raises=NotImplementedError) -test_nan = create_constant_test('TMath::QuietNaN()', numexpr_raises=NotImplementedError) - -test_sqrt2_1 = create_constant_test('sqrt2') -test_sqrt2_2 = create_constant_test('TMath::Sqrt2()') -test_e_1 = create_constant_test('e') -test_e_2 = create_constant_test('TMath::E()') -test_pi_1 = create_constant_test('pi') -test_pi_2 = create_constant_test('TMath::Pi()') -test_pi_over_2 = create_constant_test('TMath::PiOver2()') -test_pi_over_4 = create_constant_test('TMath::PiOver4()') -test_two_pi = create_constant_test('TMath::TwoPi()') -test_inv_pi = create_constant_test('TMath::InvPi()') -test_ln10_1 = create_constant_test('ln10') -test_ln10_2 = create_constant_test('TMath::Ln10()') -test_log10e = create_constant_test('TMath::LogE()') -test_deg2rad = create_constant_test('TMath::DegToRad()') -test_rad2deg = create_constant_test('TMath::RadToDeg()') - -test_na = create_constant_test('TMath::Na()') -test_nauncertainty = create_constant_test('TMath::NaUncertainty()', numexpr_raises=NotImplementedError) -test_k = create_constant_test('TMath::K()') -test_kcgs = create_constant_test('TMath::Kcgs()') -test_kuncertainty = create_constant_test('TMath::KUncertainty()', numexpr_raises=NotImplementedError) -test_c = create_constant_test('TMath::C()') -test_ccgs = create_constant_test('TMath::Ccgs()') -test_cuncertainty = create_constant_test('TMath::CUncertainty()') -test_rgair = create_constant_test('TMath::Rgair()') -test_qe = create_constant_test('TMath::Qe()') -test_qeuncertainty = create_constant_test('TMath::QeUncertainty()', numexpr_raises=NotImplementedError) -test_eulergamma = create_constant_test('TMath::EulerGamma()') -test_g = create_constant_test('TMath::G()') -test_gcgs = create_constant_test('TMath::Gcgs()') -test_guncertainty = create_constant_test('TMath::GUncertainty()', numexpr_raises=NotImplementedError) -test_ghbarc = create_constant_test('TMath::GhbarC()') -test_ghbarcuncertainty = create_constant_test('TMath::GhbarCUncertainty()', numexpr_raises=NotImplementedError) -test_gn = create_constant_test('TMath::Gn()') -test_gnuncertainty = create_constant_test('TMath::GnUncertainty()', numexpr_raises=NotImplementedError) -test_h = create_constant_test('TMath::H()') -test_hcgs = create_constant_test('TMath::Hcgs()') -test_huncertainty = create_constant_test('TMath::HUncertainty()', numexpr_raises=NotImplementedError) -test_hbar = create_constant_test('TMath::Hbar()') -test_hbarcgs = create_constant_test('TMath::Hbarcgs()') -test_hbaruncertainty = create_constant_test('TMath::HbarUncertainty()', numexpr_raises=NotImplementedError) -test_hc = create_constant_test('TMath::HC()') -test_hccgs = create_constant_test('TMath::HCcgs()') -test_mwair = create_constant_test('TMath::MWair()') -test_sigma = create_constant_test('TMath::Sigma()') -test_sigmauncertainty = create_constant_test('TMath::SigmaUncertainty()', numexpr_raises=NotImplementedError) -test_r = create_constant_test('TMath::R()') -test_runcertainty = create_constant_test('TMath::RUncertainty()', numexpr_raises=NotImplementedError) +test_true = pytest.helpers.create_formula_test("true") +test_false = pytest.helpers.create_formula_test("false") +test_infinity = pytest.helpers.create_formula_test( + "TMath::Infinity()", numexpr_raises=NotImplementedError +) +test_nan = pytest.helpers.create_formula_test( + "TMath::QuietNaN()", numexpr_raises=NotImplementedError +) + +test_sqrt2_1 = pytest.helpers.create_formula_test("sqrt2") +test_sqrt2_2 = pytest.helpers.create_formula_test("TMath::Sqrt2()") +test_e_1 = pytest.helpers.create_formula_test("e") +test_e_2 = pytest.helpers.create_formula_test("TMath::E()") +test_pi_1 = pytest.helpers.create_formula_test("pi") +test_pi_2 = pytest.helpers.create_formula_test("TMath::Pi()") +test_pi_over_2 = pytest.helpers.create_formula_test("TMath::PiOver2()") +test_pi_over_4 = pytest.helpers.create_formula_test("TMath::PiOver4()") +test_two_pi = pytest.helpers.create_formula_test("TMath::TwoPi()") +test_inv_pi = pytest.helpers.create_formula_test("TMath::InvPi()") +test_ln10_1 = pytest.helpers.create_formula_test("ln10") +test_ln10_2 = pytest.helpers.create_formula_test("TMath::Ln10()") +test_log10e = pytest.helpers.create_formula_test("TMath::LogE()") +test_deg2rad = pytest.helpers.create_formula_test("TMath::DegToRad()") +test_rad2deg = pytest.helpers.create_formula_test("TMath::RadToDeg()") + +test_na = pytest.helpers.create_formula_test("TMath::Na()") +test_nauncertainty = pytest.helpers.create_formula_test( + "TMath::NaUncertainty()", numexpr_raises=NotImplementedError +) +test_k = pytest.helpers.create_formula_test("TMath::K()") +test_kcgs = pytest.helpers.create_formula_test("TMath::Kcgs()") +test_kuncertainty = pytest.helpers.create_formula_test( + "TMath::KUncertainty()", numexpr_raises=NotImplementedError +) +test_c = pytest.helpers.create_formula_test("TMath::C()") +test_ccgs = pytest.helpers.create_formula_test("TMath::Ccgs()") +test_cuncertainty = pytest.helpers.create_formula_test("TMath::CUncertainty()") +test_rgair = pytest.helpers.create_formula_test("TMath::Rgair()") +test_qe = pytest.helpers.create_formula_test("TMath::Qe()") +test_qeuncertainty = pytest.helpers.create_formula_test( + "TMath::QeUncertainty()", numexpr_raises=NotImplementedError +) +test_eulergamma = pytest.helpers.create_formula_test("TMath::EulerGamma()") +test_g = pytest.helpers.create_formula_test("TMath::G()") +test_gcgs = pytest.helpers.create_formula_test("TMath::Gcgs()") +test_guncertainty = pytest.helpers.create_formula_test( + "TMath::GUncertainty()", numexpr_raises=NotImplementedError +) +test_ghbarc = pytest.helpers.create_formula_test("TMath::GhbarC()") +test_ghbarcuncertainty = pytest.helpers.create_formula_test( + "TMath::GhbarCUncertainty()", numexpr_raises=NotImplementedError +) +test_gn = pytest.helpers.create_formula_test("TMath::Gn()") +test_gnuncertainty = pytest.helpers.create_formula_test( + "TMath::GnUncertainty()", numexpr_raises=NotImplementedError +) +test_h = pytest.helpers.create_formula_test("TMath::H()") +test_hcgs = pytest.helpers.create_formula_test("TMath::Hcgs()") +test_huncertainty = pytest.helpers.create_formula_test( + "TMath::HUncertainty()", numexpr_raises=NotImplementedError +) +test_hbar = pytest.helpers.create_formula_test("TMath::Hbar()") +test_hbarcgs = pytest.helpers.create_formula_test("TMath::Hbarcgs()") +test_hbaruncertainty = pytest.helpers.create_formula_test( + "TMath::HbarUncertainty()", numexpr_raises=NotImplementedError +) +test_hc = pytest.helpers.create_formula_test("TMath::HC()") +test_hccgs = pytest.helpers.create_formula_test("TMath::HCcgs()") +test_mwair = pytest.helpers.create_formula_test("TMath::MWair()") +test_sigma = pytest.helpers.create_formula_test("TMath::Sigma()") +test_sigmauncertainty = pytest.helpers.create_formula_test( + "TMath::SigmaUncertainty()", numexpr_raises=NotImplementedError +) +test_r = pytest.helpers.create_formula_test("TMath::R()") +test_runcertainty = pytest.helpers.create_formula_test( + "TMath::RUncertainty()", numexpr_raises=NotImplementedError +) diff --git a/tests/test_expression.py b/tests/test_expression.py index be524b5..1ebe253 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,29 +1,38 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import, division, print_function - import sys import numpy as np import pytest -from formulate import from_numexpr, from_root, Expression, Variable, NamedConstant as NC, UnnamedConstant as UC +from formulate import ( + from_numexpr, + from_root, + Expression, + Variable, + NamedConstant as NC, + UnnamedConstant as UC, +) from formulate.identifiers import IDs, ConstantIDs from .utils import assert_equal_expressions as aee def test_get_variables(): - assert from_root('pi').variables == set() - assert from_numexpr('2').variables == set() - assert from_numexpr('2e-3').variables == set() - assert from_numexpr('A').variables == {'A'} - assert from_numexpr('A + A').variables == {'A'} - assert from_numexpr('A + B').variables == {'A', 'B'} - assert from_numexpr('A + A*A - 3e7').variables == {'A'} - assert from_numexpr('arctan2(A, A)').variables == {'A'} - assert from_numexpr('arctan2(A, B)').variables == {'A', 'B'} - assert from_root('arctan2(A, pi)').variables == {'A'} - assert from_numexpr('arctan2(arctan2(A, B), C)').variables == {'A', 'B', 'C'} - for base, expect in [(UC('2'), set()), (Variable('A'), {'A'}), (NC(ConstantIDs.PI), set())]: + assert from_root("pi").variables == set() + assert from_numexpr("2").variables == set() + assert from_numexpr("2e-3").variables == set() + assert from_numexpr("A").variables == {"A"} + assert from_numexpr("A + A").variables == {"A"} + assert from_numexpr("A + B").variables == {"A", "B"} + assert from_numexpr("A + A*A - 3e7").variables == {"A"} + assert from_numexpr("arctan2(A, A)").variables == {"A"} + assert from_numexpr("arctan2(A, B)").variables == {"A", "B"} + assert from_root("arctan2(A, pi)").variables == {"A"} + assert from_numexpr("arctan2(arctan2(A, B), C)").variables == {"A", "B", "C"} + for base, expect in [ + (UC("2"), set()), + (Variable("A"), {"A"}), + (NC(ConstantIDs.PI), set()), + ]: expr = base for i in list(range(100)): expr = Expression(IDs.SQRT, expr) @@ -31,315 +40,550 @@ def test_get_variables(): def test_named_constants(): - assert from_root('pi').named_constants == {'PI'} - assert from_numexpr('2').named_constants == set() - assert from_numexpr('2e-3').named_constants == set() - assert from_numexpr('A').named_constants == set() - assert from_numexpr('A + A').named_constants == set() - assert from_numexpr('A + B').named_constants == set() - assert from_numexpr('A + A*A - 3e7').named_constants == set() - assert from_numexpr('arctan2(A, A)').named_constants == set() - assert from_numexpr('arctan2(A, B)').named_constants == set() - assert from_root('arctan2(A, pi)').named_constants == {'PI'} - assert from_numexpr('arctan2(arctan2(A, B), C)').named_constants == set() - for base, expect in [(UC('2'), set()), (Variable('A'), set()), (NC(ConstantIDs.PI), {'PI'})]: + assert from_root("pi").named_constants == {"PI"} + assert from_numexpr("2").named_constants == set() + assert from_numexpr("2e-3").named_constants == set() + assert from_numexpr("A").named_constants == set() + assert from_numexpr("A + A").named_constants == set() + assert from_numexpr("A + B").named_constants == set() + assert from_numexpr("A + A*A - 3e7").named_constants == set() + assert from_numexpr("arctan2(A, A)").named_constants == set() + assert from_numexpr("arctan2(A, B)").named_constants == set() + assert from_root("arctan2(A, pi)").named_constants == {"PI"} + assert from_numexpr("arctan2(arctan2(A, B), C)").named_constants == set() + for base, expect in [ + (UC("2"), set()), + (Variable("A"), set()), + (NC(ConstantIDs.PI), {"PI"}), + ]: expr = base - for i in list(range(100)): + for _ in list(range(100)): expr = Expression(IDs.SQRT, expr) assert expr.named_constants == expect def test_unnamed_constants(): - assert from_root('pi').unnamed_constants == set() - assert from_numexpr('2').unnamed_constants == {'2'} - assert from_numexpr('2e-3').unnamed_constants == {'2e-3'} - assert from_numexpr('A').unnamed_constants == set() - assert from_numexpr('A + A').unnamed_constants == set() - assert from_numexpr('A + B').unnamed_constants == set() - assert from_numexpr('A + A*A - 3e7').unnamed_constants == {'3e7'} - assert from_numexpr('arctan2(A, A)').unnamed_constants == set() - assert from_numexpr('arctan2(A, B)').unnamed_constants == set() - assert from_root('arctan2(A, pi)').unnamed_constants == set() - assert from_numexpr('arctan2(arctan2(A, B), C)').unnamed_constants == set() - for base, expect in [(UC('2'), {'2'}), (Variable('A'), set()), (NC(ConstantIDs.PI), set())]: + assert from_root("pi").unnamed_constants == set() + assert from_numexpr("2").unnamed_constants == {"2"} + assert from_numexpr("2e-3").unnamed_constants == {"2e-3"} + assert from_numexpr("A").unnamed_constants == set() + assert from_numexpr("A + A").unnamed_constants == set() + assert from_numexpr("A + B").unnamed_constants == set() + assert from_numexpr("A + A*A - 3e7").unnamed_constants == {"3e7"} + assert from_numexpr("arctan2(A, A)").unnamed_constants == set() + assert from_numexpr("arctan2(A, B)").unnamed_constants == set() + assert from_root("arctan2(A, pi)").unnamed_constants == set() + assert from_numexpr("arctan2(arctan2(A, B), C)").unnamed_constants == set() + for base, expect in [ + (UC("2"), {"2"}), + (Variable("A"), set()), + (NC(ConstantIDs.PI), set()), + ]: expr = base - for i in list(range(100)): + for _ in list(range(100)): expr = Expression(IDs.SQRT, expr) assert expr.unnamed_constants == expect def test_addition(): - aee(Expression(IDs.SQRT, UC('2')) + Expression(IDs.SQRT, UC('3')), - Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(Expression(IDs.SQRT, UC('2')) + UC('3'), - Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), UC('3'))) - aee(UC('3') + Expression(IDs.SQRT, UC('2')), - Expression(IDs.ADD, UC('3'), Expression(IDs.SQRT, UC('2')))) - - expression = Expression(IDs.SQRT, UC('2')) - expression += Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression += UC('3') - aee(expression, Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), UC('3'))) - expression = UC('3') - expression += Expression(IDs.SQRT, UC('2')) - aee(expression, Expression(IDs.ADD, UC('3'), Expression(IDs.SQRT, UC('2')))) - - aee(np.add(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.add(Expression(IDs.SQRT, UC('2')), UC('3')), - Expression(IDs.ADD, Expression(IDs.SQRT, UC('2')), UC('3'))) - aee(np.add(UC('3'), Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ADD, UC('3'), Expression(IDs.SQRT, UC('2')))) + aee( + Expression(IDs.SQRT, UC("2")) + Expression(IDs.SQRT, UC("3")), + Expression( + IDs.ADD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + Expression(IDs.SQRT, UC("2")) + UC("3"), + Expression(IDs.ADD, Expression(IDs.SQRT, UC("2")), UC("3")), + ) + aee( + UC("3") + Expression(IDs.SQRT, UC("2")), + Expression(IDs.ADD, UC("3"), Expression(IDs.SQRT, UC("2"))), + ) + + expression = Expression(IDs.SQRT, UC("2")) + expression += Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.ADD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression += UC("3") + aee(expression, Expression(IDs.ADD, Expression(IDs.SQRT, UC("2")), UC("3"))) + expression = UC("3") + expression += Expression(IDs.SQRT, UC("2")) + aee(expression, Expression(IDs.ADD, UC("3"), Expression(IDs.SQRT, UC("2")))) + + aee( + np.add(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.ADD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.add(Expression(IDs.SQRT, UC("2")), UC("3")), + Expression(IDs.ADD, Expression(IDs.SQRT, UC("2")), UC("3")), + ) + aee( + np.add(UC("3"), Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ADD, UC("3"), Expression(IDs.SQRT, UC("2"))), + ) def test_subtraction(): - aee(Expression(IDs.SQRT, UC('2')) - Expression(IDs.SQRT, UC('3')), - Expression(IDs.SUB, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression -= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.SUB, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.subtract(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.SUB, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) - Expression(IDs.SQRT, UC("3")), + Expression( + IDs.SUB, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression -= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.SUB, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.subtract(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.SUB, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_multiplication(): - aee(Expression(IDs.SQRT, UC('2')) * Expression(IDs.SQRT, UC('3')), - Expression(IDs.MUL, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression *= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.MUL, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.multiply(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.MUL, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - - -@pytest.mark.skipif(sys.version_info.major < 3, reason="Python < 3 not supported for division.") + aee( + Expression(IDs.SQRT, UC("2")) * Expression(IDs.SQRT, UC("3")), + Expression( + IDs.MUL, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression *= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.MUL, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.multiply(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.MUL, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + + +@pytest.mark.skipif( + sys.version_info.major < 3, reason="Python < 3 not supported for division." +) def test_division(): - aee(Expression(IDs.SQRT, UC('2')) / Expression(IDs.SQRT, UC('3')), - Expression(IDs.DIV, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression /= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.DIV, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.divide(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.DIV, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) / Expression(IDs.SQRT, UC("3")), + Expression( + IDs.DIV, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression /= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.DIV, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.divide(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.DIV, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_power(): - aee(Expression(IDs.SQRT, UC('2')) ** Expression(IDs.SQRT, UC('3')), - Expression(IDs.POW, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression **= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.POW, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.power(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.POW, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) ** Expression(IDs.SQRT, UC("3")), + Expression( + IDs.POW, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression **= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.POW, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.power(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.POW, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_mod(): - aee(Expression(IDs.SQRT, UC('2')) % Expression(IDs.SQRT, UC('3')), - Expression(IDs.MOD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression %= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.MOD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.mod(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.MOD, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) % Expression(IDs.SQRT, UC("3")), + Expression( + IDs.MOD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression %= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.MOD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.mod(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.MOD, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_and(): - aee(Expression(IDs.SQRT, UC('2')) & Expression(IDs.SQRT, UC('3')), - Expression(IDs.AND, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression &= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.AND, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.bitwise_and(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.AND, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) & Expression(IDs.SQRT, UC("3")), + Expression( + IDs.AND, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression &= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.AND, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.bitwise_and(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.AND, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_or(): - aee(Expression(IDs.SQRT, UC('2')) | Expression(IDs.SQRT, UC('3')), - Expression(IDs.OR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression |= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.OR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.bitwise_or(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.OR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) | Expression(IDs.SQRT, UC("3")), + Expression( + IDs.OR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression |= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.OR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.bitwise_or(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.OR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_xor(): - aee(Expression(IDs.SQRT, UC('2')) ^ Expression(IDs.SQRT, UC('3')), - Expression(IDs.XOR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression ^= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.XOR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.bitwise_xor(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.XOR, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) ^ Expression(IDs.SQRT, UC("3")), + Expression( + IDs.XOR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression ^= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.XOR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.bitwise_xor(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.XOR, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_left_shift(): - aee(Expression(IDs.SQRT, UC('2')) << Expression(IDs.SQRT, UC('3')), - Expression(IDs.LSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression <<= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.LSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.left_shift(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.LSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) << Expression(IDs.SQRT, UC("3")), + Expression( + IDs.LSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression <<= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.LSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.left_shift(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.LSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_right_shift(): - aee(Expression(IDs.SQRT, UC('2')) >> Expression(IDs.SQRT, UC('3')), - Expression(IDs.RSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - expression = Expression(IDs.SQRT, UC('2')) - expression >>= Expression(IDs.SQRT, UC('3')) - aee(expression, Expression(IDs.RSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) - aee(np.right_shift(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.RSHIFT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) >> Expression(IDs.SQRT, UC("3")), + Expression( + IDs.RSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + expression = Expression(IDs.SQRT, UC("2")) + expression >>= Expression(IDs.SQRT, UC("3")) + aee( + expression, + Expression( + IDs.RSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) + aee( + np.right_shift(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.RSHIFT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_negative(): - aee(-Expression(IDs.SQRT, UC('2')), - Expression(IDs.MINUS, Expression(IDs.SQRT, UC('2')))) + aee( + -Expression(IDs.SQRT, UC("2")), + Expression(IDs.MINUS, Expression(IDs.SQRT, UC("2"))), + ) def test_positive(): - aee(+Expression(IDs.SQRT, UC('2')), - Expression(IDs.PLUS, Expression(IDs.SQRT, UC('2')))) + aee( + +Expression(IDs.SQRT, UC("2")), + Expression(IDs.PLUS, Expression(IDs.SQRT, UC("2"))), + ) def test_not(): - aee(~Expression(IDs.SQRT, UC('2')), - Expression(IDs.NOT, Expression(IDs.SQRT, UC('2')))) + aee( + ~Expression(IDs.SQRT, UC("2")), + Expression(IDs.NOT, Expression(IDs.SQRT, UC("2"))), + ) def test_abs(): - aee(abs(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ABS, Expression(IDs.SQRT, UC('2')))) - aee(np.abs(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ABS, Expression(IDs.SQRT, UC('2')))) + aee( + abs(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ABS, Expression(IDs.SQRT, UC("2"))), + ) + aee( + np.abs(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ABS, Expression(IDs.SQRT, UC("2"))), + ) def test_equals(): - aee(Expression(IDs.SQRT, UC('2')) == Expression(IDs.SQRT, UC('3')), - Expression(IDs.EQ, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) == Expression(IDs.SQRT, UC("3")), + Expression( + IDs.EQ, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_not_equals(): - aee(Expression(IDs.SQRT, UC('2')) != Expression(IDs.SQRT, UC('3')), - Expression(IDs.NEQ, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) != Expression(IDs.SQRT, UC("3")), + Expression( + IDs.NEQ, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_greater_than(): - aee(Expression(IDs.SQRT, UC('2')) > Expression(IDs.SQRT, UC('3')), - Expression(IDs.GT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) > Expression(IDs.SQRT, UC("3")), + Expression( + IDs.GT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_greater_than_or_equal(): - aee(Expression(IDs.SQRT, UC('2')) >= Expression(IDs.SQRT, UC('3')), - Expression(IDs.GTEQ, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) >= Expression(IDs.SQRT, UC("3")), + Expression( + IDs.GTEQ, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_less_than(): - aee(Expression(IDs.SQRT, UC('2')) < Expression(IDs.SQRT, UC('3')), - Expression(IDs.LT, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) < Expression(IDs.SQRT, UC("3")), + Expression( + IDs.LT, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_less_than_or_equal(): - aee(Expression(IDs.SQRT, UC('2')) <= Expression(IDs.SQRT, UC('3')), - Expression(IDs.LTEQ, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + Expression(IDs.SQRT, UC("2")) <= Expression(IDs.SQRT, UC("3")), + Expression( + IDs.LTEQ, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) # TODO where def test_sin(): - aee(np.sin(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.SIN, Expression(IDs.SQRT, UC('2')))) + aee( + np.sin(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.SIN, Expression(IDs.SQRT, UC("2"))), + ) def test_cos(): - aee(np.cos(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.COS, Expression(IDs.SQRT, UC('2')))) + aee( + np.cos(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.COS, Expression(IDs.SQRT, UC("2"))), + ) def test_tan(): - aee(np.tan(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.TAN, Expression(IDs.SQRT, UC('2')))) + aee( + np.tan(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.TAN, Expression(IDs.SQRT, UC("2"))), + ) def test_arcsin(): - aee(np.arcsin(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ASIN, Expression(IDs.SQRT, UC('2')))) + aee( + np.arcsin(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ASIN, Expression(IDs.SQRT, UC("2"))), + ) def test_arccos(): - aee(np.arccos(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ACOS, Expression(IDs.SQRT, UC('2')))) + aee( + np.arccos(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ACOS, Expression(IDs.SQRT, UC("2"))), + ) def test_arctan(): - aee(np.arctan(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ATAN, Expression(IDs.SQRT, UC('2')))) + aee( + np.arctan(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ATAN, Expression(IDs.SQRT, UC("2"))), + ) def test_arctan2(): - aee(np.arctan2(Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3'))), - Expression(IDs.ATAN2, Expression(IDs.SQRT, UC('2')), Expression(IDs.SQRT, UC('3')))) + aee( + np.arctan2(Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3"))), + Expression( + IDs.ATAN2, Expression(IDs.SQRT, UC("2")), Expression(IDs.SQRT, UC("3")) + ), + ) def test_sinh(): - aee(np.sinh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ASINH, Expression(IDs.SQRT, UC('2')))) + aee( + np.sinh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ASINH, Expression(IDs.SQRT, UC("2"))), + ) def test_cosh(): - aee(np.cosh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.COSH, Expression(IDs.SQRT, UC('2')))) + aee( + np.cosh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.COSH, Expression(IDs.SQRT, UC("2"))), + ) def test_tanh(): - aee(np.tanh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.TANH, Expression(IDs.SQRT, UC('2')))) + aee( + np.tanh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.TANH, Expression(IDs.SQRT, UC("2"))), + ) def test_arcsinh(): - aee(np.arcsinh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ASINH, Expression(IDs.SQRT, UC('2')))) + aee( + np.arcsinh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ASINH, Expression(IDs.SQRT, UC("2"))), + ) def test_arccosh(): - aee(np.arccosh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ACOSH, Expression(IDs.SQRT, UC('2')))) + aee( + np.arccosh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ACOSH, Expression(IDs.SQRT, UC("2"))), + ) def test_arctanh(): - aee(np.arctanh(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.ATANH, Expression(IDs.SQRT, UC('2')))) + aee( + np.arctanh(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.ATANH, Expression(IDs.SQRT, UC("2"))), + ) def test_log(): - aee(np.log(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.LOG, Expression(IDs.SQRT, UC('2')))) + aee( + np.log(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.LOG, Expression(IDs.SQRT, UC("2"))), + ) def test_log10(): - aee(np.log10(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.LOG10, Expression(IDs.SQRT, UC('2')))) + aee( + np.log10(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.LOG10, Expression(IDs.SQRT, UC("2"))), + ) def test_log1p(): - aee(np.log1p(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.LOG1p, Expression(IDs.SQRT, UC('2')))) + aee( + np.log1p(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.LOG1p, Expression(IDs.SQRT, UC("2"))), + ) def test_exp(): - aee(np.exp(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.EXP, Expression(IDs.SQRT, UC('2')))) + aee( + np.exp(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.EXP, Expression(IDs.SQRT, UC("2"))), + ) def test_expm1(): - aee(np.expm1(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.EXPM1, Expression(IDs.SQRT, UC('2')))) + aee( + np.expm1(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.EXPM1, Expression(IDs.SQRT, UC("2"))), + ) def test_sqrt(): - aee(np.sqrt(Expression(IDs.SQRT, UC('2'))), - Expression(IDs.SQRT, Expression(IDs.SQRT, UC('2')))) + aee( + np.sqrt(Expression(IDs.SQRT, UC("2"))), + Expression(IDs.SQRT, Expression(IDs.SQRT, UC("2"))), + ) diff --git a/tests/test_formulas.py b/tests/test_formulas.py index 9af44b3..664f461 100644 --- a/tests/test_formulas.py +++ b/tests/test_formulas.py @@ -1,120 +1,175 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import pytest -from formulate import from_root, to_root -from formulate import from_numexpr, to_numexpr - -numexpr = pytest.importorskip("numexpr") -ROOT = pytest.importorskip("ROOT") - - -def root_eval(string, x=None, y=None, z=None, t=None): - f = ROOT.TFormula('', string) - f.Compile() - if x is None: - assert y is None and z is None and t is None - return f.Eval(0) - elif y is None: - assert z is None and t is None - return f.Eval(x) - elif z is None: - assert t is None - return f.Eval(x, y) - elif t is None: - return f.Eval(x, y, z) - else: - return f.Eval(x, y, z, t) - - -def numexpr_eval(string, **kwargs): - return numexpr.evaluate(string, local_dict=kwargs) - - -def create_formula_test(input_string, input_backend='root', numexpr_raises=None): - assert input_backend in ('root', 'numexpr'), 'Unrecognised backend specified' - input_from_method = { - 'root': from_root, - 'numexpr': from_numexpr, - }[input_backend] - - def test_constant(): - expression = input_from_method(input_string) - - if input_backend == 'root': - root_result = to_root(expression) - assert input_string, root_result - - if numexpr_raises: - with pytest.raises(numexpr_raises): - to_numexpr(expression) - else: - numexpr_result = to_numexpr(expression) - assert pytest.approx(root_eval(root_result), numexpr_eval(numexpr_result)) - else: - raise NotImplementedError() - - return test_constant - - -test_root_BesselI0 = create_formula_test('TMath::BesselI0(A)', numexpr_raises=NotImplementedError) -test_root_BesselI1 = create_formula_test('TMath::BesselI1(A)', numexpr_raises=NotImplementedError) -test_root_BesselJ0 = create_formula_test('TMath::BesselJ0(A)', numexpr_raises=NotImplementedError) -test_root_BesselJ1 = create_formula_test('TMath::BesselJ1(A)', numexpr_raises=NotImplementedError) -test_root_BesselK0 = create_formula_test('TMath::BesselK0(A)', numexpr_raises=NotImplementedError) -test_root_BesselK1 = create_formula_test('TMath::BesselK1(A)', numexpr_raises=NotImplementedError) -test_root_BesselY0 = create_formula_test('TMath::BesselY0(A)', numexpr_raises=NotImplementedError) -test_root_BesselY1 = create_formula_test('TMath::BesselY1(A)', numexpr_raises=NotImplementedError) -test_root_Ceil = create_formula_test('TMath::Ceil(A)', numexpr_raises=NotImplementedError) -test_root_CeilNint = create_formula_test('TMath::CeilNint(A)', numexpr_raises=NotImplementedError) -test_root_DiLog = create_formula_test('TMath::DiLog(A)', numexpr_raises=NotImplementedError) -test_root_Erf = create_formula_test('TMath::Erf(A)', numexpr_raises=NotImplementedError) -test_root_Erfc = create_formula_test('TMath::Erfc(A)', numexpr_raises=NotImplementedError) -test_root_ErfcInverse = create_formula_test('TMath::ErfcInverse(A)', numexpr_raises=NotImplementedError) -test_root_ErfInverse = create_formula_test('TMath::ErfInverse(A)', numexpr_raises=NotImplementedError) -test_root_Even = create_formula_test('TMath::Even(A)', numexpr_raises=NotImplementedError) -test_root_Factorial = create_formula_test('TMath::Factorial(A)', numexpr_raises=NotImplementedError) -test_root_Floor = create_formula_test('TMath::Floor(A)', numexpr_raises=NotImplementedError) -test_root_FloorNint = create_formula_test('TMath::FloorNint(A)', numexpr_raises=NotImplementedError) -test_root_Freq = create_formula_test('TMath::Freq(A)', numexpr_raises=NotImplementedError) -test_root_KolmogorovProb = create_formula_test('TMath::KolmogorovProb(A)', numexpr_raises=NotImplementedError) -test_root_LandauI = create_formula_test('TMath::LandauI(A)', numexpr_raises=NotImplementedError) -test_root_LnGamma = create_formula_test('TMath::LnGamma(A)', numexpr_raises=NotImplementedError) -test_root_NextPrime = create_formula_test('TMath::NextPrime(A)', numexpr_raises=NotImplementedError) -test_root_NormQuantile = create_formula_test('TMath::NormQuantile(A)', numexpr_raises=NotImplementedError) -test_root_Odd = create_formula_test('TMath::Odd(A)', numexpr_raises=NotImplementedError) -test_root_Sq = create_formula_test('TMath::Sq(1.234)') -test_root_StruveH0 = create_formula_test('TMath::StruveH0(A)', numexpr_raises=NotImplementedError) -test_root_StruveH1 = create_formula_test('TMath::StruveH1(A)', numexpr_raises=NotImplementedError) -test_root_StruveL0 = create_formula_test('TMath::StruveL0(A)', numexpr_raises=NotImplementedError) -test_root_StruveL1 = create_formula_test('TMath::StruveL1(A)', numexpr_raises=NotImplementedError) -test_root_BesselI = create_formula_test('TMath::BesselI(A, B)', numexpr_raises=NotImplementedError) -test_root_BesselK = create_formula_test('TMath::BesselK(A, B)', numexpr_raises=NotImplementedError) -test_root_Beta = create_formula_test('TMath::Beta(A, B)', numexpr_raises=NotImplementedError) -test_root_Binomial = create_formula_test('TMath::Binomial(A, B)', numexpr_raises=NotImplementedError) -test_root_ChisquareQuantile = create_formula_test('TMath::ChisquareQuantile(A, B)', numexpr_raises=NotImplementedError) -test_root_Ldexp = create_formula_test('TMath::Ldexp(A, B)', numexpr_raises=NotImplementedError) -test_root_Permute = create_formula_test('TMath::Permute(A, B)', numexpr_raises=NotImplementedError) -test_root_Poisson = create_formula_test('TMath::Poisson(A, B)', numexpr_raises=NotImplementedError) -test_root_PoissonI = create_formula_test('TMath::PoissonI(A, B)', numexpr_raises=NotImplementedError) -test_root_Prob = create_formula_test('TMath::Prob(A, B)', numexpr_raises=NotImplementedError) -test_root_Student = create_formula_test('TMath::Student(A, B)', numexpr_raises=NotImplementedError) -test_root_StudentI = create_formula_test('TMath::StudentI(A, B)', numexpr_raises=NotImplementedError) -test_root_AreEqualAbs = create_formula_test('TMath::AreEqualAbs(A, B, C)', numexpr_raises=NotImplementedError) -test_root_AreEqualRel = create_formula_test('TMath::AreEqualRel(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BetaCf = create_formula_test('TMath::BetaCf(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BetaDist = create_formula_test('TMath::BetaDist(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BetaDistI = create_formula_test('TMath::BetaDistI(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BetaIncomplete = create_formula_test('TMath::BetaIncomplete(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BinomialI = create_formula_test('TMath::BinomialI(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BubbleHigh = create_formula_test('TMath::BubbleHigh(A, B, C)', numexpr_raises=NotImplementedError) -test_root_BubbleLow = create_formula_test('TMath::BubbleLow(A, B, C)', numexpr_raises=NotImplementedError) -test_root_FDist = create_formula_test('TMath::FDist(A, B, C)', numexpr_raises=NotImplementedError) -test_root_FDistI = create_formula_test('TMath::FDistI(A, B, C)', numexpr_raises=NotImplementedError) -test_root_Vavilov = create_formula_test('TMath::Vavilov(A, B, C)', numexpr_raises=NotImplementedError) -test_root_VavilovI = create_formula_test('TMath::VavilovI(A, B, C)', numexpr_raises=NotImplementedError) -test_root_RootsCubic = create_formula_test('TMath::RootsCubic(A, B, C, D)', numexpr_raises=NotImplementedError) -test_root_Quantiles = create_formula_test('TMath::Quantiles(A, B, C, D, E)', numexpr_raises=NotImplementedError) +test_root_BesselI0 = pytest.helpers.create_formula_test( + "TMath::BesselI0(A)", numexpr_raises=NotImplementedError +) +test_root_BesselI1 = pytest.helpers.create_formula_test( + "TMath::BesselI1(A)", numexpr_raises=NotImplementedError +) +test_root_BesselJ0 = pytest.helpers.create_formula_test( + "TMath::BesselJ0(A)", numexpr_raises=NotImplementedError +) +test_root_BesselJ1 = pytest.helpers.create_formula_test( + "TMath::BesselJ1(A)", numexpr_raises=NotImplementedError +) +test_root_BesselK0 = pytest.helpers.create_formula_test( + "TMath::BesselK0(A)", numexpr_raises=NotImplementedError +) +test_root_BesselK1 = pytest.helpers.create_formula_test( + "TMath::BesselK1(A)", numexpr_raises=NotImplementedError +) +test_root_BesselY0 = pytest.helpers.create_formula_test( + "TMath::BesselY0(A)", numexpr_raises=NotImplementedError +) +test_root_BesselY1 = pytest.helpers.create_formula_test( + "TMath::BesselY1(A)", numexpr_raises=NotImplementedError +) +test_root_Ceil = pytest.helpers.create_formula_test( + "TMath::Ceil(A)", numexpr_raises=NotImplementedError +) +test_root_CeilNint = pytest.helpers.create_formula_test( + "TMath::CeilNint(A)", numexpr_raises=NotImplementedError +) +test_root_DiLog = pytest.helpers.create_formula_test( + "TMath::DiLog(A)", numexpr_raises=NotImplementedError +) +test_root_Erf = pytest.helpers.create_formula_test( + "TMath::Erf(A)", numexpr_raises=NotImplementedError +) +test_root_Erfc = pytest.helpers.create_formula_test( + "TMath::Erfc(A)", numexpr_raises=NotImplementedError +) +test_root_ErfcInverse = pytest.helpers.create_formula_test( + "TMath::ErfcInverse(A)", numexpr_raises=NotImplementedError +) +test_root_ErfInverse = pytest.helpers.create_formula_test( + "TMath::ErfInverse(A)", numexpr_raises=NotImplementedError +) +test_root_Even = pytest.helpers.create_formula_test( + "TMath::Even(A)", numexpr_raises=NotImplementedError +) +test_root_Factorial = pytest.helpers.create_formula_test( + "TMath::Factorial(A)", numexpr_raises=NotImplementedError +) +test_root_Floor = pytest.helpers.create_formula_test( + "TMath::Floor(A)", numexpr_raises=NotImplementedError +) +test_root_FloorNint = pytest.helpers.create_formula_test( + "TMath::FloorNint(A)", numexpr_raises=NotImplementedError +) +test_root_Freq = pytest.helpers.create_formula_test( + "TMath::Freq(A)", numexpr_raises=NotImplementedError +) +test_root_KolmogorovProb = pytest.helpers.create_formula_test( + "TMath::KolmogorovProb(A)", numexpr_raises=NotImplementedError +) +test_root_LandauI = pytest.helpers.create_formula_test( + "TMath::LandauI(A)", numexpr_raises=NotImplementedError +) +test_root_LnGamma = pytest.helpers.create_formula_test( + "TMath::LnGamma(A)", numexpr_raises=NotImplementedError +) +test_root_NextPrime = pytest.helpers.create_formula_test( + "TMath::NextPrime(A)", numexpr_raises=NotImplementedError +) +test_root_NormQuantile = pytest.helpers.create_formula_test( + "TMath::NormQuantile(A)", numexpr_raises=NotImplementedError +) +test_root_Odd = pytest.helpers.create_formula_test( + "TMath::Odd(A)", numexpr_raises=NotImplementedError +) +test_root_Sq = pytest.helpers.create_formula_test("TMath::Sq(1.234)") +test_root_StruveH0 = pytest.helpers.create_formula_test( + "TMath::StruveH0(A)", numexpr_raises=NotImplementedError +) +test_root_StruveH1 = pytest.helpers.create_formula_test( + "TMath::StruveH1(A)", numexpr_raises=NotImplementedError +) +test_root_StruveL0 = pytest.helpers.create_formula_test( + "TMath::StruveL0(A)", numexpr_raises=NotImplementedError +) +test_root_StruveL1 = pytest.helpers.create_formula_test( + "TMath::StruveL1(A)", numexpr_raises=NotImplementedError +) +test_root_BesselI = pytest.helpers.create_formula_test( + "TMath::BesselI(A, B)", numexpr_raises=NotImplementedError +) +test_root_BesselK = pytest.helpers.create_formula_test( + "TMath::BesselK(A, B)", numexpr_raises=NotImplementedError +) +test_root_Beta = pytest.helpers.create_formula_test( + "TMath::Beta(A, B)", numexpr_raises=NotImplementedError +) +test_root_Binomial = pytest.helpers.create_formula_test( + "TMath::Binomial(A, B)", numexpr_raises=NotImplementedError +) +test_root_ChisquareQuantile = pytest.helpers.create_formula_test( + "TMath::ChisquareQuantile(A, B)", numexpr_raises=NotImplementedError +) +test_root_Ldexp = pytest.helpers.create_formula_test( + "TMath::Ldexp(A, B)", numexpr_raises=NotImplementedError +) +test_root_Permute = pytest.helpers.create_formula_test( + "TMath::Permute(A, B)", numexpr_raises=NotImplementedError +) +test_root_Poisson = pytest.helpers.create_formula_test( + "TMath::Poisson(A, B)", numexpr_raises=NotImplementedError +) +test_root_PoissonI = pytest.helpers.create_formula_test( + "TMath::PoissonI(A, B)", numexpr_raises=NotImplementedError +) +test_root_Prob = pytest.helpers.create_formula_test( + "TMath::Prob(A, B)", numexpr_raises=NotImplementedError +) +test_root_Student = pytest.helpers.create_formula_test( + "TMath::Student(A, B)", numexpr_raises=NotImplementedError +) +test_root_StudentI = pytest.helpers.create_formula_test( + "TMath::StudentI(A, B)", numexpr_raises=NotImplementedError +) +test_root_AreEqualAbs = pytest.helpers.create_formula_test( + "TMath::AreEqualAbs(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_AreEqualRel = pytest.helpers.create_formula_test( + "TMath::AreEqualRel(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BetaCf = pytest.helpers.create_formula_test( + "TMath::BetaCf(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BetaDist = pytest.helpers.create_formula_test( + "TMath::BetaDist(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BetaDistI = pytest.helpers.create_formula_test( + "TMath::BetaDistI(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BetaIncomplete = pytest.helpers.create_formula_test( + "TMath::BetaIncomplete(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BinomialI = pytest.helpers.create_formula_test( + "TMath::BinomialI(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BubbleHigh = pytest.helpers.create_formula_test( + "TMath::BubbleHigh(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_BubbleLow = pytest.helpers.create_formula_test( + "TMath::BubbleLow(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_FDist = pytest.helpers.create_formula_test( + "TMath::FDist(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_FDistI = pytest.helpers.create_formula_test( + "TMath::FDistI(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_Vavilov = pytest.helpers.create_formula_test( + "TMath::Vavilov(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_VavilovI = pytest.helpers.create_formula_test( + "TMath::VavilovI(A, B, C)", numexpr_raises=NotImplementedError +) +test_root_RootsCubic = pytest.helpers.create_formula_test( + "TMath::RootsCubic(A, B, C, D)", numexpr_raises=NotImplementedError +) +test_root_Quantiles = pytest.helpers.create_formula_test( + "TMath::Quantiles(A, B, C, D, E)", numexpr_raises=NotImplementedError +) diff --git a/tests/test_main.py b/tests/test_main.py index d5c310a..8d796f7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,54 +1,66 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import pytest from formulate.__main__ import parse_args def test_root2numexpr_conversion(): - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A)', '--to-numexpr']) - assert result == '(A & B) | sqrt(A)' + result = parse_args(["--from-root", "(A && B) || TMath::Sqrt(A)", "--to-numexpr"]) + assert result == "(A & B) | sqrt(A)" def test_numexpr2root_conversion(): - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A)', '--to-root']) - assert result == '(A && B) || TMath::Sqrt(A)' + result = parse_args(["--from-numexpr", "(A & B) | sqrt(A)", "--to-root"]) + assert result == "(A && B) || TMath::Sqrt(A)" def test_get_variables(): - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A) + 5.4**3.141592 ', '--variables']) - assert result == 'A\nB' - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A) + 5.4**pi', '--variables']) - assert result == 'A\nB' + result = parse_args( + ["--from-numexpr", "(A & B) | sqrt(A) + 5.4**3.141592 ", "--variables"] + ) + assert result == "A\nB" + result = parse_args( + ["--from-root", "(A && B) || TMath::Sqrt(A) + 5.4**pi", "--variables"] + ) + assert result == "A\nB" def test_get_named_constants(): - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A)', '--named-constants']) - assert result == '' - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A)', '--named-constants']) - assert result == '' - - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A) + 5.4**3.141592', '--named-constants']) - assert result == '' - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A) + 5.4**pi', '--named-constants']) - assert result == 'PI' + result = parse_args(["--from-numexpr", "(A & B) | sqrt(A)", "--named-constants"]) + assert result == "" + result = parse_args( + ["--from-root", "(A && B) || TMath::Sqrt(A)", "--named-constants"] + ) + assert result == "" + + result = parse_args( + ["--from-numexpr", "(A & B) | sqrt(A) + 5.4**3.141592", "--named-constants"] + ) + assert result == "" + result = parse_args( + ["--from-root", "(A && B) || TMath::Sqrt(A) + 5.4**pi", "--named-constants"] + ) + assert result == "PI" def test_get_unnamed_constants(): - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A)', '--unnamed-constants']) - assert result == '' - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A)', '--unnamed-constants']) - assert result == '' - - result = parse_args(['--from-numexpr', '(A & B) | sqrt(A) + 5.4**3.141592', '--unnamed-constants']) - assert result == '3.141592\n5.4' - result = parse_args(['--from-root', '(A && B) || TMath::Sqrt(A) + 5.4**pi', '--unnamed-constants']) - assert result == '5.4' + result = parse_args(["--from-numexpr", "(A & B) | sqrt(A)", "--unnamed-constants"]) + assert result == "" + result = parse_args( + ["--from-root", "(A && B) || TMath::Sqrt(A)", "--unnamed-constants"] + ) + assert result == "" + + result = parse_args( + ["--from-numexpr", "(A & B) | sqrt(A) + 5.4**3.141592", "--unnamed-constants"] + ) + assert result == "3.141592\n5.4" + result = parse_args( + ["--from-root", "(A && B) || TMath::Sqrt(A) + 5.4**pi", "--unnamed-constants"] + ) + assert result == "5.4" def test_invalid_args(): with pytest.raises(SystemExit): - parse_args(['--dsadasdsada']) + parse_args(["--dsadasdsada"]) diff --git a/tests/test_operators.py b/tests/test_operators.py index d0971ac..0a14e49 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,168 +1,157 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import pytest -from formulate import from_root, to_root -from formulate import from_numexpr, to_numexpr - -numexpr = pytest.importorskip("numexpr") -ROOT = pytest.importorskip("ROOT") - - -def root_eval(string, x=None, y=None, z=None, t=None): - f = ROOT.TFormula('', string) - f.Compile() - if x is None: - assert y is None and z is None and t is None - return f.Eval(0) - elif y is None: - assert z is None and t is None - return f.Eval(x) - elif z is None: - assert t is None - return f.Eval(x, y) - elif t is None: - return f.Eval(x, y, z) - else: - return f.Eval(x, y, z, t) - - -def numexpr_eval(string, **kwargs): - return numexpr.evaluate(string, local_dict=kwargs) - - -def create_formula_test(input_string, input_backend='root', root_raises=None, numexpr_raises=None): - assert input_backend in ('root', 'numexpr'), 'Unrecognised backend specified' - input_from_method = { - 'root': from_root, - 'numexpr': from_numexpr, - }[input_backend] - - def test_constant(): - expression = input_from_method(input_string) - - if input_backend == 'root': - root_result = to_root(expression) - assert input_string, root_result - - if numexpr_raises: - with pytest.raises(numexpr_raises): - to_numexpr(expression) - else: - numexpr_result = to_numexpr(expression) - assert pytest.approx(root_eval(root_result), numexpr_eval(numexpr_result)) - else: - numexpr_result = to_numexpr(expression) - assert input_string, numexpr_result - - if root_raises: - with pytest.raises(root_raises): - to_root(expression) - else: - root_result = to_root(expression) - assert pytest.approx(numexpr_eval(numexpr_result), root_eval(root_result)) - - return test_constant - - -test_add_root = create_formula_test('3.4 + 5e-7', input_backend='root') -test_add_numexpr = create_formula_test('3.4 + 5e-7', input_backend='numexpr') -test_sub_root = create_formula_test('3.4 - 5e-7', input_backend='root') -test_sub_numexpr = create_formula_test('3.4 - 5e-7', input_backend='numexpr') -test_mul_root = create_formula_test('3.4 * 5e-7', input_backend='root') -test_mul_numexpr = create_formula_test('3.4 * 5e-7', input_backend='numexpr') -test_div_root = create_formula_test('3.4 / 5e-7', input_backend='root') -test_div_numexpr = create_formula_test('3.4 / 5e-7', input_backend='numexpr') -test_mod_root = create_formula_test('3 % 5', input_backend='root') -test_mod_numexpr = create_formula_test('3 % 5', input_backend='numexpr') -test_pow_1_root = create_formula_test('3 ** 5', input_backend='root') -test_pow_1_numexpr = create_formula_test('3 ** 5', input_backend='numexpr') -test_pow_2_root = create_formula_test('3 ** -1.5', input_backend='root') -test_pow_2_numexpr = create_formula_test('3 ** -1.5', input_backend='numexpr') -test_pow_3_root = create_formula_test('3 **2', input_backend='root') -# test_pow_3_numexpr = create_formula_test('3 **2', input_backend='numexpr') -test_lshift_root = create_formula_test('3 << 5', input_backend='root') -test_lshift_numexpr = create_formula_test('3 << 5', input_backend='numexpr') -test_rshift_root = create_formula_test('3 >> 5', input_backend='root') -test_rshift_numexpr = create_formula_test('3 >> 5', input_backend='numexpr') - -test_eq_1_root = create_formula_test('3 == 5', input_backend='root') -test_eq_1_numexpr = create_formula_test('3 == 5', input_backend='numexpr') -test_eq_2_root = create_formula_test('3 == 3', input_backend='root') -test_eq_2_numexpr = create_formula_test('3 == 3', input_backend='numexpr') - -test_neq_1_root = create_formula_test('3 != 5', input_backend='root') -test_neq_1_numexpr = create_formula_test('3 != 5', input_backend='numexpr') -test_neq_2_root = create_formula_test('3 != 3', input_backend='root') -test_neq_2_numexpr = create_formula_test('3 != 3', input_backend='numexpr') -test_gt_1_root = create_formula_test('3 > 1', input_backend='root') -test_gt_1_numexpr = create_formula_test('3 > 1', input_backend='numexpr') -test_gt_2_root = create_formula_test('3 > 3', input_backend='root') -test_gt_2_numexpr = create_formula_test('3 > 3', input_backend='numexpr') -test_gt_3_root = create_formula_test('5 > 3', input_backend='root') -test_gt_3_numexpr = create_formula_test('5 > 3', input_backend='numexpr') -test_gteq_1_root = create_formula_test('3 >= 1', input_backend='root') -test_gteq_1_numexpr = create_formula_test('3 >= 1', input_backend='numexpr') -test_gteq_2_root = create_formula_test('3 >= 3', input_backend='root') -test_gteq_2_numexpr = create_formula_test('3 >= 3', input_backend='numexpr') -test_gteq_3_root = create_formula_test('5 >= 3', input_backend='root') -test_gteq_3_numexpr = create_formula_test('5 >= 3', input_backend='numexpr') -test_lt_1_root = create_formula_test('3 < 1', input_backend='root') -test_lt_1_numexpr = create_formula_test('3 < 1', input_backend='numexpr') -test_lt_2_root = create_formula_test('3 < 3', input_backend='root') -test_lt_2_numexpr = create_formula_test('3 < 3', input_backend='numexpr') -test_lt_3_root = create_formula_test('5 < 3', input_backend='root') -test_lt_3_numexpr = create_formula_test('5 < 3', input_backend='numexpr') -test_lteq_1_root = create_formula_test('3 <= 1', input_backend='root') -test_lteq_1_numexpr = create_formula_test('3 <= 1', input_backend='numexpr') -test_lteq_2_root = create_formula_test('3 <= 3', input_backend='root') -test_lteq_2_numexpr = create_formula_test('3 <= 3', input_backend='numexpr') -test_lteq_3_root = create_formula_test('5 <= 3', input_backend='root') -test_lteq_3_numexpr = create_formula_test('5 <= 3', input_backend='numexpr') - -test_and_1_root = create_formula_test('0 && 0', input_backend='root') -test_and_1_numexpr = create_formula_test('0 & 0', input_backend='numexpr') -test_and_2_root = create_formula_test('0 && 1', input_backend='root') -test_and_2_numexpr = create_formula_test('0 & 1', input_backend='numexpr') -test_and_3_root = create_formula_test('1 && 0', input_backend='root') -test_and_3_numexpr = create_formula_test('1 & 0', input_backend='numexpr') -test_and_4_root = create_formula_test('1 && 1', input_backend='root') -test_and_4_numexpr = create_formula_test('1 & 1', input_backend='numexpr') -test_and_5_root = create_formula_test('3 && 5', input_backend='root') -test_and_5_numexpr = create_formula_test('3 & 5', input_backend='numexpr') - -test_or_1_root = create_formula_test('0 || 0', input_backend='root') -test_or_1_numexpr = create_formula_test('0 | 0', input_backend='numexpr') -test_or_2_root = create_formula_test('0 || 1', input_backend='root') -test_or_2_numexpr = create_formula_test('0 | 1', input_backend='numexpr') -test_or_3_root = create_formula_test('1 || 0', input_backend='root') -test_or_3_numexpr = create_formula_test('1 | 0', input_backend='numexpr') -test_or_4_root = create_formula_test('1 || 1', input_backend='root') -test_or_4_numexpr = create_formula_test('1 | 1', input_backend='numexpr') -test_or_5_root = create_formula_test('3 || 5', input_backend='root') -test_or_5_numexpr = create_formula_test('3 | 5', input_backend='numexpr') - -test_xor_1_root = create_formula_test('0 ^ 0', input_backend='root') -test_xor_1_numexpr = create_formula_test('0 ^ 0', input_backend='numexpr') -test_xor_2_root = create_formula_test('0 ^ 1', input_backend='root') -test_xor_2_numexpr = create_formula_test('0 ^ 1', input_backend='numexpr') -test_xor_3_root = create_formula_test('1 ^ 0', input_backend='root') -test_xor_3_numexpr = create_formula_test('1 ^ 0', input_backend='numexpr') -test_xor_4_root = create_formula_test('1 ^ 1', input_backend='root') -test_xor_4_numexpr = create_formula_test('1 ^ 1', input_backend='numexpr') -test_xor_5_root = create_formula_test('3 ^ 5', input_backend='root') -test_xor_5_numexpr = create_formula_test('3 ^ 5', input_backend='numexpr') - -test_not_1_root = create_formula_test('!0', input_backend='root') -test_not_1_numexpr = create_formula_test('~0', input_backend='numexpr') -test_not_2_root = create_formula_test('!1', input_backend='root') -test_not_2_numexpr = create_formula_test('~1', input_backend='numexpr') -test_not_3_root = create_formula_test('!0', input_backend='root') -test_not_3_numexpr = create_formula_test('~0', input_backend='numexpr') -test_not_4_root = create_formula_test('!1', input_backend='root') -test_not_4_numexpr = create_formula_test('~1', input_backend='numexpr') -test_not_5_root = create_formula_test('!5', input_backend='root') -test_not_5_numexpr = create_formula_test('~5', input_backend='numexpr') +test_add_root = pytest.helpers.create_formula_test("3.4 + 5e-7", input_backend="root") +test_add_numexpr = pytest.helpers.create_formula_test( + "3.4 + 5e-7", input_backend="numexpr" +) +test_sub_root = pytest.helpers.create_formula_test("3.4 - 5e-7", input_backend="root") +test_sub_numexpr = pytest.helpers.create_formula_test( + "3.4 - 5e-7", input_backend="numexpr" +) +test_mul_root = pytest.helpers.create_formula_test("3.4 * 5e-7", input_backend="root") +test_mul_numexpr = pytest.helpers.create_formula_test( + "3.4 * 5e-7", input_backend="numexpr" +) +test_div_root = pytest.helpers.create_formula_test("3.4 / 5e-7", input_backend="root") +test_div_numexpr = pytest.helpers.create_formula_test( + "3.4 / 5e-7", input_backend="numexpr" +) +test_mod_root = pytest.helpers.create_formula_test("3 % 5", input_backend="root") +test_mod_numexpr = pytest.helpers.create_formula_test("3 % 5", input_backend="numexpr") +test_pow_1_root = pytest.helpers.create_formula_test("3 ** 5", input_backend="root") +test_pow_1_numexpr = pytest.helpers.create_formula_test( + "3 ** 5", input_backend="numexpr" +) +test_pow_2_root = pytest.helpers.create_formula_test("3 ** -1.5", input_backend="root") +test_pow_2_numexpr = pytest.helpers.create_formula_test( + "3 ** -1.5", input_backend="numexpr" +) +test_pow_3_root = pytest.helpers.create_formula_test("3 **2", input_backend="root") +# test_pow_3_numexpr = pytest.helpers.create_formula_test('3 **2', input_backend='numexpr') +test_lshift_root = pytest.helpers.create_formula_test("3 << 5", input_backend="root") +test_lshift_numexpr = pytest.helpers.create_formula_test( + "3 << 5", input_backend="numexpr" +) +test_rshift_root = pytest.helpers.create_formula_test("3 >> 5", input_backend="root") +test_rshift_numexpr = pytest.helpers.create_formula_test( + "3 >> 5", input_backend="numexpr" +) + +test_eq_1_root = pytest.helpers.create_formula_test("3 == 5", input_backend="root") +test_eq_1_numexpr = pytest.helpers.create_formula_test( + "3 == 5", input_backend="numexpr" +) +test_eq_2_root = pytest.helpers.create_formula_test("3 == 3", input_backend="root") +test_eq_2_numexpr = pytest.helpers.create_formula_test( + "3 == 3", input_backend="numexpr" +) + +test_neq_1_root = pytest.helpers.create_formula_test("3 != 5", input_backend="root") +test_neq_1_numexpr = pytest.helpers.create_formula_test( + "3 != 5", input_backend="numexpr" +) +test_neq_2_root = pytest.helpers.create_formula_test("3 != 3", input_backend="root") +test_neq_2_numexpr = pytest.helpers.create_formula_test( + "3 != 3", input_backend="numexpr" +) +test_gt_1_root = pytest.helpers.create_formula_test("3 > 1", input_backend="root") +test_gt_1_numexpr = pytest.helpers.create_formula_test("3 > 1", input_backend="numexpr") +test_gt_2_root = pytest.helpers.create_formula_test("3 > 3", input_backend="root") +test_gt_2_numexpr = pytest.helpers.create_formula_test("3 > 3", input_backend="numexpr") +test_gt_3_root = pytest.helpers.create_formula_test("5 > 3", input_backend="root") +test_gt_3_numexpr = pytest.helpers.create_formula_test("5 > 3", input_backend="numexpr") +test_gteq_1_root = pytest.helpers.create_formula_test("3 >= 1", input_backend="root") +test_gteq_1_numexpr = pytest.helpers.create_formula_test( + "3 >= 1", input_backend="numexpr" +) +test_gteq_2_root = pytest.helpers.create_formula_test("3 >= 3", input_backend="root") +test_gteq_2_numexpr = pytest.helpers.create_formula_test( + "3 >= 3", input_backend="numexpr" +) +test_gteq_3_root = pytest.helpers.create_formula_test("5 >= 3", input_backend="root") +test_gteq_3_numexpr = pytest.helpers.create_formula_test( + "5 >= 3", input_backend="numexpr" +) +test_lt_1_root = pytest.helpers.create_formula_test("3 < 1", input_backend="root") +test_lt_1_numexpr = pytest.helpers.create_formula_test("3 < 1", input_backend="numexpr") +test_lt_2_root = pytest.helpers.create_formula_test("3 < 3", input_backend="root") +test_lt_2_numexpr = pytest.helpers.create_formula_test("3 < 3", input_backend="numexpr") +test_lt_3_root = pytest.helpers.create_formula_test("5 < 3", input_backend="root") +test_lt_3_numexpr = pytest.helpers.create_formula_test("5 < 3", input_backend="numexpr") +test_lteq_1_root = pytest.helpers.create_formula_test("3 <= 1", input_backend="root") +test_lteq_1_numexpr = pytest.helpers.create_formula_test( + "3 <= 1", input_backend="numexpr" +) +test_lteq_2_root = pytest.helpers.create_formula_test("3 <= 3", input_backend="root") +test_lteq_2_numexpr = pytest.helpers.create_formula_test( + "3 <= 3", input_backend="numexpr" +) +test_lteq_3_root = pytest.helpers.create_formula_test("5 <= 3", input_backend="root") +test_lteq_3_numexpr = pytest.helpers.create_formula_test( + "5 <= 3", input_backend="numexpr" +) + +test_and_1_root = pytest.helpers.create_formula_test("0 && 0", input_backend="root") +test_and_1_numexpr = pytest.helpers.create_formula_test( + "0 & 0", input_backend="numexpr" +) +test_and_2_root = pytest.helpers.create_formula_test("0 && 1", input_backend="root") +test_and_2_numexpr = pytest.helpers.create_formula_test( + "0 & 1", input_backend="numexpr" +) +test_and_3_root = pytest.helpers.create_formula_test("1 && 0", input_backend="root") +test_and_3_numexpr = pytest.helpers.create_formula_test( + "1 & 0", input_backend="numexpr" +) +test_and_4_root = pytest.helpers.create_formula_test("1 && 1", input_backend="root") +test_and_4_numexpr = pytest.helpers.create_formula_test( + "1 & 1", input_backend="numexpr" +) +test_and_5_root = pytest.helpers.create_formula_test("3 && 5", input_backend="root") +test_and_5_numexpr = pytest.helpers.create_formula_test( + "3 & 5", input_backend="numexpr" +) + +test_or_1_root = pytest.helpers.create_formula_test("0 || 0", input_backend="root") +test_or_1_numexpr = pytest.helpers.create_formula_test("0 | 0", input_backend="numexpr") +test_or_2_root = pytest.helpers.create_formula_test("0 || 1", input_backend="root") +test_or_2_numexpr = pytest.helpers.create_formula_test("0 | 1", input_backend="numexpr") +test_or_3_root = pytest.helpers.create_formula_test("1 || 0", input_backend="root") +test_or_3_numexpr = pytest.helpers.create_formula_test("1 | 0", input_backend="numexpr") +test_or_4_root = pytest.helpers.create_formula_test("1 || 1", input_backend="root") +test_or_4_numexpr = pytest.helpers.create_formula_test("1 | 1", input_backend="numexpr") +test_or_5_root = pytest.helpers.create_formula_test("3 || 5", input_backend="root") +test_or_5_numexpr = pytest.helpers.create_formula_test("3 | 5", input_backend="numexpr") + +test_xor_1_root = pytest.helpers.create_formula_test("0 ^ 0", input_backend="root") +test_xor_1_numexpr = pytest.helpers.create_formula_test( + "0 ^ 0", input_backend="numexpr" +) +test_xor_2_root = pytest.helpers.create_formula_test("0 ^ 1", input_backend="root") +test_xor_2_numexpr = pytest.helpers.create_formula_test( + "0 ^ 1", input_backend="numexpr" +) +test_xor_3_root = pytest.helpers.create_formula_test("1 ^ 0", input_backend="root") +test_xor_3_numexpr = pytest.helpers.create_formula_test( + "1 ^ 0", input_backend="numexpr" +) +test_xor_4_root = pytest.helpers.create_formula_test("1 ^ 1", input_backend="root") +test_xor_4_numexpr = pytest.helpers.create_formula_test( + "1 ^ 1", input_backend="numexpr" +) +test_xor_5_root = pytest.helpers.create_formula_test("3 ^ 5", input_backend="root") +test_xor_5_numexpr = pytest.helpers.create_formula_test( + "3 ^ 5", input_backend="numexpr" +) + +test_not_1_root = pytest.helpers.create_formula_test("!0", input_backend="root") +test_not_1_numexpr = pytest.helpers.create_formula_test("~0", input_backend="numexpr") +test_not_2_root = pytest.helpers.create_formula_test("!1", input_backend="root") +test_not_2_numexpr = pytest.helpers.create_formula_test("~1", input_backend="numexpr") +test_not_3_root = pytest.helpers.create_formula_test("!0", input_backend="root") +test_not_3_numexpr = pytest.helpers.create_formula_test("~0", input_backend="numexpr") +test_not_4_root = pytest.helpers.create_formula_test("!1", input_backend="root") +test_not_4_numexpr = pytest.helpers.create_formula_test("~1", input_backend="numexpr") +test_not_5_root = pytest.helpers.create_formula_test("!5", input_backend="root") +test_not_5_numexpr = pytest.helpers.create_formula_test("~5", input_backend="numexpr") diff --git a/tests/test_parser.py b/tests/test_parser.py index 6a9b8b8..ad26a41 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,8 +1,4 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import sys import pytest @@ -23,17 +19,14 @@ @pytest.mark.slow def test_long_chain(): args = [UC(str(i)) for i in range(1000)] - check_result(' + '.join(map(str, args)), Expression(IDs.ADD, *args)) + check_result(" + ".join(map(str, args)), Expression(IDs.ADD, *args)) def test_alternating_chain(): - string = '0' - expected = UC('0') + string = "0" + expected = UC("0") for i in range(1, 100): - op_name, op_id = { - 0: (' + ', IDs.ADD), - 1: (' - ', IDs.SUB) - }[i % 2] + op_name, op_id = {0: (" + ", IDs.ADD), 1: (" - ", IDs.SUB)}[i % 2] string += op_name + str(i) expected = Expression(op_id, expected, UC(str(i))) check_result(string, expected) @@ -42,31 +35,28 @@ def test_alternating_chain(): @pytest.mark.slow @pytest.mark.xfail(raises=RecursionError) def test_long_alternating_chain(): - string = '0' - expected = UC('0') + string = "0" + expected = UC("0") for i in range(1, 1000): - op_name, op_id = { - 0: (' + ', IDs.ADD), - 1: (' - ', IDs.SUB) - }[i % 2] + op_name, op_id = {0: (" + ", IDs.ADD), 1: (" - ", IDs.SUB)}[i % 2] string += op_name + str(i) expected = Expression(op_id, expected, UC(str(i))) check_result(string, expected) def test_3_deep_chain(): - string = 'sqrt(sqrt(sqrt(2)))' - expected = Expression(IDs.SQRT, Expression(IDs.SQRT, Expression(IDs.SQRT, UC('2')))) + string = "sqrt(sqrt(sqrt(2)))" + expected = Expression(IDs.SQRT, Expression(IDs.SQRT, Expression(IDs.SQRT, UC("2")))) check_result(string, expected) @pytest.mark.slow @pytest.mark.xfail(raises=RecursionError) def test_5_deep_chain(): - string = '2' - expected = UC('2') - for i in list(range(5)): - string = 'sqrt('+string+')' + string = "2" + expected = UC("2") + for _ in list(range(5)): + string = f"sqrt({string})" expected = Expression(IDs.SQRT, expected) print(string) check_result(string, expected) @@ -75,10 +65,10 @@ def test_5_deep_chain(): @pytest.mark.slow @pytest.mark.xfail(raises=RecursionError) def test_10_deep_chain(): - string = '2' - expected = UC('2') - for i in list(range(10)): - string = 'sqrt('+string+')' + string = "2" + expected = UC("2") + for _ in list(range(10)): + string = f"sqrt({string})" expected = Expression(IDs.SQRT, expected) print(string) check_result(string, expected) @@ -86,19 +76,19 @@ def test_10_deep_chain(): def test_parse_invalid_expression(): with pytest.raises(ParsingException): - from_numexpr('saadasd()&+|()') + from_numexpr("saadasd()&+|()") def test_invalid_arg_parse(): with pytest.raises(ValueError): - from_numexpr(Expression(IDs.SQRT, UC('2'))) + from_numexpr(Expression(IDs.SQRT, UC("2"))) def test_too_many_function_arguments(): with pytest.raises(ParsingException): - from_numexpr('sqrt(2, 3)') + from_numexpr("sqrt(2, 3)") def test_invalid_arg_to_string(): with pytest.raises(ValueError): - to_numexpr('1') + to_numexpr("1") diff --git a/tests/utils.py b/tests/utils.py index b643531..76cc712 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,14 @@ # Licensed under a 3-clause BSD style license, see LICENSE. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import pytest -from formulate import ExpressionComponent, Expression, Variable, NamedConstant, UnnamedConstant +from formulate import ( + ExpressionComponent, + Expression, + Variable, + NamedConstant, + UnnamedConstant, +) def make_check_result(from_func, to_func): @@ -19,8 +22,13 @@ def check_result(input_string, expected_expression, expected_string=None, **kwar assert_equal_expressions(result, expected_expression) # TODO Stop stripping parentheses - assert (string.replace(' ', '').replace('(', '').replace(')', '') == - expected_string.replace(' ', '').replace('(', '').replace(')', '')), (string, expected_string) + assert string.replace(" ", "").replace("(", "").replace( + ")", "" + ) == expected_string.replace(" ", "").replace("(", "").replace(")", ""), ( + string, + expected_string, + ) + return check_result