diff --git a/dvc_ssh/__init__.py b/dvc_ssh/__init__.py index fa81d7a..1cb6d37 100644 --- a/dvc_ssh/__init__.py +++ b/dvc_ssh/__init__.py @@ -1,17 +1,57 @@ import getpass import os.path import threading -from typing import ClassVar +from contextlib import suppress +from pathlib import Path, PurePath +from typing import ClassVar, Union +from asyncssh.config import SSHClientConfig from funcy import memoize, silent, wrap_prop, wrap_with from dvc.utils.objects import cached_property from dvc_objects.fs.base import FileSystem from dvc_objects.fs.utils import as_atomic +SSH_CONFIG = Path("~", ".ssh", "config").expanduser() +FilePath = Union[str, PurePath] + DEFAULT_PORT = 22 +def parse_config(*, host, user=(), port=(), local_user=None, config_files=None): + if config_files is None: + config_files = [SSH_CONFIG] + + if local_user is None: + with suppress(KeyError): + local_user = getpass.getuser() + + last_config = None + reload = False + config = SSHClientConfig( + last_config=last_config, + reload=reload, + canonical=False, + final=False, + local_user=local_user, + user=user, + host=host, + port=port, + ) + + if config_files: + if isinstance(config_files, (str, PurePath)): + # paths: Sequence[FilePath] = [config_files] #lint issue + paths = [config_files] + else: + paths = config_files + + for path in paths: + config.parse(Path(path)) + config.loaded = True + return config + + @wrap_with(threading.Lock()) @memoize def ask_password(host, user, port, desc): @@ -41,7 +81,7 @@ def unstrip_protocol(self, path: str) -> str: return f"ssh://{host}:{port}/{path}" def _prepare_credentials(self, **config): - from sshfs.config import parse_config + # from sshfs.config import parse_config from .client import InteractiveSSHClient