diff --git a/docs/flow/modin/config.rst b/docs/flow/modin/config.rst index 0b2afb5934a..924fb8f5c25 100644 --- a/docs/flow/modin/config.rst +++ b/docs/flow/modin/config.rst @@ -56,3 +56,21 @@ API. # Changing value of `NPartitions` modin.config.NPartitions.put(16) print(modin.config.NPartitions.get()) # prints '16' + +One can also use config variables with a context manager in order to use +some config only for a certain part of the code: + +.. code-block:: python + + import modin.config as cfg + + # Default value for this config is 'False' + print(cfg.RangePartitioning.get()) # False + + # Set the config to 'True' inside of the context-manager + with cfg.RangePartitioning(True): + print(cfg.RangePartitioning.get()) # True + df.merge(...) # will use range-partitioning impl + + # Once the context is over, the config gets back to its previous value + print(cfg.RangePartitioning.get()) # False diff --git a/modin/config/pubsub.py b/modin/config/pubsub.py index a49f906ca64..3070dd89bcc 100644 --- a/modin/config/pubsub.py +++ b/modin/config/pubsub.py @@ -194,6 +194,15 @@ class Parameter(object): """ Base class describing interface for configuration entities. + To set the parameter's value you can use both ``Parameter.put(val)`` and ``Parameter(val)``, + the latter also supports a context-manager use-case. Exiting the context will result into + resetting the parameter's value to the previous value. + + Parameters + ---------- + value : object + A value to set to this parameter. + Attributes ---------- choices : Optional[Sequence[str]] @@ -209,6 +218,18 @@ class Parameter(object): ``ValueSource``. _deprecation_descriptor : Optional[DeprecationDescriptor] Indicate whether this parameter is deprecated. + + Examples + -------- + >>> class MyParameter(Parameter, type=bool): + ... default = False + >>> MyParameter.get() + False + >>> with MyParameter(True): + ... print(MyParameter.get()) # True + True + >>> MyParameter.get() + False """ choices: Optional[Tuple[str, ...]] = None @@ -221,6 +242,16 @@ class Parameter(object): _once: DefaultDict[Any, list] = defaultdict(list) _deprecation_descriptor: Optional[DeprecationDescriptor] = None + def __init__(self, value): + self._previous_val = self.get() + self.put(value) + + def __enter__(self, *args, **kwargs): # noqa: GL08 + return self + + def __exit__(self, *args, **kwargs): # noqa: GL08 + self.put(self._previous_val) + @classmethod def _get_raw_from_config(cls) -> str: """ diff --git a/modin/config/test/test_parameter.py b/modin/config/test/test_parameter.py index cc9bbb5f3f2..7f48815024e 100644 --- a/modin/config/test/test_parameter.py +++ b/modin/config/test/test_parameter.py @@ -101,3 +101,25 @@ def test_init_validation(vartype): parameter = make_prefilled(vartype, "bad value") with pytest.raises(ValueError): parameter.get() + + +def test_context_manager(): + parameter = make_prefilled(vartype=bool, varinit="False") + + # simple case + assert parameter.get() is False + with parameter(True): + assert parameter.get() is True + assert parameter.get() is False + + # nested case + assert parameter.get() is False + with parameter(True): + assert parameter.get() is True + with parameter(False): + assert parameter.get() is False + with parameter(False): + assert parameter.get() is False + assert parameter.get() is False + assert parameter.get() is True + assert parameter.get() is False