Skip to content

Commit

Permalink
changes for arrow connector
Browse files Browse the repository at this point in the history
  • Loading branch information
sabbasani committed Jul 29, 2024
1 parent 13645a8 commit 6d9d75b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ public ArrowAbstractMetadata(ArrowFlightConfig config, ArrowFlightClientHandler
this.clientHandler = requireNonNull(clientHandler);
}

protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName);

protected abstract String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName);

@Override
public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName)
{
Expand All @@ -86,10 +82,6 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable
return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName());
}

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional<String> query, String schema, String table);

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema);

public List<Field> getColumnsList(String schema, String table, ConnectorSession connectorSession)
{
try {
Expand Down Expand Up @@ -256,10 +248,10 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect
ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType();
switch (floatingPoint.getPrecision()) {
case SINGLE:
meta.add(new ColumnMetadata(columnName, RealType.REAL)); // Float4
meta.add(new ColumnMetadata(columnName, RealType.REAL));
break;
case DOUBLE:
meta.add(new ColumnMetadata(columnName, DoubleType.DOUBLE)); // Float8
meta.add(new ColumnMetadata(columnName, DoubleType.DOUBLE));
break;
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision());
Expand Down Expand Up @@ -317,4 +309,12 @@ public Map<SchemaTableName, List<ColumnMetadata>> listTableColumns(ConnectorSess
}
return columns.build();
}

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional<String> query, String schema, String table);

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema);

protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName);

protected abstract String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ public ArrowAbstractSplitManager(ArrowFlightClientHandler client)
this.clientHandler = client;
}

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle);

@Override
public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext)
{
Expand All @@ -59,4 +57,6 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand
logger.info("created %d splits from arrow tickets", splits.size());
return new FixedSplitSource(splits);
}

protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,46 +48,6 @@ public ArrowFlightClientHandler(ArrowFlightConfig config)
this.config = config;
}

private void initializeClient(Optional<String> uri)
{
if (!isClientClosed.get()) {
return;
}
try {
allocator = new RootAllocator(Long.MAX_VALUE);
Optional<InputStream> trustedCertificate = Optional.empty();

Location location;
if (uri.isPresent()) {
location = new Location(uri.get());
}
else {
if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) {
location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort());
}
else {
location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort());
}
}

FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
if (config.getVerifyServer() != null && !config.getVerifyServer()) {
flightClientBuilder.verifyServer(false);
}
else if (config.getFlightServerSSLCertificate() != null) {
trustedCertificate = Optional.of(new FileInputStream(config.getFlightServerSSLCertificate()));
flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
}

FlightClient flightClient = flightClientBuilder.build();
this.arrowFlightClient = new ArrowFlightClient(flightClient, trustedCertificate, allocator);
isClientClosed.set(false);
}
catch (Exception ex) {
throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex);
}
}

public ArrowFlightConfig getConfig()
{
return config;
Expand Down Expand Up @@ -122,8 +82,6 @@ public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession con
}
}

protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession);

public synchronized void close() throws Exception
{
if (arrowFlightClient != null) {
Expand All @@ -137,6 +95,61 @@ public synchronized void close() throws Exception
isClientClosed.set(true);
}

public void resetTimer()
{
shutdownTimer();
scheduleCloseTask();
}

public void shutdownTimer()
{
if (scheduledExecutorService != null) {
scheduledExecutorService.shutdownNow();
}
}

protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession);

private void initializeClient(Optional<String> uri)
{
if (!isClientClosed.get()) {
return;
}
try {
allocator = new RootAllocator(Long.MAX_VALUE);
Optional<InputStream> trustedCertificate = Optional.empty();

Location location;
if (uri.isPresent()) {
location = new Location(uri.get());
}
else {
if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) {
location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort());
}
else {
location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort());
}
}

FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
if (config.getVerifyServer() != null && !config.getVerifyServer()) {
flightClientBuilder.verifyServer(false);
}
else if (config.getFlightServerSSLCertificate() != null) {
trustedCertificate = Optional.of(new FileInputStream(config.getFlightServerSSLCertificate()));
flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
}

FlightClient flightClient = flightClientBuilder.build();
this.arrowFlightClient = new ArrowFlightClient(flightClient, trustedCertificate, allocator);
isClientClosed.set(false);
}
catch (Exception ex) {
throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex);
}
}

private void scheduleCloseTask()
{
scheduledExecutorService = Executors.newScheduledThreadPool(1);
Expand All @@ -152,17 +165,4 @@ private void scheduleCloseTask()
};
scheduledExecutorService.schedule(closeTask, TIMER_DURATION_IN_MINUTES, TimeUnit.MINUTES);
}

public void resetTimer()
{
shutdownTimer();
scheduleCloseTask();
}

public void shutdownTimer()
{
if (scheduledExecutorService != null) {
scheduledExecutorService.shutdownNow();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,86 @@ public ArrowPageSource(ArrowSplit split, List<ArrowColumnHandle> columnHandles,
getFlightStream(clientHandler, split.getTicket(), connectorSession);
}

@Override
public long getCompletedBytes()
{
return 0;
}

@Override
public long getCompletedPositions()
{
return currentPosition;
}

@Override
public long getReadTimeNanos()
{
return 0;
}

@Override
public boolean isFinished()
{
return completed;
}

@Override
public long getSystemMemoryUsage()
{
return 0;
}

@Override
public Page getNextPage()
{
if (vectorSchemaRoot.isPresent()) {
vectorSchemaRoot.get().close();
vectorSchemaRoot = Optional.empty();
}

if (flightStream.next()) {
vectorSchemaRoot = Optional.ofNullable(flightStream.getRoot());
}

if (!vectorSchemaRoot.isPresent()) {
completed = true;
}

if (isFinished()) {
return null;
}

currentPosition++;

List<Block> blocks = new ArrayList<>();
for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex);
Type type = columnHandles.get(columnIndex).getColumnType();

Block block = buildBlockFromVector(vector, type);
blocks.add(block);
}

return new Page(vectorSchemaRoot.get().getRowCount(), blocks.toArray(new Block[0]));
}

@Override
public void close() throws IOException
{
if (vectorSchemaRoot.isPresent()) {
vectorSchemaRoot.get().close();
}
if (flightStream != null) {
try {
flightStream.close();
}
catch (Exception e) {
logger.error(e);
}
}
}

private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] ticket, ConnectorSession connectorSession)
{
try {
Expand Down Expand Up @@ -483,84 +563,4 @@ private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type t
}
return builder.build();
}

@Override
public long getCompletedBytes()
{
return 0;
}

@Override
public long getCompletedPositions()
{
return currentPosition;
}

@Override
public long getReadTimeNanos()
{
return 0;
}

@Override
public boolean isFinished()
{
return completed;
}

@Override
public long getSystemMemoryUsage()
{
return 0;
}

@Override
public Page getNextPage()
{
if (vectorSchemaRoot.isPresent()) {
vectorSchemaRoot.get().close();
vectorSchemaRoot = Optional.empty();
}

if (flightStream.next()) {
vectorSchemaRoot = Optional.ofNullable(flightStream.getRoot());
}

if (!vectorSchemaRoot.isPresent()) {
completed = true;
}

if (isFinished()) {
return null;
}

currentPosition++;

List<Block> blocks = new ArrayList<>();
for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex);
Type type = columnHandles.get(columnIndex).getColumnType();

Block block = buildBlockFromVector(vector, type);
blocks.add(block);
}

return new Page(vectorSchemaRoot.get().getRowCount(), blocks.toArray(new Block[0]));
}

@Override
public void close() throws IOException
{
if (vectorSchemaRoot.isPresent()) {
vectorSchemaRoot.get().close();
}
if (flightStream != null) {
try {
flightStream.close();
}
catch (Exception e) {
logger.error(e);
}
}
}
}

0 comments on commit 6d9d75b

Please sign in to comment.