Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote IO: S3 support #479

Merged
merged 27 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -18,6 +19,7 @@ dependencies:
- doxygen=1.9.1
- gcc_linux-aarch64=11.*
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -20,6 +21,7 @@ dependencies:
- libcufile-dev=1.4.0.31
- libcufile=1.4.0.31
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-aarch64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
197 changes: 193 additions & 4 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -89,7 +91,7 @@ inline std::size_t callback_device_memory(char* data,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
const std::size_t nbytes = size * nmemb;
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
Expand Down Expand Up @@ -132,7 +134,7 @@ class RemoteEndpoint {
*
* @returns A string description.
*/
virtual std::string str() = 0;
virtual std::string str() const = 0;

virtual ~RemoteEndpoint() = default;
};
Expand All @@ -145,12 +147,192 @@ class HttpEndpoint : public RemoteEndpoint {
std::string _url;

public:
/**
* @brief Create an http endpoint from a url.
*
* @param url The full http url to the remote file.
*/
HttpEndpoint(std::string url) : _url{std::move(url)} {}
void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); }
std::string str() override { return _url; }
std::string str() const override { return _url; }
~HttpEndpoint() override = default;
};

/**
* @brief A remote endpoint using AWS's S3 protocol.
*/
class S3Endpoint : public RemoteEndpoint {
private:
std::string _url;
std::string _aws_sigv4;
std::string _aws_userpwd;

/**
* @brief Unwrap an optional parameter, obtaining a default from the environment.
*
* If not nullopt, the optional's value is returned otherwise the environment
madsbk marked this conversation as resolved.
Show resolved Hide resolved
* variable `env_var` is used. If that also doesn't have a value:
* - if `err_msg` is empty, the empty string is returned.
* - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown.
*
* @param value The value to unwrap.
* @param env_var The name of the environment variable to check if `value` isn't set.
* @param err_msg The error message to throw on error or the empty string.
vyasr marked this conversation as resolved.
Show resolved Hide resolved
* @return The parsed AWS argument or the empty string.
*/
static std::string unwrap_or_default(std::optional<std::string> aws_arg,
std::string const& env_var,
std::string const& err_msg = "")
{
if (aws_arg.has_value()) { return std::move(*aws_arg); }

char const* env = std::getenv(env_var.c_str());
if (env == nullptr) {
if (err_msg.empty()) { return std::string(); }
throw std::invalid_argument(err_msg);
}
return std::string(env);
}

/**
* @brief Get url from a AWS S3 bucket and object name.
*
* @throws std::invalid_argument if no region is specified and no default region is
* specified in the environment.
*
* @param bucket_name The name of the S3 bucket.
madsbk marked this conversation as resolved.
Show resolved Hide resolved
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of
* the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular
* AWS url scheme is used: "https://<bucket>.s3.<region>.amazonaws.com/<object>"
*/
static std::string url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> const& aws_region,
std::optional<std::string> aws_endpoint_url)
{
auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL");
std::stringstream ss;
if (endpoint_url.empty()) {
auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");
// We default to the official AWS url scheme.
ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name;
} else {
ss << endpoint_url << "/" << bucket_name << "/" << object_name;
}
return ss.str();
}

public:
/**
* @brief Given an url like "s3://<bucket>/<object>", return the name of the bucket and object.
*
* @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name.
*
* @param s3_url S3 url.
* @return Pair of strings: [bucket-name, object-name].
*/
[[nodiscard]] static std::pair<std::string, std::string> parse_s3_url(std::string const& s3_url)
{
// Regular expression to match s3://<bucket>/<object>
std::regex pattern{R"(s3://([^/]+)/(.+))"};
std::smatch matches;
if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; }
throw std::invalid_argument("Input string does not match the expected S3 URL format.");
}

/**
* @brief Create a S3 endpoint from a url.
*
* @param url The full http url to the S3 file. NB: this should be an url starting with
* "http://" or "https://". If you have an S3 url of the form "s3://<bucket>/<object>",
* please use `S3Endpoint::parse_s3_url()` to convert it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we only use https please and reject http? Or do you want that for testing?

It looks like yes. I would like this interface to be "safe" by default, and so I would like the user to have to explicitly opt in to using an unencrypted link, given that we send secrets over the wire.

Also, how does parse_s3_url help directly? That returns a std::pair not a std::string. Should one use url_from_bucket_and_object on the result?

Should we error-check and raise if the URL doesn't start with https:// or http://?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we only use https please and reject http? Or do you want that for testing?

We also wants http for high performance access to public data.

It looks like yes. I would like this interface to be "safe" by default, and so I would like the user to have to explicitly opt in to using an unencrypted link, given that we send secrets over the wire.

NB: only a time specific signature are send over the wire, curl uses aws_secret_access_key to generate the AWS authentication signature V4. Of cause, the payload is send unencrypted.

I think it is reasonable to use https by default and accept http if the user overwrite the endpoint url explicitly?

* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
*/
S3Endpoint(std::string url,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt)
: _url{std::move(url)}
{
auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");

auto const access_key =
unwrap_or_default(std::move(aws_access_key),
"AWS_ACCESS_KEY_ID",
"S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set.");

auto const secret_access_key = unwrap_or_default(
std::move(aws_secret_access_key),
"AWS_SECRET_ACCESS_KEY",
"S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set.");

// Create the CURLOPT_AWS_SIGV4 option
{
std::stringstream ss;
ss << "aws:amz:" << region << ":s3";
_aws_sigv4 = ss.str();
}
// Create the CURLOPT_USERPWD option
{
std::stringstream ss;
ss << access_key << ":" << secret_access_key;
_aws_userpwd = ss.str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have std::format in C++20...

}
}

/**
* @brief Create a S3 endpoint from a bucket and object name.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url to use. If nullopt, the value of
* the `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular
* AWS url scheme is used: "https://<bucket>.s3.<region>.amazonaws.com/<object>"
wence- marked this conversation as resolved.
Show resolved Hide resolved
*/
S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_endpoint_url = std::nullopt)
: S3Endpoint(url_from_bucket_and_object(
bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
std::move(aws_region),
std::move(aws_access_key),
std::move(aws_secret_access_key))
{
}

void setopt(CurlHandle& curl) override
{
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
}
std::string str() const override { return _url; }
~S3Endpoint() override = default;
};

/**
* @brief Handle of remote file.
*/
Expand Down Expand Up @@ -211,6 +393,13 @@ class RemoteHandle {
*/
[[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; }

/**
* @brief Get a const reference to the underlying remote endpoint.
*
* @return The remote endpoint.
*/
[[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; }

/**
* @brief Read from remote source into buffer (host or device memory).
*
Expand All @@ -229,7 +418,7 @@ class RemoteHandle {
<< " bytes file (" << _endpoint->str() << ")";
throw std::invalid_argument(ss.str());
}
const bool is_host_mem = is_host_memory(buf);
bool const is_host_mem = is_host_memory(buf);
auto curl = create_curl_handle();
_endpoint->setopt(curl);

Expand Down
7 changes: 7 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ dependencies:
- pytest
- pytest-cov
- rangehttpserver
- boto3>=1.21.21
- output_types: [requirements, pyproject]
packages:
- moto[server]>=4.0.8
- output_types: conda
packages:
- moto>=4.0.8
specific:
- output_types: [conda, requirements, pyproject]
matrices:
Expand Down
Loading