Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ArrowFlightConfig verifyServer true by default #24518

Merged
merged 2 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
public class ArrowFlightConfig
{
private String server;
private Boolean verifyServer;
private boolean verifyServer = true;
private String flightServerSSLCertificate;
private Boolean arrowFlightServerSslEnabled;
private boolean arrowFlightServerSslEnabled;
private Integer arrowFlightPort;

public String getFlightServerName()
Expand All @@ -35,13 +35,13 @@ public ArrowFlightConfig setFlightServerName(String server)
return this;
}

public Boolean getVerifyServer()
public boolean getVerifyServer()
{
return verifyServer;
}

@Config("arrow-flight.server.verify")
public ArrowFlightConfig setVerifyServer(Boolean verifyServer)
public ArrowFlightConfig setVerifyServer(boolean verifyServer)
{
this.verifyServer = verifyServer;
return this;
Expand Down Expand Up @@ -71,13 +71,13 @@ public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCer
return this;
}

public Boolean getArrowFlightServerSslEnabled()
public boolean getArrowFlightServerSslEnabled()
{
return arrowFlightServerSslEnabled;
}

@Config("arrow-flight.server-ssl-enabled")
public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled)
public ArrowFlightConfig setArrowFlightServerSslEnabled(boolean arrowFlightServerSslEnabled)
{
this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig
protected FlightClient createFlightClient()
{
Location location;
if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) {
location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort());
if (config.getArrowFlightServerSslEnabled()) {
location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort());
}
else {
location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort());
location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort());
}
return createFlightClient(location);
}
Expand All @@ -67,10 +67,8 @@ protected FlightClient createFlightClient(Location location)
try {
Optional<InputStream> trustedCertificate = Optional.empty();
FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
if (config.getVerifyServer() != null && !config.getVerifyServer()) {
flightClientBuilder.verifyServer(false);
}
else if (config.getFlightServerSSLCertificate() != null) {
flightClientBuilder.verifyServer(config.getVerifyServer());
if (config.getFlightServerSSLCertificate() != null) {
trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate())));
flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.io.File;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.testing.TestingSession.testSessionBuilder;

Expand All @@ -53,11 +54,15 @@ private static DistributedQueryRunner createQueryRunner(
throws Exception
{
Session session = testSessionBuilder()
.setCatalog("arrow")
.setCatalog("arrowflight")
.setSchema("tpch")
.build();

DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(extraProperties).build();
DistributedQueryRunner.Builder queryRunnerBuilder = DistributedQueryRunner.builder(session);
Optional<Integer> workerCount = getProperty("WORKER_COUNT").map(Integer::parseInt);
workerCount.ifPresent(queryRunnerBuilder::setNodeCount);

DistributedQueryRunner queryRunner = queryRunnerBuilder.setExtraProperties(extraProperties).build();

try {
queryRunner.installPlugin(new TestingArrowFlightPlugin());
Expand All @@ -66,10 +71,9 @@ private static DistributedQueryRunner createQueryRunner(
.putAll(catalogProperties)
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt")
.put("arrow-flight.server.verify", "true");
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt");

queryRunner.createCatalog("arrow", "arrow", properties.build());
queryRunner.createCatalog("arrowflight", "arrow-flight", properties.build());

return queryRunner;
}
Expand All @@ -78,6 +82,19 @@ private static DistributedQueryRunner createQueryRunner(
}
}

private static Optional<String> getProperty(String name)
{
String systemPropertyValue = System.getProperty(name);
if (systemPropertyValue != null) {
return Optional.of(systemPropertyValue);
}
String environmentVariableValue = System.getenv(name);
if (environmentVariableValue != null) {
return Optional.of(environmentVariableValue);
}
return Optional.empty();
}

public static void main(String[] args)
throws Exception
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ private static MapType createMapType(Type keyType, Type valueType)
private static FlightClient createFlightClient(BufferAllocator allocator) throws IOException
{
InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/server.crt")));
return FlightClient.builder(allocator, getServerLocation()).verifyServer(true).useTls().trustedCertificates(trustedCertificate).build();
return FlightClient.builder(allocator, getServerLocation()).useTls().trustedCertificates(trustedCertificate).build();
}

private void addTableToServer(FlightClient client, VectorSchemaRoot root, String tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ public class TestingArrowFlightPlugin
{
public TestingArrowFlightPlugin()
{
super("arrow", new TestingArrowModule(), new JsonModule());
super("arrow-flight", new TestingArrowModule(), new JsonModule());
}
}
Loading