Source code for xorbits._mars.config

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import operator
import os
import threading
import warnings
from copy import deepcopy
from functools import reduce
from typing import Any, Dict, Union

_DEFAULT_REDIRECT_WARN = (
    "Option {source} has been replaced by {target} and "
    "might be removed in a future release."
)


class OptionError(Exception):
    pass


class Redirection:
    def __init__(self, item, warn=None):
        self._items = item.split(".")
        self._warn = warn
        self._warned = True
        self._parent = None

    def bind(self, attr_dict):
        self._parent = attr_dict
        self.getvalue()
        self._warned = False

    def getvalue(self):
        if self._warn and not self._warned:
            self._warned = True
            warnings.warn(self._warn)
        conf = self._parent.root
        for it in self._items:
            conf = getattr(conf, it)
        return conf

    def setvalue(self, value):
        if self._warn and not self._warned:
            self._warned = True
            warnings.warn(self._warn)
        conf = self._parent.root
        for it in self._items[:-1]:
            conf = getattr(conf, it)
        setattr(conf, self._items[-1], value)


class AttributeDict(dict):
    def __init__(self, *args, **kwargs):
        self._inited = False
        self._parent = kwargs.pop("_parent", None)
        self._root = None
        super().__init__(*args, **kwargs)
        self._inited = True

    @property
    def root(self):
        if self._root is not None:
            return self._root
        if self._parent is None:
            self._root = self
        else:
            self._root = self._parent.root
        return self._root

    def __getattr__(self, item):
        if item in self:
            val = self[item]
            if isinstance(val, AttributeDict):
                return val
            elif isinstance(val[0], Redirection):
                return val[0].getvalue()
            else:
                return val[0]
        return object.__getattribute__(self, item)

    def __dir__(self):
        return list(self.keys())

    def register(self, key, value, validator=None):
        if isinstance(validator, tuple):
            validator = any_validator(*validator)
        self[key] = value, validator
        if isinstance(value, Redirection):
            value.bind(self)

    def unregister(self, key):
        del self[key]

    def _setattr(self, key, value, silent=False):
        splits = key.split(".")
        target = self
        for k in splits[:-1]:
            if not silent and (
                not isinstance(target, AttributeDict) or k not in target
            ):
                raise OptionError("You can only set the value of existing options")
            target = target[k]
        key = splits[-1]

        if not isinstance(value, AttributeDict):
            validate = None
            if key in target:
                val = target[key]
                validate = target[key][1]
                if validate is not None:
                    if not validate(value):
                        raise ValueError(f"Cannot set value `{value}`")
                if isinstance(val[0], Redirection):
                    val[0].setvalue(value)
                else:
                    target[key] = value, validate
            else:
                target[key] = value, validate
        else:
            target[key] = value

    def __setattr__(self, key, value):
        if key == "_inited":
            super().__setattr__(key, value)
            return
        try:
            object.__getattribute__(self, key)
            super().__setattr__(key, value)
            return
        except AttributeError:
            pass

        if not self._inited:
            super().__setattr__(key, value)
        else:
            self._setattr(key, value)

    def to_dict(self):
        d = dict()
        for k, v in self.items():
            if isinstance(v, AttributeDict):
                d.update(
                    {k + "." + sub_k: sub_v for sub_k, sub_v in v.to_dict().items()}
                )
            elif isinstance(v[0], Redirection):
                continue
            else:
                d[k] = v[0]
        return d


class Config:
    def __init__(self, config=None):
        self._config = config or AttributeDict()
        self._serialize_options = []

    def __dir__(self):
        return list(self._config.keys())

    def __getattr__(self, item):
        config = object.__getattribute__(self, "_config")
        return getattr(config, item)

    def __setattr__(self, key, value):
        if key.startswith("_"):
            object.__setattr__(self, key, value)
            return
        setattr(self._config, key, value)

    def reset_option(self, key):
        attr_list = key.split(".")
        setattr(
            reduce(getattr, attr_list[:-1], self._config),
            attr_list[-1],
            reduce(getattr, attr_list, _default_options),
        )

    def register_option(self, option, value, validator=None, serialize=False):
        splits = option.split(".")
        conf = self._config
        if isinstance(validator, tuple):
            validator = any_validator(*validator)

        for name in splits[:-1]:
            config = conf.get(name)
            if config is None:
                val = AttributeDict(_parent=conf)
                conf[name] = val
                conf = val
            elif not isinstance(config, dict):
                raise AttributeError(
                    f"Fail to set option: {option}, conflict has encountered"
                )
            else:
                conf = config

        key = splits[-1]
        if conf.get(key) is not None:
            raise AttributeError(f"Fail to set option: {option}, option has been set")

        conf.register(key, value, validator)
        if serialize:
            self._serialize_options.append(option)

    def redirect_option(self, option, target, warn=_DEFAULT_REDIRECT_WARN):
        redir = Redirection(target, warn=warn.format(source=option, target=target))
        self.register_option(option, redir)

    def unregister_option(self, option):
        splits = option.split(".")
        conf = self._config
        for name in splits[:-1]:
            config = conf.get(name)
            if not isinstance(config, dict):
                raise AttributeError(
                    f"Fail to unregister option: {option}, conflict has encountered"
                )
            else:
                conf = config

        key = splits[-1]
        if key not in conf:
            raise AttributeError(
                f"Option {option} not configured, thus failed to unregister."
            )
        conf.unregister(key)

    def copy(self):
        new_options = Config(deepcopy(self._config))
        return new_options

    def get_option(self, option: str) -> Any:
        """
        Retrieves the value of the specified option.

        Parameters
        ----------
        option : str
            The name of the option to retrieve. Can be a nested option using dot notation (e.g., dataframe.dtype_backend).

        Returns
        -------
        Any
            The value of the specified option.

        Raises
        ------
        AttributeError
            If the specified option does not exist.

        Examples
        --------
        >>> xorbits.options.get_option("show_progress")
        'auto'
        """
        splits = option.split(".")
        conf = self._config
        for name in splits[:-1]:
            config = conf.get(name)
            if not isinstance(config, dict):
                raise AttributeError(f"No such keys(s): {option}.")
            else:
                conf = config

        key = splits[-1]
        if key not in conf:
            raise AttributeError(f"No such keys(s): {option}.")
        (value, _) = conf.get(key)
        return value

    def set_option(self, option: str, value: Any) -> Any:
        """
        Sets the value of the specified option.

        Parameters
        ----------
        option : str
            The name of the option to set. Can be a nested option using dot notation (e.g., dataframe.dtype_backend).
        value : Any
            The value to set for the specified option.

        Returns
        -------
        Any
            The new value of the specified option.

        Raises
        ------
        AttributeError
            If the specified option does not exist.
        ValueError
            If the provided value is invalid for the option.

        Examples
        --------
        >>> xorbits.options.set_option("show_progress", False)
        """
        splits = option.split(".")
        conf = self._config
        for name in splits[:-1]:
            config = conf.get(name)
            if not isinstance(config, dict):
                raise AttributeError(f"No such keys(s): {option}.")
            else:
                conf = config

        key = splits[-1]
        if key not in conf:
            raise AttributeError(f"No such keys(s): {option}.")
        (old_value, validator) = conf.get(key)
        if validator is not None:
            if not validator(value):
                raise ValueError(f"Invalid value {value} for option {option}")

        conf[key] = value, validator

    def update(self, new_config: Union["Config", Dict]):
        if not isinstance(new_config, dict):
            new_config = new_config._config
        for option, value in new_config.items():
            try:
                self.register_option(option, value)
            except AttributeError:
                setattr(self, option, value)

    def get_serializable(self):
        d = dict()
        for k in self._serialize_options:
            parts = k.split(".")
            v = self
            for p in parts:
                v = getattr(v, p)
            d[k] = v
        return d

    def fill_serialized(self, d):
        for k, v in d.items():
            parts = k.split(".")
            cf = self
            for p in parts[:-1]:
                cf = getattr(cf, p)
            setattr(cf, parts[-1], v)

    def to_dict(self):
        return self._config.to_dict()


[docs] @contextlib.contextmanager def option_context(config=None): """ Context manager to temporarily set options in a ``with`` statement. Parameters ---------- config : dict A dictionary of key-value option pairs to set temporarily. Returns ------- None No return value. Examples -------- >>> with xorbits.option_context({'show_progress': False}): ... # Code here will run with show_progress=False >>> # Outside the context, the original configuration is restored """ global_options = get_global_option() try: config = config or dict() local_options = Config(deepcopy(global_options._config)) local_options.update(config) _options_local.default_options = local_options yield local_options finally: _options_local.default_options = global_options
def is_interactive(): import __main__ as main return not hasattr(main, "__file__") # validators def any_validator(*validators): def validate(x): return any(validator(x) for validator in validators) return validate def all_validator(*validators): def validate(x): return all(validator(x) for validator in validators) return validate def _instance_check(typ, v): return isinstance(v, typ) is_null = functools.partial(operator.is_, None) is_bool = functools.partial(_instance_check, bool) is_integer = functools.partial(_instance_check, int) is_float = functools.partial(_instance_check, float) is_numeric = functools.partial(_instance_check, (float, int)) is_string = functools.partial(_instance_check, str) is_dict = functools.partial(_instance_check, dict) is_list = functools.partial(_instance_check, list) def is_in(vals): def validate(x): return x in vals return validate default_options = Config() default_options.register_option("tcp_timeout", 30, validator=is_integer) default_options.register_option("verbose", False, validator=is_bool) default_options.register_option("kv_store", ":inproc:", validator=is_string) default_options.register_option("check_interval", 20, validator=is_integer) default_options.register_option( "show_progress", "auto", validator=any_validator(is_bool, is_string) ) default_options.register_option("serialize_method", "pickle") # dataframe-related options default_options.register_option( "dataframe.dtype_backend", "numpy_nullable", validator=is_in(("numpy_nullable", "pyarrow")), ) default_options.register_option( "dataframe.arrow_array.pandas_only", None, validator=any_validator(is_null, is_bool), ) # learn options assume_finite = os.environ.get("SKLEARN_ASSUME_FINITE") if assume_finite is not None: assume_finite = bool(assume_finite) working_memory = os.environ.get("SKLEARN_WORKING_MEMORY") if working_memory is not None: working_memory = int(working_memory) default_options.register_option( "learn.assume_finite", assume_finite, validator=any_validator(is_null, is_bool) ) default_options.register_option( "learn.working_memory", working_memory, validator=any_validator(is_null, is_integer) ) # the number of combined chunks in tree reduction or tree add default_options.register_option("combine_size", 4, validator=is_integer, serialize=True) # the default chunk store size default_options.register_option( "chunk_store_limit", 512 * 1024**2, validator=is_numeric ) default_options.register_option( "chunk_size", None, validator=any_validator(is_null, is_integer), serialize=True ) # rechunk default_options.register_option( "rechunk.threshold", 4, validator=is_integer, serialize=True ) default_options.register_option( "rechunk.chunk_size_limit", int(1e8), validator=is_integer, serialize=True ) default_options.register_option( "bincount.chunk_size_limit", int(1e8), validator=is_integer, serialize=True ) # deploy default_options.register_option("deploy.open_browser", True, validator=is_bool) # optimization default_options.register_option("optimize_tileable_graph", True, validator=is_bool) # eager mode default_options.register_option("eager_mode", False, validator=is_bool) # optimization default_options.register_option( "optimize.head_optimize_threshold", 1000, validator=is_integer ) # debug default_options.register_option("warn_duplicated_execution", False, validator=is_bool) # client serialize type default_options.register_option("client.serial_type", "arrow", validator=is_string) # custom log dir default_options.register_option( "custom_log_dir", None, validator=any_validator(is_null, is_string) ) # vineyard default_options.register_option( "vineyard.socket", os.environ.get("VINEYARD_IPC_SOCKET", None) ) default_options.register_option( "vineyard.enabled", os.environ.get("WITH_VINEYARD", None) is not None ) _options_local = threading.local() _options_local.default_options = default_options _default_options = default_options.copy() def get_global_option(): ret = getattr(_options_local, "default_options", None) if ret is None: ret = _options_local.default_options = Config(deepcopy(default_options._config)) return ret class OptionsProxy: def __dir__(self): return dir(get_global_option()) def __getattribute__(self, attr): return getattr(get_global_option(), attr) def __setattr__(self, key, value): setattr(get_global_option(), key, value) options = OptionsProxy() options.__doc__ = """ An object for accessing and modifying global job-level configuration options. This object provides access to all the configuration options defined in the global configuration. It allows getting and setting option values, as well as accessing nested options using dot notation (e.g., dataframe.dtype_backend). Examples: >>> options.show_progress 'auto' >>> options.set_option("show_progress", False) >>> options.show_progress False See Also: :func:`get_option`: Function to retrieve option values. :func:`set_option`: Function to set option values. """ options.redirect_option("tensor.chunk_store_limit", "chunk_store_limit") options.redirect_option("tensor.chunk_size", "chunk_size") options.redirect_option("tensor.rechunk.threshold", "rechunk.threshold") options.redirect_option("tensor.rechunk.chunk_size_limit", "rechunk.chunk_size_limit")