From 48110ee545f60fa49c7da4dbff8b9ded95230781 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 3 Oct 2024 13:24:19 +0200 Subject: [PATCH 01/24] S3 support using libcurl --- .../all_cuda-118_arch-aarch64.yaml | 2 + .../all_cuda-118_arch-x86_64.yaml | 2 + .../all_cuda-125_arch-aarch64.yaml | 2 + .../all_cuda-125_arch-x86_64.yaml | 2 + cpp/include/kvikio/remote_handle.hpp | 136 ++++++++++ dependencies.yaml | 7 + python/kvikio/kvikio/_lib/remote_handle.pyx | 62 +++++ python/kvikio/kvikio/benchmarks/s3_io.py | 247 ++++++++++++++++++ python/kvikio/kvikio/remote_file.py | 44 ++++ python/kvikio/pyproject.toml | 3 + python/kvikio/tests/test_s3_io.py | 131 ++++++++++ 11 files changed, 638 insertions(+) create mode 100644 python/kvikio/kvikio/benchmarks/s3_io.py create mode 100644 python/kvikio/tests/test_s3_io.py diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 0e7f4b3e21..ef1215d51b 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -18,6 +19,7 @@ dependencies: - doxygen=1.9.1 - gcc_linux-aarch64=11.* - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 293085e8f7..842b984cc6 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -20,6 +21,7 @@ dependencies: - libcufile-dev=1.4.0.31 - libcufile=1.4.0.31 - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index 1e4a370ff6..9a4b3e94bd 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-aarch64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 44d8772a71..2b926acf29 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index e036ebcb37..2188af1941 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -151,6 +152,141 @@ class HttpEndpoint : public RemoteEndpoint { ~HttpEndpoint() override = default; }; +/** + * @brief + */ +class S3Endpoint : public RemoteEndpoint { + private: + std::string _url; + std::string _aws_sigv4; + std::string _aws_userpwd; + + static std::string parse_aws_argument(std::optional aws_arg, + const std::string& env_var, + const std::string& err_msg, + bool allow_empty = false) + { + if (aws_arg.has_value()) { return std::move(*aws_arg); } + + char const* env = std::getenv(env_var.c_str()); + if (env == nullptr) { + if (allow_empty) { return std::string(); } + throw std::invalid_argument(err_msg); + } + return std::string(env); + } + + static std::string url_from_bucket_and_object(const std::string& bucket_name, + const std::string& object_name, + const std::optional& aws_region, + std::optional aws_endpoint_url) + { + std::string endpoint_url = + parse_aws_argument(std::move(aws_endpoint_url), + "AWS_ENDPOINT_URL", + "S3: must provide `aws_endpoint_url` if AWS_ENDPOINT_URL isn't set.", + true); + std::stringstream ss; + if (endpoint_url.empty()) { + std::string region = + parse_aws_argument(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + // We default to the official AWS url scheme. + ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name; + } else { + ss << endpoint_url << "/" << bucket_name << "/" << object_name; + } + return ss.str(); + } + + public: + /** + * @brief Given an url like "s3:///", return the name of the bucket and object. + * + * @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name. + * + * @param s3_url S3 url. + * @return Pair of strings: [bucket-name, object-name]. + */ + static std::pair parse_s3_url(std::string const& s3_url) + { + if (s3_url.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } + if (s3_url.size() < 5 || s3_url.substr(0, 5) != "s3://") { + throw std::invalid_argument("The S3 url must start with the S3 scheme (\"s3://\")."); + } + std::string p = s3_url.substr(5); + if (p.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } + size_t pos = p.find_first_of('/'); + std::string bucket_name = p.substr(0, pos); + if (bucket_name.empty()) { + throw std::invalid_argument("The S3 url does not contain a bucket name."); + } + std::string object_name = (pos == std::string::npos) ? "" : p.substr(pos + 1); + if (object_name.empty()) { + throw std::invalid_argument("The S3 url does not contain an object name."); + } + return std::make_pair(std::move(bucket_name), std::move(object_name)); + } + + S3Endpoint(std::string url, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt) + : _url{std::move(url)} + { + std::string region = + parse_aws_argument(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + + std::string access_key = + parse_aws_argument(std::move(aws_access_key), + "AWS_ACCESS_KEY_ID", + "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); + + std::string secret_access_key = parse_aws_argument( + std::move(aws_secret_access_key), + "AWS_SECRET_ACCESS_KEY", + "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set."); + + // Create the CURLOPT_AWS_SIGV4 option + { + std::stringstream ss; + ss << "aws:amz:" << region << ":s3"; + _aws_sigv4 = ss.str(); + } + // Create the CURLOPT_USERPWD option + { + std::stringstream ss; + ss << access_key << ":" << secret_access_key; + _aws_userpwd = ss.str(); + } + } + S3Endpoint(const std::string& bucket_name, + const std::string& object_name, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt, + std::optional aws_endpoint_url = std::nullopt) + : S3Endpoint(url_from_bucket_and_object( + bucket_name, object_name, aws_region, std::move(aws_endpoint_url)), + std::move(aws_region), + std::move(aws_access_key), + std::move(aws_secret_access_key)) + { + } + + void setopt(CurlHandle& curl) override + { + curl.setopt(CURLOPT_URL, _url.c_str()); + curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str()); + curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str()); + } + std::string str() override { return _url; } + ~S3Endpoint() override = default; +}; + /** * @brief Handle of remote file. */ diff --git a/dependencies.yaml b/dependencies.yaml index 39ba3aaa17..85bf871150 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -345,6 +345,13 @@ dependencies: - pytest - pytest-cov - rangehttpserver + - boto3>=1.21.21 + - output_types: [requirements, pyproject] + packages: + - moto[server]>=4.0.8 + - output_types: conda + packages: + - moto>=4.0.8 specific: - output_types: [conda, requirements, pyproject] matrices: diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 5e58da32f0..357a965595 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -23,6 +23,15 @@ cdef extern from "" nogil: cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint": cpp_HttpEndpoint(string url) except + + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": + cpp_S3Endpoint(string url) except + + + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": + cpp_S3Endpoint(string bucket_name, string object_name) except + + + pair[string, string] cpp_parse_s3_url \ + "kvikio::S3Endpoint::parse_s3_url"(string url) except + + cdef cppclass cpp_RemoteHandle "kvikio::RemoteHandle": cpp_RemoteHandle( unique_ptr[cpp_RemoteEndpoint] endpoint, size_t nbytes @@ -67,6 +76,59 @@ cdef class RemoteFile: ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) return ret + @classmethod + def open_s3_from_http_url( + cls, + url: str, + nbytes: Optional[int], + ): + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + _to_string(url) + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + + @classmethod + def open_s3( + cls, + bucket_name: str, + object_name: str, + nbytes: Optional[int], + ): + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + _to_string(bucket_name), _to_string(object_name) + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + + @classmethod + def open_s3_from_s3_url( + cls, + url: str, + nbytes: Optional[int], + ): + cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url)) + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + bucket_and_object.first, bucket_and_object.second + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + def nbytes(self) -> int: return deref(self._handle).nbytes() diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py new file mode 100644 index 0000000000..6130885442 --- /dev/null +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -0,0 +1,247 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import argparse +import contextlib +import multiprocessing +import os +import socket +import statistics +import sys +import time +from functools import partial +from typing import ContextManager +from urllib.parse import urlparse + +import boto3 +import cupy +import numpy +from dask.utils import format_bytes + +import kvikio +import kvikio.defaults + + +def get_local_port() -> int: + """Return an available port""" + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes. + NB: to stop before `lifetime`, kill the process/thread running this function. + """ + from moto.server import ThreadedMotoServer + + # Silence the activity info from ThreadedMotoServer + sys.stderr = open(os.devnull, "w") + url = urlparse(os.environ["AWS_ENDPOINT_URL"]) + server = ThreadedMotoServer(ip_address=url.hostname, port=url.port) + server.start() + time.sleep(lifetime) + + +@contextlib.contextmanager +def local_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes or kill it on context exit""" + # Use fake aws credentials + os.environ["AWS_ACCESS_KEY_ID"] = "foobar_key" + os.environ["AWS_SECRET_ACCESS_KEY"] = "foobar_secret" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + p = multiprocessing.Process(target=start_s3_server, args=(lifetime,)) + p.start() + yield + p.kill() + + +def create_client_and_bucket(): + client = boto3.client("s3", endpoint_url=os.getenv("AWS_ENDPOINT_URL", None)) + try: + client.create_bucket(Bucket=args.bucket, ACL="public-read-write") + except ( + client.exceptions.BucketAlreadyOwnedByYou, + client.exceptions.BucketAlreadyExists, + ): + pass + except Exception: + print( + "Problem accessing the S3 server? using wrong credentials? Try setting " + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and/or AWS_ENDPOINT_URL. " + "Alternatively, use the bundled server `--use-bundled-server`\n", + file=sys.stderr, + flush=True, + ) + raise + return client + + +def run_numpy_like(args, xp): + # Upload data to S3 server + data = numpy.arange(args.nelem, dtype=args.dtype) + recv = xp.empty_like(data) + + client = create_client_and_bucket() + client.put_object(Bucket=args.bucket, Key="data", Body=bytes(data)) + server_address = os.environ["AWS_ENDPOINT_URL"] + url = f"{server_address}/{args.bucket}/data" + + def run() -> float: + t0 = time.perf_counter() + with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + res = f.read(recv) + t1 = time.perf_counter() + assert res == args.nbytes, f"IO mismatch, expected {args.nbytes} got {res}" + xp.testing.assert_array_equal(data, recv) + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +def run_cudf(args, libcudf_s3_io: bool): + import cudf + + cudf.set_option("libcudf_s3_io", libcudf_s3_io) + + # Upload data to S3 server + create_client_and_bucket() + data = cupy.random.rand(args.nelem).astype(args.dtype) + df = cudf.DataFrame({"a": data}) + df.to_parquet(f"s3://{args.bucket}/data1") + + def run() -> float: + t0 = time.perf_counter() + cudf.read_parquet(f"s3://{args.bucket}/data1") + t1 = time.perf_counter() + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +API = { + "cupy-kvikio": partial(run_numpy_like, xp=cupy), + "numpy-kvikio": partial(run_numpy_like, xp=numpy), + "cudf-kvikio": partial(run_cudf, libcudf_s3_io=True), + "cudf-fsspec": partial(run_cudf, libcudf_s3_io=False), +} + + +def main(args): + cupy.cuda.set_allocator(None) # Disable CuPy's default memory pool + cupy.arange(10) # Make sure CUDA is initialized + + os.environ["KVIKIO_NTHREADS"] = str(args.nthreads) + kvikio.defaults.num_threads_reset(args.nthreads) + + print("Roundtrip benchmark") + print("--------------------------------------") + print(f"nelem | {args.nelem} ({format_bytes(args.nbytes)})") + print(f"dtype | {args.dtype}") + print(f"nthreads | {args.nthreads}") + print(f"nruns | {args.nruns}") + print(f"server | {os.getenv('AWS_ENDPOINT_URL', 'http://*.amazonaws.com')}") + if args.use_bundled_server: + print("--------------------------------------") + print("Using the bundled local server is slow") + print("and can be misleading. Consider using") + print("a local MinIO or official S3 server.") + print("======================================") + + # Run each benchmark using the requested APIs + for api in args.api: + res = [] + for elapsed in API[api](args): + res.append(elapsed) + + def pprint_api_res(name, samples): + samples = [args.nbytes / s for s in samples] # Convert to throughput + mean = statistics.harmonic_mean(samples) if len(samples) > 1 else samples[0] + ret = f"{api}-{name}".ljust(18) + ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) + if len(samples) > 1: + stdev = statistics.stdev(samples) / mean * 100 + ret += " ± %5.2f %%" % stdev + ret += " (" + for sample in samples: + ret += f"{format_bytes(sample)}/s, " + ret = ret[:-2] + ")" # Replace trailing comma + return ret + + print(pprint_api_res("read", res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Roundtrip benchmark") + parser.add_argument( + "-n", + "--nelem", + metavar="NELEM", + default="1024", + type=int, + help="Number of elements (default: %(default)s).", + ) + parser.add_argument( + "--dtype", + metavar="DATATYPE", + default="float32", + type=numpy.dtype, + help="The data type of each element (default: %(default)s).", + ) + parser.add_argument( + "--nruns", + metavar="RUNS", + default=1, + type=int, + help="Number of runs per API (default: %(default)s).", + ) + parser.add_argument( + "-t", + "--nthreads", + metavar="THREADS", + default=1, + type=int, + help="Number of threads to use (default: %(default)s).", + ) + parser.add_argument( + "--use-bundled-server", + action="store_true", + help="Launch and use a local slow S3 server (ThreadedMotoServer).", + ) + parser.add_argument( + "--bundled-server-lifetime", + metavar="SECONDS", + default=3600, + type=int, + help="Maximum lifetime of the bundled server (default: %(default)s).", + ) + parser.add_argument( + "--bucket", + metavar="NAME", + default="kvikio-s3-benchmark", + type=str, + help="Name of the AWS S3 bucket to use (default: %(default)s).", + ) + parser.add_argument( + "--api", + metavar="API", + default=list(API.keys())[0], # defaults to the first API + nargs="+", + choices=tuple(API.keys()) + ("all",), + help="List of APIs to use {%(choices)s} (default: %(default)s).", + ) + args = parser.parse_args() + args.nbytes = args.nelem * args.dtype.itemsize + if "all" in args.api: + args.api = tuple(API.keys()) + + ctx: ContextManager = contextlib.nullcontext() + if args.use_bundled_server: + os.environ["AWS_ENDPOINT_URL"] = f"http://127.0.0.1:{get_local_port()}" + ctx = local_s3_server(args.bundled_server_lifetime) + with ctx: + main(args) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 52bbe8010f..5227126278 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -68,6 +68,50 @@ def open_http( """ return RemoteFile(_get_remote_module().RemoteFile.open_http(url, nbytes)) + @classmethod + def open_s3( + cls, + bucket_name: str, + object_name: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3(bucket_name, object_name, nbytes) + ) + + @classmethod + def open_s3_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + url = url.lower() + if url.startswith("http://") or url.startswith("https://"): + return cls.open_s3_from_http_url(url, nbytes) + if url.startswith("s://"): + return cls.open_s3_from_s3_url(url, nbytes) + raise ValueError(f"Unsupported protocol in url: {url}") + + @classmethod + def open_s3_from_http_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) + ) + + @classmethod + def open_s3_from_s3_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) + ) + def close(self) -> None: """Close the file""" pass diff --git a/python/kvikio/pyproject.toml b/python/kvikio/pyproject.toml index 04f04cfa6f..25a961a858 100644 --- a/python/kvikio/pyproject.toml +++ b/python/kvikio/pyproject.toml @@ -39,8 +39,10 @@ classifiers = [ [project.optional-dependencies] test = [ + "boto3>=1.21.21", "cuda-python>=11.7.1,<12.0a0", "dask>=2022.05.2", + "moto[server]>=4.0.8", "pytest", "pytest-cov", "rangehttpserver", @@ -140,4 +142,5 @@ regex = "(?P.*)" filterwarnings = [ "error", "ignore:Jitify is performing a one-time only warm-up to populate the persistent cache", + "ignore::DeprecationWarning:botocore.*", ] diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py new file mode 100644 index 0000000000..2daab28700 --- /dev/null +++ b/python/kvikio/tests/test_s3_io.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import multiprocessing as mp +import socket +import time +from contextlib import contextmanager + +import pytest + +import kvikio +import kvikio.defaults + +pytestmark = pytest.mark.skipif( + not kvikio.is_remote_file_available(), + reason=( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ), +) + +# Notice, we import boto and moto after the `is_remote_file_available` check. +import boto3 # noqa: E402 +import moto # noqa: E402 +import moto.server # noqa: E402 + + +@pytest.fixture(scope="session") +def endpoint_ip(): + return "127.0.0.1" + + +@pytest.fixture(scope="session") +def endpoint_port(): + # Return a free port per worker session. + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(ip_address, port): + server = moto.server.ThreadedMotoServer(ip_address=ip_address, port=port) + server.start() + time.sleep(600) + print("ThreadedMotoServer shutting down because of timeout (10min)") + + +@pytest.fixture(scope="session") +def s3_base(endpoint_ip, endpoint_port): + """Fixture to set up moto server in separate process""" + with pytest.MonkeyPatch.context() as monkeypatch: + # Use fake aws credentials + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "foobar_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "foobar_secret") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + monkeypatch.setenv("AWS_ENDPOINT_URL", f"http://{endpoint_ip}:{endpoint_port}") + + p = mp.Process(target=start_s3_server, args=(endpoint_ip, endpoint_port)) + p.start() + yield f"http://{endpoint_ip}:{endpoint_port}" + p.kill() + + +@contextmanager +def s3_context(s3_base, bucket, files=None): + if files is None: + files = {} + client = boto3.client("s3", endpoint_url=s3_base) + client.create_bucket(Bucket=bucket, ACL="public-read-write") + for f, data in files.items(): + client.put_object(Bucket=bucket, Key=f, Body=data) + yield s3_base + for f, data in files.items(): + try: + client.delete_object(Bucket=bucket, Key=f) + except Exception: + pass + + +@pytest.mark.parametrize("size", [10, 100, 1000]) +@pytest.mark.parametrize("nthreads", [1, 3]) +@pytest.mark.parametrize("tasksize", [99, 999]) +@pytest.mark.parametrize("buffer_size", [101, 1001]) +def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): + bucket_name = "test_read" + object_name = "a1" + a = xp.arange(size) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + with kvikio.defaults.set_num_threads(nthreads): + with kvikio.defaults.set_task_size(tasksize): + with kvikio.defaults.set_bounce_buffer_size(buffer_size): + with kvikio.RemoteFile.open_s3_url( + f"{server_address}/{bucket_name}/{object_name}" + ) as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(buf=b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + with kvikio.RemoteFile.open_s3(bucket_name, object_name) as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(buf=b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + +@pytest.mark.parametrize( + "start,end", + [ + (0, 10 * 4096), + (1, int(1.3 * 4096)), + (int(2.1 * 4096), int(5.6 * 4096)), + (42, int(2**20)), + ], +) +def test_read_with_file_offset(s3_base, xp, start, end): + bucket_name = "test_read_with_file_offset" + object_name = "a1" + a = xp.arange(end, dtype=xp.int64) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + url = f"{server_address}/{bucket_name}/{object_name}" + with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + b = xp.zeros(shape=(end - start,), dtype=xp.int64) + assert f.read(b, file_offset=start * a.itemsize) == b.nbytes + xp.testing.assert_array_equal(a[start:end], b) From ae033edeecec87139cbf937941c7089a8b494797 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Oct 2024 13:50:08 +0200 Subject: [PATCH 02/24] doc --- cpp/include/kvikio/remote_handle.hpp | 41 ++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 2188af1941..3a71d4e612 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -146,6 +146,11 @@ class HttpEndpoint : public RemoteEndpoint { std::string _url; public: + /** + * @brief Create a http endpoint from a url. + * + * @param url The full http url to the remote file. + */ HttpEndpoint(std::string url) : _url{std::move(url)} {} void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); } std::string str() override { return _url; } @@ -153,7 +158,7 @@ class HttpEndpoint : public RemoteEndpoint { }; /** - * @brief + * @brief A remote endpoint using AWS's S3 protocol. */ class S3Endpoint : public RemoteEndpoint { private: @@ -182,10 +187,7 @@ class S3Endpoint : public RemoteEndpoint { std::optional aws_endpoint_url) { std::string endpoint_url = - parse_aws_argument(std::move(aws_endpoint_url), - "AWS_ENDPOINT_URL", - "S3: must provide `aws_endpoint_url` if AWS_ENDPOINT_URL isn't set.", - true); + parse_aws_argument(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL", "", true); std::stringstream ss; if (endpoint_url.empty()) { std::string region = @@ -229,6 +231,19 @@ class S3Endpoint : public RemoteEndpoint { return std::make_pair(std::move(bucket_name), std::move(object_name)); } + /** + * @brief Create a S3 endpoint from a url. + * + * @param url The full http url to the S3 file. NB: this should an url starting with + * "http://" or "https://". If you have an S3 url of the form "s3:///", + * please use `S3Endpoint::parse_s3_url()` to convert it. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_access_key The AWS access key to use. If nullopt, the value of the + * `AWS_ACCESS_KEY_ID` environment variable is used. + * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the + * `AWS_SECRET_ACCESS_KEY` environment variable is used. + */ S3Endpoint(std::string url, std::optional aws_region = std::nullopt, std::optional aws_access_key = std::nullopt, @@ -263,6 +278,22 @@ class S3Endpoint : public RemoteEndpoint { _aws_userpwd = ss.str(); } } + + /** + * @brief Create a S3 endpoint from a bucket and object name. + * + * @param bucket_name The name of the S3 bucket. + * @param object_name The name of the S3 object. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_access_key The AWS access key to use. If nullopt, the value of the + * `AWS_ACCESS_KEY_ID` environment variable is used. + * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the + * `AWS_SECRET_ACCESS_KEY` environment variable is used. + * @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of + * the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular + * AWS url scheme is used: "https://.s3..amazonaws.com/" + */ S3Endpoint(const std::string& bucket_name, const std::string& object_name, std::optional aws_region = std::nullopt, From 351169afe6cad93a66e80afecfac16f0f98234df Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Oct 2024 14:10:48 +0200 Subject: [PATCH 03/24] cleanup --- python/kvikio/kvikio/_lib/remote_handle.pyx | 14 ++-- python/kvikio/kvikio/benchmarks/s3_io.py | 2 +- python/kvikio/kvikio/remote_file.py | 73 ++++++++++++++------- python/kvikio/tests/test_s3_io.py | 2 +- 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 357a965595..11563007dc 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -77,14 +77,15 @@ cdef class RemoteFile: return ret @classmethod - def open_s3_from_http_url( + def open_s3( cls, - url: str, + bucket_name: str, + object_name: str, nbytes: Optional[int], ): cdef RemoteFile ret = RemoteFile() cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( - _to_string(url) + _to_string(bucket_name), _to_string(object_name) ) if nbytes is None: ret._handle = make_unique[cpp_RemoteHandle](move(ep)) @@ -94,15 +95,14 @@ cdef class RemoteFile: return ret @classmethod - def open_s3( + def open_s3_from_http_url( cls, - bucket_name: str, - object_name: str, + url: str, nbytes: Optional[int], ): cdef RemoteFile ret = RemoteFile() cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( - _to_string(bucket_name), _to_string(object_name) + _to_string(url) ) if nbytes is None: ret._handle = make_unique[cpp_RemoteHandle](move(ep)) diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py index 6130885442..4311f5f012 100644 --- a/python/kvikio/kvikio/benchmarks/s3_io.py +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -91,7 +91,7 @@ def run_numpy_like(args, xp): def run() -> float: t0 = time.perf_counter() - with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + with kvikio.RemoteFile.open_s3_url(url) as f: res = f.read(recv) t1 = time.perf_counter() assert res == args.nbytes, f"IO mismatch, expected {args.nbytes} got {res}" diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 5227126278..c4f93d86d9 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -75,6 +75,26 @@ def open_s3( object_name: str, nbytes: Optional[int] = None, ) -> RemoteFile: + """Open a AWS S3 file from a bucket name and object name. + + Please make sure to set the AWS environment variables: + - `AWS_DEFAULT_REGION` + - `AWS_ACCESS_KEY_ID` + - `AWS_SECRET_ACCESS_KEY` + + Additionally, to overwrite the AWS endpoint, set `AWS_ENDPOINT_URL`. + See + + Parameters + ---------- + bucket_name + The bucket name of the file. + object_name + The object name of the file. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. + """ return RemoteFile( _get_remote_module().RemoteFile.open_s3(bucket_name, object_name, nbytes) ) @@ -85,32 +105,39 @@ def open_s3_url( url: str, nbytes: Optional[int] = None, ) -> RemoteFile: + """Open a AWS S3 file from an URL. + + The `url` can take two forms: + - A full http url such as "http://127.0.0.1/my/file", or + - A S3 url such as "s3:///". + + Please make sure to set the AWS environment variables: + - `AWS_DEFAULT_REGION` + - `AWS_ACCESS_KEY_ID` + - `AWS_SECRET_ACCESS_KEY` + + Additionally, if `url` is a S3 url, it is possible to overwrite the AWS endpoint + by setting `AWS_ENDPOINT_URL`. + See + + Parameters + ---------- + url + Either a http url or a S3 url. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. + """ url = url.lower() if url.startswith("http://") or url.startswith("https://"): - return cls.open_s3_from_http_url(url, nbytes) + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) + ) if url.startswith("s://"): - return cls.open_s3_from_s3_url(url, nbytes) - raise ValueError(f"Unsupported protocol in url: {url}") - - @classmethod - def open_s3_from_http_url( - cls, - url: str, - nbytes: Optional[int] = None, - ) -> RemoteFile: - return RemoteFile( - _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) - ) - - @classmethod - def open_s3_from_s3_url( - cls, - url: str, - nbytes: Optional[int] = None, - ) -> RemoteFile: - return RemoteFile( - _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) - ) + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) + ) + raise ValueError(f"Unsupported protocol: {url}") def close(self) -> None: """Close the file""" diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py index 2daab28700..1893d6b6d2 100644 --- a/python/kvikio/tests/test_s3_io.py +++ b/python/kvikio/tests/test_s3_io.py @@ -125,7 +125,7 @@ def test_read_with_file_offset(s3_base, xp, start, end): s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} ) as server_address: url = f"{server_address}/{bucket_name}/{object_name}" - with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + with kvikio.RemoteFile.open_s3_url(url) as f: b = xp.zeros(shape=(end - start,), dtype=xp.int64) assert f.read(b, file_offset=start * a.itemsize) == b.nbytes xp.testing.assert_array_equal(a[start:end], b) From 8e57584dd83aa74dc429d61c7439da6b271bd693 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Oct 2024 14:23:02 +0200 Subject: [PATCH 04/24] [[nodiscard]] --- cpp/include/kvikio/remote_handle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 3a71d4e612..27b957cbac 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -211,7 +211,7 @@ class S3Endpoint : public RemoteEndpoint { * @param s3_url S3 url. * @return Pair of strings: [bucket-name, object-name]. */ - static std::pair parse_s3_url(std::string const& s3_url) + [[nodiscard]] static std::pair parse_s3_url(std::string const& s3_url) { if (s3_url.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } if (s3_url.size() < 5 || s3_url.substr(0, 5) != "s3://") { @@ -234,7 +234,7 @@ class S3Endpoint : public RemoteEndpoint { /** * @brief Create a S3 endpoint from a url. * - * @param url The full http url to the S3 file. NB: this should an url starting with + * @param url The full http url to the S3 file. NB: this should be an url starting with * "http://" or "https://". If you have an S3 url of the form "s3:///", * please use `S3Endpoint::parse_s3_url()` to convert it. * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the From 488c060cd9e0e3316a20d08081605f29d68dedfa Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Oct 2024 14:25:11 +0200 Subject: [PATCH 05/24] const: going east --- cpp/include/kvikio/remote_handle.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 27b957cbac..f5d815cc5f 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -90,7 +90,7 @@ inline std::size_t callback_device_memory(char* data, void* context) { auto ctx = reinterpret_cast(context); - const std::size_t nbytes = size * nmemb; + std::size_t const nbytes = size * nmemb; if (ctx->size < ctx->offset + nbytes) { ctx->overflow_error = true; return CURL_WRITEFUNC_ERROR; @@ -167,8 +167,8 @@ class S3Endpoint : public RemoteEndpoint { std::string _aws_userpwd; static std::string parse_aws_argument(std::optional aws_arg, - const std::string& env_var, - const std::string& err_msg, + std::string const& env_var, + std::string const& err_msg, bool allow_empty = false) { if (aws_arg.has_value()) { return std::move(*aws_arg); } @@ -181,9 +181,9 @@ class S3Endpoint : public RemoteEndpoint { return std::string(env); } - static std::string url_from_bucket_and_object(const std::string& bucket_name, - const std::string& object_name, - const std::optional& aws_region, + static std::string url_from_bucket_and_object(std::string const& bucket_name, + std::string const& object_name, + std::optional const& aws_region, std::optional aws_endpoint_url) { std::string endpoint_url = @@ -294,8 +294,8 @@ class S3Endpoint : public RemoteEndpoint { * the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular * AWS url scheme is used: "https://.s3..amazonaws.com/" */ - S3Endpoint(const std::string& bucket_name, - const std::string& object_name, + S3Endpoint(std::string const& bucket_name, + std::string const& object_name, std::optional aws_region = std::nullopt, std::optional aws_access_key = std::nullopt, std::optional aws_secret_access_key = std::nullopt, @@ -396,7 +396,7 @@ class RemoteHandle { << " bytes file (" << _endpoint->str() << ")"; throw std::invalid_argument(ss.str()); } - const bool is_host_mem = is_host_memory(buf); + bool const is_host_mem = is_host_memory(buf); auto curl = create_curl_handle(); _endpoint->setopt(curl); From 70dc29cd31896f3b0b3602fb843b04a72943696d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 9 Oct 2024 14:31:00 +0200 Subject: [PATCH 06/24] clean up benchmark --- python/kvikio/kvikio/benchmarks/s3_io.py | 29 +++--------------------- python/kvikio/tests/test_benchmarks.py | 29 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py index 4311f5f012..668af8e550 100644 --- a/python/kvikio/kvikio/benchmarks/s3_io.py +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -102,32 +102,9 @@ def run() -> float: yield run() -def run_cudf(args, libcudf_s3_io: bool): - import cudf - - cudf.set_option("libcudf_s3_io", libcudf_s3_io) - - # Upload data to S3 server - create_client_and_bucket() - data = cupy.random.rand(args.nelem).astype(args.dtype) - df = cudf.DataFrame({"a": data}) - df.to_parquet(f"s3://{args.bucket}/data1") - - def run() -> float: - t0 = time.perf_counter() - cudf.read_parquet(f"s3://{args.bucket}/data1") - t1 = time.perf_counter() - return t1 - t0 - - for _ in range(args.nruns): - yield run() - - API = { - "cupy-kvikio": partial(run_numpy_like, xp=cupy), - "numpy-kvikio": partial(run_numpy_like, xp=numpy), - "cudf-kvikio": partial(run_cudf, libcudf_s3_io=True), - "cudf-fsspec": partial(run_cudf, libcudf_s3_io=False), + "cupy": partial(run_numpy_like, xp=cupy), + "numpy": partial(run_numpy_like, xp=numpy), } @@ -138,7 +115,7 @@ def main(args): os.environ["KVIKIO_NTHREADS"] = str(args.nthreads) kvikio.defaults.num_threads_reset(args.nthreads) - print("Roundtrip benchmark") + print("Remote S3 benchmark") print("--------------------------------------") print(f"nelem | {args.nelem} ({format_bytes(args.nbytes)})") print(f"dtype | {args.dtype}") diff --git a/python/kvikio/tests/test_benchmarks.py b/python/kvikio/tests/test_benchmarks.py index 5b5602e53a..307b0b258d 100644 --- a/python/kvikio/tests/test_benchmarks.py +++ b/python/kvikio/tests/test_benchmarks.py @@ -109,3 +109,32 @@ def test_http_io(run_cmd, api): cwd=benchmarks_path, ) assert retcode == 0 + + +@pytest.mark.parametrize( + "api", + [ + "cupy", + "numpy", + ], +) +def test_s3_io(run_cmd, api): + """Test benchmarks/s3_io.py""" + + if not kvikio.is_remote_file_available(): + pytest.skip( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ) + retcode = run_cmd( + cmd=[ + sys.executable, + "http_io.py", + "-n", + "1000", + "--api", + api, + ], + cwd=benchmarks_path, + ) + assert retcode == 0 From 5f488996c2d5d5df655fefcc60bdd0d69bcea03d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Oct 2024 08:39:23 +0200 Subject: [PATCH 07/24] Apply suggestions from code review Co-authored-by: Vyas Ramasubramani --- cpp/include/kvikio/remote_handle.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index f5d815cc5f..7d876ab244 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -186,11 +186,11 @@ class S3Endpoint : public RemoteEndpoint { std::optional const& aws_region, std::optional aws_endpoint_url) { - std::string endpoint_url = + auto const endpoint_url = parse_aws_argument(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL", "", true); std::stringstream ss; if (endpoint_url.empty()) { - std::string region = + auto const region = parse_aws_argument(std::move(aws_region), "AWS_DEFAULT_REGION", "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); @@ -250,17 +250,17 @@ class S3Endpoint : public RemoteEndpoint { std::optional aws_secret_access_key = std::nullopt) : _url{std::move(url)} { - std::string region = + auto const region = parse_aws_argument(std::move(aws_region), "AWS_DEFAULT_REGION", "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); - std::string access_key = + auto const access_key = parse_aws_argument(std::move(aws_access_key), "AWS_ACCESS_KEY_ID", "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); - std::string secret_access_key = parse_aws_argument( + auto const secret_access_key = parse_aws_argument( std::move(aws_secret_access_key), "AWS_SECRET_ACCESS_KEY", "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set."); From 66460f8aee014ea39379d80099ba2fa9e686fbaf Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Oct 2024 08:40:11 +0200 Subject: [PATCH 08/24] Use regex Co-authored-by: Vyas Ramasubramani --- cpp/include/kvikio/remote_handle.hpp | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 7d876ab244..4b5a5acc02 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -213,22 +214,11 @@ class S3Endpoint : public RemoteEndpoint { */ [[nodiscard]] static std::pair parse_s3_url(std::string const& s3_url) { - if (s3_url.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } - if (s3_url.size() < 5 || s3_url.substr(0, 5) != "s3://") { - throw std::invalid_argument("The S3 url must start with the S3 scheme (\"s3://\")."); - } - std::string p = s3_url.substr(5); - if (p.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } - size_t pos = p.find_first_of('/'); - std::string bucket_name = p.substr(0, pos); - if (bucket_name.empty()) { - throw std::invalid_argument("The S3 url does not contain a bucket name."); - } - std::string object_name = (pos == std::string::npos) ? "" : p.substr(pos + 1); - if (object_name.empty()) { - throw std::invalid_argument("The S3 url does not contain an object name."); - } - return std::make_pair(std::move(bucket_name), std::move(object_name)); + // Regular expression to match s3:/// + std::regex pattern{R"(s3://([^/]+)/(.+))"}; + std::smatch matches; + if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; } + throw std::invalid_argument("Input string does not match the expected S3 URL format."); } /** From 84007c1b797919c6a6dd806793541246fcdea47b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Oct 2024 10:28:55 +0200 Subject: [PATCH 09/24] cleanup --- cpp/include/kvikio/remote_handle.hpp | 32 +++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 4b5a5acc02..3e0518d05d 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -167,28 +167,50 @@ class S3Endpoint : public RemoteEndpoint { std::string _aws_sigv4; std::string _aws_userpwd; + /** + * @brief Parse a AWS argument such as `aws_region` or `aws_access_key`. + * + * If not nullopt, the optional's value is returned otherwise the environment + * variable `env_var` is used. If that also doesn't have a value: + * - if `err_msg` is empty, the empty string is returned. + * - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown. + * + * @param aws_arg The AWS argument to parse. + * @param env_var The name of the environment variable to check if `aws_arg` isn't set. + * @param err_msg The error message to throw on error or the empty string. + * @return The parsed AWS argument or the empty string. + */ static std::string parse_aws_argument(std::optional aws_arg, std::string const& env_var, - std::string const& err_msg, - bool allow_empty = false) + std::string const& err_msg = "") { if (aws_arg.has_value()) { return std::move(*aws_arg); } char const* env = std::getenv(env_var.c_str()); if (env == nullptr) { - if (allow_empty) { return std::string(); } + if (err_msg.empty()) { return std::string(); } throw std::invalid_argument(err_msg); } return std::string(env); } + /** + * @brief Get url from a AWS S3 bucket and object name. + * + * @param bucket_name The name of the S3 bucket. + * @param object_name The name of the S3 object. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of + * the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular + * AWS url scheme is used: "https://.s3..amazonaws.com/" + */ static std::string url_from_bucket_and_object(std::string const& bucket_name, std::string const& object_name, std::optional const& aws_region, std::optional aws_endpoint_url) { - auto const endpoint_url = - parse_aws_argument(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL", "", true); + auto const endpoint_url = parse_aws_argument(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL"); std::stringstream ss; if (endpoint_url.empty()) { auto const region = From 6104106ec4ef725640c32775f1a5a1637ad29faa Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 10:27:13 +0200 Subject: [PATCH 10/24] RemoteFile._from_endpoint --- python/kvikio/kvikio/_lib/remote_handle.pyx | 88 +++++++++++---------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 32cfe4dea9..29f1af9056 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -57,20 +57,27 @@ cdef string _to_string(str s): else: return string() +# Help function to cast an endpoint to its base class `RemoteEndpoint` +cdef extern from *: + """ + template + std::unique_ptr cast_to_remote_endpoint(T endpoint) + { + return std::move(endpoint); + } + """ + cdef unique_ptr[cpp_RemoteEndpoint] cast_to_remote_endpoint[T](T handle) except + + cdef class RemoteFile: cdef unique_ptr[cpp_RemoteHandle] _handle - @classmethod - def open_http( - cls, - url: str, + @staticmethod + cdef RemoteFile _from_endpoint( + unique_ptr[cpp_RemoteEndpoint] ep, nbytes: Optional[int], ): cdef RemoteFile ret = RemoteFile() - cdef unique_ptr[cpp_HttpEndpoint] ep = make_unique[cpp_HttpEndpoint]( - _to_string(url) - ) if nbytes is None: ret._handle = make_unique[cpp_RemoteHandle](move(ep)) return ret @@ -78,58 +85,59 @@ cdef class RemoteFile: ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) return ret - @classmethod + @staticmethod + def open_http( + url: str, + nbytes: Optional[int], + ): + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_HttpEndpoint](_to_string(url)) + ), + nbytes + ) + + @staticmethod def open_s3( - cls, bucket_name: str, object_name: str, nbytes: Optional[int], ): - cdef RemoteFile ret = RemoteFile() - cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( - _to_string(bucket_name), _to_string(object_name) + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint]( + _to_string(bucket_name), _to_string(object_name) + ) + ), + nbytes ) - if nbytes is None: - ret._handle = make_unique[cpp_RemoteHandle](move(ep)) - return ret - cdef size_t n = nbytes - ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) - return ret - @classmethod + @staticmethod def open_s3_from_http_url( - cls, url: str, nbytes: Optional[int], ): - cdef RemoteFile ret = RemoteFile() - cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( - _to_string(url) + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint](_to_string(url)) + ), + nbytes ) - if nbytes is None: - ret._handle = make_unique[cpp_RemoteHandle](move(ep)) - return ret - cdef size_t n = nbytes - ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) - return ret - @classmethod + @staticmethod def open_s3_from_s3_url( - cls, url: str, nbytes: Optional[int], ): cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url)) - cdef RemoteFile ret = RemoteFile() - cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( - bucket_and_object.first, bucket_and_object.second + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint]( + bucket_and_object.first, bucket_and_object.second + ) + ), + nbytes ) - if nbytes is None: - ret._handle = make_unique[cpp_RemoteHandle](move(ep)) - return ret - cdef size_t n = nbytes - ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) - return ret def nbytes(self) -> int: return deref(self._handle).nbytes() From 8595c0114648d0d81e4db29fbce946bddf13c411 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 10:52:33 +0200 Subject: [PATCH 11/24] Apply suggestions from code review Co-authored-by: Lawrence Mitchell Co-authored-by: Vyas Ramasubramani --- cpp/include/kvikio/remote_handle.hpp | 11 +++++++---- python/kvikio/kvikio/_lib/remote_handle.pyx | 2 -- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 3e0518d05d..e14bc5af10 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -148,7 +148,7 @@ class HttpEndpoint : public RemoteEndpoint { public: /** - * @brief Create a http endpoint from a url. + * @brief Create an http endpoint from a url. * * @param url The full http url to the remote file. */ @@ -168,15 +168,15 @@ class S3Endpoint : public RemoteEndpoint { std::string _aws_userpwd; /** - * @brief Parse a AWS argument such as `aws_region` or `aws_access_key`. + * @brief Unwrap an optional parameter, obtaining a default from the environment. * * If not nullopt, the optional's value is returned otherwise the environment * variable `env_var` is used. If that also doesn't have a value: * - if `err_msg` is empty, the empty string is returned. * - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown. * - * @param aws_arg The AWS argument to parse. - * @param env_var The name of the environment variable to check if `aws_arg` isn't set. + * @param value The value to unwrap. + * @param env_var The name of the environment variable to check if `value` isn't set. * @param err_msg The error message to throw on error or the empty string. * @return The parsed AWS argument or the empty string. */ @@ -197,6 +197,9 @@ class S3Endpoint : public RemoteEndpoint { /** * @brief Get url from a AWS S3 bucket and object name. * + * @throws std::invalid_argument if no region is specified and no default region is + * specified in the environment. + * * @param bucket_name The name of the S3 bucket. * @param object_name The name of the S3 object. * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 29f1af9056..5406b32413 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -25,8 +25,6 @@ cdef extern from "" nogil: cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": cpp_S3Endpoint(string url) except + - - cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": cpp_S3Endpoint(string bucket_name, string object_name) except + pair[string, string] cpp_parse_s3_url \ From b50835a2794365fae6c8ac6988e32b38c5d5aef5 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 10:56:03 +0200 Subject: [PATCH 12/24] rename to unwrap_or_default --- cpp/include/kvikio/remote_handle.hpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index e14bc5af10..19809eb456 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -180,9 +180,9 @@ class S3Endpoint : public RemoteEndpoint { * @param err_msg The error message to throw on error or the empty string. * @return The parsed AWS argument or the empty string. */ - static std::string parse_aws_argument(std::optional aws_arg, - std::string const& env_var, - std::string const& err_msg = "") + static std::string unwrap_or_default(std::optional aws_arg, + std::string const& env_var, + std::string const& err_msg = "") { if (aws_arg.has_value()) { return std::move(*aws_arg); } @@ -213,13 +213,13 @@ class S3Endpoint : public RemoteEndpoint { std::optional const& aws_region, std::optional aws_endpoint_url) { - auto const endpoint_url = parse_aws_argument(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL"); + auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL"); std::stringstream ss; if (endpoint_url.empty()) { auto const region = - parse_aws_argument(std::move(aws_region), - "AWS_DEFAULT_REGION", - "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + unwrap_or_default(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); // We default to the official AWS url scheme. ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name; } else { @@ -266,16 +266,16 @@ class S3Endpoint : public RemoteEndpoint { : _url{std::move(url)} { auto const region = - parse_aws_argument(std::move(aws_region), - "AWS_DEFAULT_REGION", - "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + unwrap_or_default(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); auto const access_key = - parse_aws_argument(std::move(aws_access_key), - "AWS_ACCESS_KEY_ID", - "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); + unwrap_or_default(std::move(aws_access_key), + "AWS_ACCESS_KEY_ID", + "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); - auto const secret_access_key = parse_aws_argument( + auto const secret_access_key = unwrap_or_default( std::move(aws_secret_access_key), "AWS_SECRET_ACCESS_KEY", "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set."); From c6a416fc631827d345c6b31d4fb001395f58dd6b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 11:47:11 +0200 Subject: [PATCH 13/24] const --- cpp/include/kvikio/remote_handle.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 19809eb456..53af35290f 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -134,7 +134,7 @@ class RemoteEndpoint { * * @returns A string description. */ - virtual std::string str() = 0; + virtual std::string str() const = 0; virtual ~RemoteEndpoint() = default; }; @@ -154,7 +154,7 @@ class HttpEndpoint : public RemoteEndpoint { */ HttpEndpoint(std::string url) : _url{std::move(url)} {} void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); } - std::string str() override { return _url; } + std::string str() const override { return _url; } ~HttpEndpoint() override = default; }; @@ -329,7 +329,7 @@ class S3Endpoint : public RemoteEndpoint { curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str()); curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str()); } - std::string str() override { return _url; } + std::string str() const override { return _url; } ~S3Endpoint() override = default; }; From f81de011b0b029068c8b105cb45609737bb6cd68 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 11:54:12 +0200 Subject: [PATCH 14/24] expose a __str__ of an endpoint --- cpp/include/kvikio/remote_handle.hpp | 7 +++++++ python/kvikio/kvikio/_lib/remote_handle.pyx | 11 ++++++++--- python/kvikio/kvikio/remote_file.py | 3 +++ python/kvikio/tests/test_http_io.py | 4 ++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 53af35290f..203ffd06a9 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -393,6 +393,13 @@ class RemoteHandle { */ [[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; } + /** + * @brief Get a const reference to the underlying remote endpoint. + * + * @return The remote endpoint. + */ + [[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; } + /** * @brief Read from remote source into buffer (host or device memory). * diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 5406b32413..b6455b462f 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -18,12 +18,12 @@ from kvikio._lib.future cimport IOFuture, _wrap_io_future, future cdef extern from "" nogil: cdef cppclass cpp_RemoteEndpoint "kvikio::RemoteEndpoint": - pass + string str() except + - cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint": + cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint"(cpp_RemoteEndpoint): cpp_HttpEndpoint(string url) except + - cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint"(cpp_RemoteEndpoint): cpp_S3Endpoint(string url) except + cpp_S3Endpoint(string bucket_name, string object_name) except + @@ -36,6 +36,7 @@ cdef extern from "" nogil: ) except + cpp_RemoteHandle(unique_ptr[cpp_RemoteEndpoint] endpoint) except + int nbytes() except + + const cpp_RemoteEndpoint& endpoint() except + size_t read( void* buf, size_t size, @@ -137,6 +138,10 @@ cdef class RemoteFile: nbytes ) + def __str__(self) -> str: + cdef string ep_str = deref(self._handle).endpoint().str() + return f'<{self.__class__.__name__} "{ep_str.decode()}">' + def nbytes(self) -> int: return deref(self._handle).nbytes() diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index c4f93d86d9..5cd45d3c32 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -149,6 +149,9 @@ def __enter__(self) -> RemoteFile: def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + def __str__(self) -> str: + return str(self._handle) + def nbytes(self) -> int: """Get the file size. diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py index 70abec71b6..5c2c3888cd 100644 --- a/python/kvikio/tests/test_http_io.py +++ b/python/kvikio/tests/test_http_io.py @@ -47,6 +47,7 @@ def test_read(http_server, tmpdir, xp, size, nthreads, tasksize): with kvikio.defaults.set_task_size(tasksize): with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) b = xp.empty_like(a) assert f.read(b) == a.nbytes xp.testing.assert_array_equal(a, b) @@ -60,6 +61,7 @@ def test_large_read(http_server, tmpdir, xp, nthreads): with kvikio.defaults.set_num_threads(nthreads): with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) b = xp.empty_like(a) assert f.read(b) == a.nbytes xp.testing.assert_array_equal(a, b) @@ -71,6 +73,7 @@ def test_error_too_small_file(http_server, tmpdir, xp): a.tofile(tmpdir / "a") with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) with pytest.raises( ValueError, match=r"cannot read 0\+100 bytes into a 10 bytes file" ): @@ -88,6 +91,7 @@ def test_no_range_support(http_server, tmpdir, xp): b = xp.empty_like(a) with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) with pytest.raises( OverflowError, match="maybe the server doesn't support file ranges?" ): From 718f291fe03274477c5de09bf102a52f97095ab9 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 12:08:32 +0200 Subject: [PATCH 15/24] fix typo --- python/kvikio/kvikio/remote_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 5cd45d3c32..f10f4b49f9 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -133,7 +133,7 @@ def open_s3_url( return RemoteFile( _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) ) - if url.startswith("s://"): + if url.startswith("s3://"): return RemoteFile( _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) ) From 34ffb03d753024efb5750ab17dc7d7288ff4255f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 12:17:24 +0200 Subject: [PATCH 16/24] benchmark: show url --- python/kvikio/kvikio/benchmarks/s3_io.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py index 668af8e550..335d4a8f65 100644 --- a/python/kvikio/kvikio/benchmarks/s3_io.py +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -86,8 +86,7 @@ def run_numpy_like(args, xp): client = create_client_and_bucket() client.put_object(Bucket=args.bucket, Key="data", Body=bytes(data)) - server_address = os.environ["AWS_ENDPOINT_URL"] - url = f"{server_address}/{args.bucket}/data" + url = f"s3://{args.bucket}/data" def run() -> float: t0 = time.perf_counter() @@ -121,7 +120,7 @@ def main(args): print(f"dtype | {args.dtype}") print(f"nthreads | {args.nthreads}") print(f"nruns | {args.nruns}") - print(f"server | {os.getenv('AWS_ENDPOINT_URL', 'http://*.amazonaws.com')}") + print(f"file | s3://{args.bucket}/data") if args.use_bundled_server: print("--------------------------------------") print("Using the bundled local server is slow") From de151c01881563a45ab9045ca4710548eec9843c Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 12:33:11 +0200 Subject: [PATCH 17/24] benchmark clean up --- python/kvikio/kvikio/benchmarks/s3_io.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py index 335d4a8f65..6133cecb95 100644 --- a/python/kvikio/kvikio/benchmarks/s3_io.py +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -61,17 +61,15 @@ def local_s3_server(lifetime: int): def create_client_and_bucket(): client = boto3.client("s3", endpoint_url=os.getenv("AWS_ENDPOINT_URL", None)) try: - client.create_bucket(Bucket=args.bucket, ACL="public-read-write") - except ( - client.exceptions.BucketAlreadyOwnedByYou, - client.exceptions.BucketAlreadyExists, - ): - pass + bucket_names = {bucket["Name"] for bucket in client.list_buckets()["Buckets"]} + if args.bucket not in bucket_names: + client.create_bucket(Bucket=args.bucket, ACL="public-read-write") except Exception: print( "Problem accessing the S3 server? using wrong credentials? Try setting " - "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and/or AWS_ENDPOINT_URL. " - "Alternatively, use the bundled server `--use-bundled-server`\n", + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and/or AWS_ENDPOINT_URL. Also, " + "if the bucket doesn't exist, make sure you have the required permission. " + "Alternatively, use the bundled server `--use-bundled-server`:\n", file=sys.stderr, flush=True, ) From 9489bef0e65e0a3a6f37a267e0d0c8bc9b43dbf2 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 13:32:50 +0200 Subject: [PATCH 18/24] doc --- cpp/include/kvikio/remote_handle.hpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 203ffd06a9..98512ef704 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -204,9 +204,10 @@ class S3Endpoint : public RemoteEndpoint { * @param object_name The name of the S3 object. * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the * `AWS_DEFAULT_REGION` environment variable is used. - * @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of - * the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular - * AWS url scheme is used: "https://.s3..amazonaws.com/" + * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using + * the scheme: "//". If nullopt, the value of the + * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS + * url scheme is used: "https://.s3..amazonaws.com/". */ static std::string url_from_bucket_and_object(std::string const& bucket_name, std::string const& object_name, @@ -305,9 +306,10 @@ class S3Endpoint : public RemoteEndpoint { * `AWS_ACCESS_KEY_ID` environment variable is used. * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the * `AWS_SECRET_ACCESS_KEY` environment variable is used. - * @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of - * the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular - * AWS url scheme is used: "https://.s3..amazonaws.com/" + * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using + * the scheme: "//". If nullopt, the value of the + * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS + * url scheme is used: "https://.s3..amazonaws.com/". */ S3Endpoint(std::string const& bucket_name, std::string const& object_name, From fc371921b7e616c9282c4b74882d25e009a452a8 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 13:37:08 +0200 Subject: [PATCH 19/24] doc --- cpp/include/kvikio/remote_handle.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 98512ef704..ca19470cd0 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -194,6 +194,7 @@ class S3Endpoint : public RemoteEndpoint { return std::string(env); } + public: /** * @brief Get url from a AWS S3 bucket and object name. * @@ -229,7 +230,6 @@ class S3Endpoint : public RemoteEndpoint { return ss.str(); } - public: /** * @brief Given an url like "s3:///", return the name of the bucket and object. * @@ -251,8 +251,8 @@ class S3Endpoint : public RemoteEndpoint { * @brief Create a S3 endpoint from a url. * * @param url The full http url to the S3 file. NB: this should be an url starting with - * "http://" or "https://". If you have an S3 url of the form "s3:///", - * please use `S3Endpoint::parse_s3_url()` to convert it. + * "http://" or "https://". If you have an S3 url of the form "s3:///", please + * use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it. * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the * `AWS_DEFAULT_REGION` environment variable is used. * @param aws_access_key The AWS access key to use. If nullopt, the value of the From 52f05953706208873ae66c593659b0eed23c520f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 16:53:11 +0200 Subject: [PATCH 20/24] check url --- cpp/include/kvikio/remote_handle.hpp | 8 +++++- python/kvikio/tests/test_s3_io.py | 40 +++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index ca19470cd0..826d2e2cc9 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -241,7 +241,7 @@ class S3Endpoint : public RemoteEndpoint { [[nodiscard]] static std::pair parse_s3_url(std::string const& s3_url) { // Regular expression to match s3:/// - std::regex pattern{R"(s3://([^/]+)/(.+))"}; + std::regex pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase}; std::smatch matches; if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; } throw std::invalid_argument("Input string does not match the expected S3 URL format."); @@ -266,6 +266,12 @@ class S3Endpoint : public RemoteEndpoint { std::optional aws_secret_access_key = std::nullopt) : _url{std::move(url)} { + // Regular expression to match http[s]:// + std::regex pattern{R"(^https?://.*)", std::regex_constants::icase}; + if (!std::regex_search(_url, pattern)) { + throw std::invalid_argument("url must start with http:// or https://"); + } + auto const region = unwrap_or_default(std::move(aws_region), "AWS_DEFAULT_REGION", diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py index 1893d6b6d2..1f2bae95d0 100644 --- a/python/kvikio/tests/test_s3_io.py +++ b/python/kvikio/tests/test_s3_io.py @@ -79,6 +79,40 @@ def s3_context(s3_base, bucket, files=None): pass +def test_read_access(s3_base): + bucket_name = "bucket" + object_name = "data" + data = b"file content" + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(data)} + ) as server_address: + with kvikio.RemoteFile.open_s3_url(f"s3://{bucket_name}/{object_name}") as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with kvikio.RemoteFile.open_s3(bucket_name, object_name) as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with kvikio.RemoteFile.open_s3_url( + f"{server_address}/{bucket_name}/{object_name}" + ) as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with pytest.raises(ValueError, match="Unsupported protocol"): + kvikio.RemoteFile.open_s3_url(f"unknown://{bucket_name}/{object_name}") + + with pytest.raises(RuntimeError, match="URL returned error: 404"): + kvikio.RemoteFile.open_s3("unknown-bucket", object_name) + + with pytest.raises(RuntimeError, match="URL returned error: 404"): + kvikio.RemoteFile.open_s3(bucket_name, "unknown-file") + + @pytest.mark.parametrize("size", [10, 100, 1000]) @pytest.mark.parametrize("nthreads", [1, 3]) @pytest.mark.parametrize("tasksize", [99, 999]) @@ -101,12 +135,6 @@ def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): assert f.read(buf=b) == a.nbytes xp.testing.assert_array_equal(a, b) - with kvikio.RemoteFile.open_s3(bucket_name, object_name) as f: - assert f.nbytes() == a.nbytes - b = xp.empty_like(a) - assert f.read(buf=b) == a.nbytes - xp.testing.assert_array_equal(a, b) - @pytest.mark.parametrize( "start,end", From 376c918370e5c39beea8eaa430198269cce2a220 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 21 Oct 2024 17:13:58 +0200 Subject: [PATCH 21/24] doc --- cpp/include/kvikio/remote_handle.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 826d2e2cc9..a01e9cf14f 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -294,6 +294,9 @@ class S3Endpoint : public RemoteEndpoint { _aws_sigv4 = ss.str(); } // Create the CURLOPT_USERPWD option + // Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT set + // over the wire. See + // { std::stringstream ss; ss << access_key << ":" << secret_access_key; From 7136a0c7c288e4cbc9f6530222451a31d3b7c68f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 22 Oct 2024 13:31:08 +0200 Subject: [PATCH 22/24] cleanup --- python/kvikio/kvikio/benchmarks/s3_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py index 6133cecb95..7941462650 100644 --- a/python/kvikio/kvikio/benchmarks/s3_io.py +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -135,7 +135,7 @@ def main(args): def pprint_api_res(name, samples): samples = [args.nbytes / s for s in samples] # Convert to throughput mean = statistics.harmonic_mean(samples) if len(samples) > 1 else samples[0] - ret = f"{api}-{name}".ljust(18) + ret = f"{api}-{name}".ljust(12) ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) if len(samples) > 1: stdev = statistics.stdev(samples) / mean * 100 @@ -203,7 +203,7 @@ def pprint_api_res(name, samples): parser.add_argument( "--api", metavar="API", - default=list(API.keys())[0], # defaults to the first API + default="all", nargs="+", choices=tuple(API.keys()) + ("all",), help="List of APIs to use {%(choices)s} (default: %(default)s).", From 938151af35abafc5f70ef49bb9d07a66cb71a396 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 22 Oct 2024 20:12:15 +0200 Subject: [PATCH 23/24] Apply suggestions from code review Co-authored-by: Vyas Ramasubramani --- cpp/include/kvikio/remote_handle.hpp | 4 ++-- python/kvikio/kvikio/_lib/remote_handle.pyx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index a01e9cf14f..8dcdb6a815 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -170,7 +170,7 @@ class S3Endpoint : public RemoteEndpoint { /** * @brief Unwrap an optional parameter, obtaining a default from the environment. * - * If not nullopt, the optional's value is returned otherwise the environment + * If not nullopt, the optional's value is returned. Otherwise, the environment * variable `env_var` is used. If that also doesn't have a value: * - if `err_msg` is empty, the empty string is returned. * - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown. @@ -241,7 +241,7 @@ class S3Endpoint : public RemoteEndpoint { [[nodiscard]] static std::pair parse_s3_url(std::string const& s3_url) { // Regular expression to match s3:/// - std::regex pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase}; + std::regex const pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase}; std::smatch matches; if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; } throw std::invalid_argument("Input string does not match the expected S3 URL format."); diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index b6455b462f..1e0b14acb9 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -56,7 +56,7 @@ cdef string _to_string(str s): else: return string() -# Help function to cast an endpoint to its base class `RemoteEndpoint` +# Helper function to cast an endpoint to its base class `RemoteEndpoint` cdef extern from *: """ template From 0a90950930f4cc6e437477ea4879194dbcf73542 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 22 Oct 2024 20:17:33 +0200 Subject: [PATCH 24/24] doc --- cpp/include/kvikio/remote_handle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 8dcdb6a815..809500f663 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -294,8 +294,8 @@ class S3Endpoint : public RemoteEndpoint { _aws_sigv4 = ss.str(); } // Create the CURLOPT_USERPWD option - // Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT set - // over the wire. See + // Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT included + // in the http header. See // { std::stringstream ss;