Skip to content

Commit

Permalink
Added mTLS support in flight_sql_client. Updated README.md with instr…
Browse files Browse the repository at this point in the history
…uctions on how to use the client.
  • Loading branch information
prmoore77 committed Feb 29, 2024
1 parent 775081c commit 1918799
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 16 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,36 @@ n_nationkey: [[24]]
n_name: [["UNITED STATES"]]
```

### Connecting via the new `flight_sql_client` CLI tool
You can also use the new `flight_sql_client` CLI tool to connect to the Flight SQL server, and then run a single command. This tool is built into the Docker image, and is also available as a standalone executable for Linux and MacOS.

Example (run from the host computer's terminal):
```bash
flight_sql_client \
--command Execute \
--host "localhost" \
--port 31337 \
--username "flight_username" \
--password "flight_password" \
--query "SELECT version()" \
--use-tls \
--tls-skip-verify
```

That should return:
```text
Results from endpoint 1 of 1
Schema:
version(): string
Results:
version(): [
"v0.10.0"
]
Total: 1
```

### Tear-down
Stop the docker image with:
```bash
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ pandas==2.1.*
duckdb==0.10.0
click==8.1.*
pyarrow==15.0.0
adbc-driver-flightsql==0.9.*
adbc-driver-manager==0.9.*
adbc-driver-flightsql==0.10.*
adbc-driver-manager==0.10.*
48 changes: 35 additions & 13 deletions scripts/test_flight_sql.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
import os
from time import sleep
import pyarrow
from adbc_driver_flightsql import dbapi as flight_sql, DatabaseOptions

flight_password = os.getenv("FLIGHT_PASSWORD")

with flight_sql.connect(uri="grpc+tls://localhost:31337",
db_kwargs={"username": "flight_username",
"password": flight_password,
DatabaseOptions.TLS_SKIP_VERIFY.value: "true" # Not needed if you use a trusted CA-signed TLS cert
}
) as conn:
with conn.cursor() as cur:
cur.execute("SELECT n_nationkey, n_name FROM nation WHERE n_nationkey = ?",
parameters=[24]
)
x = cur.fetch_arrow_table()
print(x)
# Setup variables
max_attempts: int = 10
sleep_interval: int = 10
flight_password = os.environ["FLIGHT_PASSWORD"]

def main():
for attempt in range(max_attempts):
try:
with flight_sql.connect(uri="grpc+tls://localhost:31337",
db_kwargs={"username": "flight_username",
"password": flight_password,
DatabaseOptions.TLS_SKIP_VERIFY.value: "true" # Not needed if you use a trusted CA-signed TLS cert
}
) as conn:
with conn.cursor() as cur:
cur.execute("SELECT n_nationkey, n_name FROM nation WHERE n_nationkey = ?",
parameters=[24]
)
x = cur.fetch_arrow_table()
print(x)
except Exception as e:
if attempt == max_attempts - 1:
raise e
else:
print(f"Attempt {attempt + 1} failed: {e}, sleeping for {sleep_interval} seconds")
sleep(sleep_interval)
else:
print("Success!")
break


if __name__ == "__main__":
main()
36 changes: 35 additions & 1 deletion src/flight_sql_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <iostream>
#include <memory>
#include <optional>
#include <sstream>
#include <fstream>

#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
Expand Down Expand Up @@ -52,7 +54,10 @@ DEFINE_int32(port, 31337, "Port to connect to");
DEFINE_string(username, "", "Username");
DEFINE_string(password, "", "Password");
DEFINE_bool(use_tls, false, "Use TLS for connection");
DEFINE_bool(tls_skip_verify, false, "Skip TLS certificate verification");
DEFINE_string(tls_roots, "", "Path to Root certificates for TLS (in PEM format)");
DEFINE_bool(tls_skip_verify, false, "Skip TLS server certificate verification");
DEFINE_string(mtls_cert_chain, "", "Path to Certificate chain (in PEM format) used for mTLS authentication - if server requires it, must be accompanied by mtls_private_key");
DEFINE_string(mtls_private_key, "", "Path to Private key (in PEM format) used for mTLS authentication - if server requires it");

DEFINE_string(command, "", "Method to run");
DEFINE_string(query, "", "Query");
Expand Down Expand Up @@ -102,6 +107,19 @@ Status PrintResults(FlightSqlClient& client, const FlightCallOptions& call_optio
return Status::OK();
}

Status getPEMCertFileContents(const std::string& cert_file_path, std::string& cert_contents) {
std::ifstream cert_file(cert_file_path);
if (!cert_file.is_open()) {
return Status::IOError("Could not open file: " + cert_file_path);
}

std::stringstream cert_stream;
cert_stream << cert_file.rdbuf();
cert_contents = cert_stream.str();

return Status::OK();
}

Status RunMain() {
ARROW_ASSIGN_OR_RAISE(auto location,
(FLAGS_use_tls)
Expand All @@ -112,8 +130,24 @@ Status RunMain() {

// Setup our options
arrow::flight::FlightClientOptions options;

if (!FLAGS_tls_roots.empty()) {
ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_tls_roots, options.tls_root_certs));
}

options.disable_server_verification = FLAGS_tls_skip_verify;

if (!FLAGS_mtls_cert_chain.empty()) {
ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_mtls_cert_chain, options.cert_chain));

if (!FLAGS_mtls_private_key.empty()) {
ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_mtls_private_key, options.private_key));
}
else {
return Status::Invalid("mTLS private key file must be provided if mTLS certificate chain is provided");
}
}

ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location, options));

FlightCallOptions call_options;
Expand Down

0 comments on commit 1918799

Please sign in to comment.