diff --git a/README.md b/README.md index 0e0747b..68a0640 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,36 @@ n_nationkey: [[24]] n_name: [["UNITED STATES"]] ``` +### Connecting via the new `flight_sql_client` CLI tool +You can also use the new `flight_sql_client` CLI tool to connect to the Flight SQL server, and then run a single command. This tool is built into the Docker image, and is also available as a standalone executable for Linux and MacOS. + +Example (run from the host computer's terminal): +```bash +flight_sql_client \ + --command Execute \ + --host "localhost" \ + --port 31337 \ + --username "flight_username" \ + --password "flight_password" \ + --query "SELECT version()" \ + --use-tls \ + --tls-skip-verify +``` + +That should return: +```text +Results from endpoint 1 of 1 +Schema: +version(): string + +Results: +version(): [ + "v0.10.0" + ] + +Total: 1 +``` + ### Tear-down Stop the docker image with: ```bash diff --git a/requirements.txt b/requirements.txt index 17768ad..0472f34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ pandas==2.1.* duckdb==0.10.0 click==8.1.* pyarrow==15.0.0 -adbc-driver-flightsql==0.9.* -adbc-driver-manager==0.9.* +adbc-driver-flightsql==0.10.* +adbc-driver-manager==0.10.* diff --git a/scripts/test_flight_sql.py b/scripts/test_flight_sql.py index df3996c..1c15811 100644 --- a/scripts/test_flight_sql.py +++ b/scripts/test_flight_sql.py @@ -1,17 +1,39 @@ import os +from time import sleep +import pyarrow from adbc_driver_flightsql import dbapi as flight_sql, DatabaseOptions -flight_password = os.getenv("FLIGHT_PASSWORD") -with flight_sql.connect(uri="grpc+tls://localhost:31337", - db_kwargs={"username": "flight_username", - "password": flight_password, - DatabaseOptions.TLS_SKIP_VERIFY.value: "true" # Not needed if you use a trusted CA-signed TLS cert - } - ) as conn: - with conn.cursor() as cur: - cur.execute("SELECT n_nationkey, n_name FROM nation WHERE n_nationkey = ?", - parameters=[24] - ) - x = cur.fetch_arrow_table() - print(x) +# Setup variables +max_attempts: int = 10 +sleep_interval: int = 10 +flight_password = os.environ["FLIGHT_PASSWORD"] + +def main(): + for attempt in range(max_attempts): + try: + with flight_sql.connect(uri="grpc+tls://localhost:31337", + db_kwargs={"username": "flight_username", + "password": flight_password, + DatabaseOptions.TLS_SKIP_VERIFY.value: "true" # Not needed if you use a trusted CA-signed TLS cert + } + ) as conn: + with conn.cursor() as cur: + cur.execute("SELECT n_nationkey, n_name FROM nation WHERE n_nationkey = ?", + parameters=[24] + ) + x = cur.fetch_arrow_table() + print(x) + except Exception as e: + if attempt == max_attempts - 1: + raise e + else: + print(f"Attempt {attempt + 1} failed: {e}, sleeping for {sleep_interval} seconds") + sleep(sleep_interval) + else: + print("Success!") + break + + +if __name__ == "__main__": + main() diff --git a/src/flight_sql_client.cpp b/src/flight_sql_client.cpp index a5d297d..b5d44cf 100644 --- a/src/flight_sql_client.cpp +++ b/src/flight_sql_client.cpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" @@ -52,7 +54,10 @@ DEFINE_int32(port, 31337, "Port to connect to"); DEFINE_string(username, "", "Username"); DEFINE_string(password, "", "Password"); DEFINE_bool(use_tls, false, "Use TLS for connection"); -DEFINE_bool(tls_skip_verify, false, "Skip TLS certificate verification"); +DEFINE_string(tls_roots, "", "Path to Root certificates for TLS (in PEM format)"); +DEFINE_bool(tls_skip_verify, false, "Skip TLS server certificate verification"); +DEFINE_string(mtls_cert_chain, "", "Path to Certificate chain (in PEM format) used for mTLS authentication - if server requires it, must be accompanied by mtls_private_key"); +DEFINE_string(mtls_private_key, "", "Path to Private key (in PEM format) used for mTLS authentication - if server requires it"); DEFINE_string(command, "", "Method to run"); DEFINE_string(query, "", "Query"); @@ -102,6 +107,19 @@ Status PrintResults(FlightSqlClient& client, const FlightCallOptions& call_optio return Status::OK(); } +Status getPEMCertFileContents(const std::string& cert_file_path, std::string& cert_contents) { + std::ifstream cert_file(cert_file_path); + if (!cert_file.is_open()) { + return Status::IOError("Could not open file: " + cert_file_path); + } + + std::stringstream cert_stream; + cert_stream << cert_file.rdbuf(); + cert_contents = cert_stream.str(); + + return Status::OK(); +} + Status RunMain() { ARROW_ASSIGN_OR_RAISE(auto location, (FLAGS_use_tls) @@ -112,8 +130,24 @@ Status RunMain() { // Setup our options arrow::flight::FlightClientOptions options; + + if (!FLAGS_tls_roots.empty()) { + ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_tls_roots, options.tls_root_certs)); + } + options.disable_server_verification = FLAGS_tls_skip_verify; + if (!FLAGS_mtls_cert_chain.empty()) { + ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_mtls_cert_chain, options.cert_chain)); + + if (!FLAGS_mtls_private_key.empty()) { + ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_mtls_private_key, options.private_key)); + } + else { + return Status::Invalid("mTLS private key file must be provided if mTLS certificate chain is provided"); + } + } + ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location, options)); FlightCallOptions call_options;