Skip to content

Commit

Permalink
WIP: Add custom DistroFinders for hab download support
Browse files Browse the repository at this point in the history
  • Loading branch information
MHendricks committed Nov 19, 2024
1 parent b920692 commit e0537b6
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 0 deletions.
58 changes: 58 additions & 0 deletions hab/distro_finders/df_zip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging

from .. import utils
from .zip_sidecar import DistroFinderZipSidecar

logger = logging.getLogger(__name__)


class DistroFinderZip(DistroFinderZipSidecar):
def __init__(self, root, site=None):
super().__init__(root, site=site)
self.glob_str = "*.zip"
self.hab_filename = ".hab.json"
self._cache = {}

def content(self, path):
return path.parent

def get_text_files_from_zip(self, archive, cache_files, as_path=True):
"""Opens the zip archive and yields any cache_file paths that exist in
the archive.
To reduce re-opening the zip archive later this also caches the contents
of each of these text files for later processing in load_path.
"""
with archive:
for cache_file in cache_files:
if cache_file not in archive.namelist():
continue

data = archive.read(cache_file)
fn = archive.filename / cache_file
self._cache[fn] = data
yield fn, cache_file

def clear_cache(self, persistent=False):
"""Clear cached data in memory. If `persistent` is True then also remove
cache data from disk if it exists.
"""
self._cache = {}

def distro_path_info(self):
for path in self.root.glob(self.glob_str):
archive = self.archive(path)
for path, _ in self.get_text_files_from_zip(
archive, cache_files=[self.hab_filename]
):
yield None, path, False

def load_path(self, distro, path):
logger.debug(f'Loading json: "{path}"')
data = self._cache[path]
data = data.decode("utf-8")
data = utils.loads_json(data, source=path)
# Pull the version from the sidecar filename if its not explicitly set
if "version" not in data:
data["version"] = self.version_for_path(path)
return data
84 changes: 84 additions & 0 deletions hab/distro_finders/s3_zip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
from hashlib import sha256

import remotezip
from cloudpathlib import CloudPath, S3Client
from requests_aws4auth import AWS4Auth

from .df_zip import DistroFinderZip

logger = logging.getLogger(__name__)


class DistroFinderS3Zip(DistroFinderZip):
def __init__(
self, root, site=None, client=None, profile_name=None, **object_filters
):
# Root should not be cast to a pathlib.Path object on this class.
super().__init__("", site=site)
# self.object_filters = object_filters
# bucket_name = root.split("/")[0]
# self.bucket_name = bucket_name

self.client = client
if self.client is None:
if profile_name:
self.client = S3Client(profile_name=profile_name)
else:
self.client = S3Client()

self.root = CloudPath(root, client=self.client)

def _credentials(self):
"""Returns the credentials needed for requests to connect to aws s3 bucket.
Generates these credentials using the client object.
"""

try:
return self.__credentials
except AttributeError:
pass
# The `x-amz-content-sha256` header is required for all AWS Signature
# Version 4 requests. It provides a hash of the request payload. If
# there is no payload, you must provide the hash of an empty string.
headers = {"x-amz-content-sha256": sha256(b"").hexdigest()}

location = self.client.client.get_bucket_location(Bucket=self.root.bucket)[
"LocationConstraint"
]
auth = AWS4Auth(
refreshable_credentials=self.client.sess.get_credentials(),
region=location,
service="s3",
)

self.__credentials = (auth, headers)
return self.__credentials

def archive(self, zip_path):
"""Returns a `zipfile.Zipfile` like instance for zip_path.
Path should be a aws s3 object url pointing to a .zip file.
"""
logger.debug(f"Connecting to s3 for url: {zip_path}")
auth, headers = self._credentials()
ret = remotezip.RemoteZip(zip_path.as_url(), auth=auth, headers=headers)
ret.filename = zip_path
return ret

def clear_cache(self, persistent=False):
"""Clear cached data in memory. If `persistent` is True then also remove
cache data from disk if it exists.
"""
if persistent:
self.remove_download_cache()
super().clear_cache(persistent=persistent)

def content(self, path):
return path.removesuffix(f"/{self.hab_filename}")

def install(self, path, dest):
raise NotImplementedError("Using ZipFile.extract on this is a bad idea")
# Download zip if not in cache
# call super on cached zip file
57 changes: 57 additions & 0 deletions hab/distro_finders/zip_sidecar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import re
import zipfile

from packaging.version import VERSION_PATTERN

from .. import utils
from .distro_finder import DistroFinder

logger = logging.getLogger(__name__)


class DistroFinderZipSidecar(DistroFinder):
"""Working with zipped distros that have a sidecar `dist_name_v0.0.0.hab.json`
file. This is useful when it can't extract the `.hab.json` from the .zip file.
"""

version_regex = re.compile(
rf"(?P<name>.+)_v{VERSION_PATTERN}", flags=re.VERBOSE | re.IGNORECASE
)

def __init__(self, root, site=None):
super().__init__(root, site)
self.glob_str = "*.hab.json"

def archive(self, zip_path):
"""Returns a `zipfile.Zipfile` like instance for zip_path."""
return zipfile.ZipFile(zip_path)

def content(self, path):
"""Returns the path to the sidecar .zip file.
This replaces `.hab.json` with `.zip`.
"""
return path.with_suffix("").with_suffix(".zip")

def install(self, path, dest):
path = self.content(path)
logger.debug(f"Extracting to {dest} from zip {path}")
with self.archive(path) as archive:
members = archive.namelist()
total = len(members)
for i, member in enumerate(members):
logger.debug(f"Extracting file({i}/{total}): {member}")
archive.extract(member, dest)

def load_path(self, distro, path):
"""Returns the raw distro dictionary with version set."""
logger.debug(f'Loading "{path}"')
data = utils.load_json_file(path)
# Pull the version from the sidecar filename if its not explicitly set
if "version" not in data:
data["version"] = self.version_for_path(path)
return data

def version_for_path(self, path):
return self.version_regex.match(str(path)).group("release")

0 comments on commit e0537b6

Please sign in to comment.