From e2274e482c141ee80ee9b488640336b01de740d2 Mon Sep 17 00:00:00 2001 From: sai bhaskar reddy Date: Tue, 9 Jul 2024 10:55:10 +0530 Subject: [PATCH] Arrow Flight connector template Co-authored-by: sai bhaskar reddy Co-authored-by: SthuthiGhosh9400 Co-authored-by: lithinwxd Co-authored-by: Steve Burnett Co-authored-by: elbinpallimalilibm Co-authored-by: Steve Burnett Co-authored-by: Timothy Meehan --- .github/workflows/arrow-flight-tests.yml | 82 ++ .github/workflows/test-other-modules.yml | 3 +- pom.xml | 1 + presto-base-arrow-flight/pom.xml | 359 +++++++ .../plugin/arrow/AbstractArrowMetadata.java | 269 +++++ .../arrow/AbstractArrowSplitManager.java | 62 ++ .../plugin/arrow/ArrowColumnHandle.java | 61 ++ .../facebook/plugin/arrow/ArrowConnector.java | 87 ++ .../plugin/arrow/ArrowConnectorFactory.java | 98 ++ .../plugin/arrow/ArrowConnectorId.java | 53 + .../facebook/plugin/arrow/ArrowErrorCode.java | 43 + .../facebook/plugin/arrow/ArrowException.java | 31 + .../plugin/arrow/ArrowFlightClient.java | 54 + .../arrow/ArrowFlightClientHandler.java | 125 +++ .../plugin/arrow/ArrowFlightConfig.java | 84 ++ .../plugin/arrow/ArrowHandleResolver.java | 55 + .../facebook/plugin/arrow/ArrowModule.java | 46 + .../plugin/arrow/ArrowPageSource.java | 163 +++ .../plugin/arrow/ArrowPageSourceProvider.java | 52 + .../facebook/plugin/arrow/ArrowPageUtils.java | 968 ++++++++++++++++++ .../facebook/plugin/arrow/ArrowPlugin.java | 49 + .../com/facebook/plugin/arrow/ArrowSplit.java | 90 ++ .../plugin/arrow/ArrowTableHandle.java | 73 ++ .../plugin/arrow/ArrowTableLayoutHandle.java | 86 ++ .../plugin/arrow/ArrowTransactionHandle.java | 22 + .../plugin/arrow/ArrowFlightQueryRunner.java | 68 ++ .../plugin/arrow/ArrowMetadataUtil.java | 75 ++ .../plugin/arrow/ArrowPageUtilsTest.java | 707 +++++++++++++ .../plugin/arrow/TestArrowColumnHandle.java | 83 ++ .../TestArrowFlightIntegrationSmokeTest.java | 71 ++ .../plugin/arrow/TestArrowFlightQueries.java | 175 ++++ .../plugin/arrow/TestArrowHandleResolver.java | 67 ++ .../facebook/plugin/arrow/TestArrowSplit.java | 74 ++ .../plugin/arrow/TestArrowTableHandle.java | 37 + .../arrow/TestArrowTableLayoutHandle.java | 116 +++ .../plugin/arrow/TestingArrowFactory.java | 23 + .../TestingArrowFlightClientHandler.java | 36 + .../arrow/TestingArrowFlightConfig.java | 113 ++ .../arrow/TestingArrowFlightRequest.java | 98 ++ .../plugin/arrow/TestingArrowMetadata.java | 158 +++ .../plugin/arrow/TestingArrowModule.java | 35 + .../plugin/arrow/TestingArrowPlugin.java | 33 + .../arrow/TestingArrowQueryBuilder.java | 305 ++++++ .../plugin/arrow/TestingArrowQueryRunner.java | 74 ++ .../plugin/arrow/TestingArrowServer.java | 310 ++++++ .../arrow/TestingArrowSplitManager.java | 50 + .../arrow/TestingConnectionProperties.java | 35 + .../plugin/arrow/TestingH2DatabaseSetup.java | 273 +++++ .../arrow/TestingInteractionProperties.java | 57 ++ .../plugin/arrow/TestingRequestData.java | 40 + .../src/test/resources/server.crt | 21 + .../src/test/resources/server.key | 28 + presto-docs/src/main/sphinx/connector.rst | 1 + .../sphinx/connector/base-arrow-flight.rst | 95 ++ 54 files changed, 6273 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/arrow-flight-tests.yml create mode 100644 presto-base-arrow-flight/pom.xml create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java create mode 100644 presto-base-arrow-flight/src/test/resources/server.crt create mode 100644 presto-base-arrow-flight/src/test/resources/server.key create mode 100644 presto-docs/src/main/sphinx/connector/base-arrow-flight.rst diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml new file mode 100644 index 0000000000000..ee77c122536e1 --- /dev/null +++ b/.github/workflows/arrow-flight-tests.yml @@ -0,0 +1,82 @@ +name: arrow flight tests + +on: + pull_request: + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - '!presto-docs/**' + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + # Checkout the code only if there are changes in the relevant files + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + # Set up Java for the build environment + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + # Cache Maven dependencies to speed up the build + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + # Install dependencies for the target module + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + + # Run Maven tests for the target module + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} diff --git a/.github/workflows/test-other-modules.yml b/.github/workflows/test-other-modules.yml index 3c3f84817a0d1..1065fe3d0c6ef 100644 --- a/.github/workflows/test-other-modules.yml +++ b/.github/workflows/test-other-modules.yml @@ -84,4 +84,5 @@ jobs: !presto-test-coverage, !presto-iceberg, !presto-singlestore, - !presto-native-sidecar-plugin' + !presto-native-sidecar-plugin, + !presto-base-arrow-flight' diff --git a/pom.xml b/pom.xml index ee4920c084624..b9ec84c9f1537 100644 --- a/pom.xml +++ b/pom.xml @@ -204,6 +204,7 @@ presto-hana presto-openapi presto-native-sidecar-plugin + presto-base-arrow-flight diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml new file mode 100644 index 0000000000000..76a6bb6f4d02f --- /dev/null +++ b/presto-base-arrow-flight/pom.xml @@ -0,0 +1,359 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.291-SNAPSHOT + + presto-base-arrow-flight + presto-base-arrow-flight + Presto - Base Arrow Flight Connector + + + ${project.parent.basedir} + 4.10.0 + 17.0.0 + 4.1.110.Final + 1.6.20 + 2.28.0 + 0.27.0 + + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + log + + + + com.google.guava + guava + + + + com.google.j2objc + j2objc-annotations + + + + + + javax.inject + javax.inject + + + + com.facebook.presto + presto-spi + provided + + + + io.airlift + slice + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + com.facebook.presto + presto-common + provided + + + + io.airlift + units + provided + + + + com.google.code.findbugs + jsr305 + true + + + + org.apache.arrow + arrow-memory-core + ${arrow.version} + + + org.slf4j + slf4j-api + + + + + + com.google.inject + guice + + + + com.facebook.airlift + configuration + + + + org.apache.arrow + arrow-jdbc + ${arrow.version} + + + org.slf4j + slf4j-api + + + + + + org.apache.arrow + arrow-vector + ${arrow.version} + + + org.slf4j + slf4j-api + + + + commons-codec + commons-codec + + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + + + org.testng + testng + test + + + + io.airlift.tpch + tpch + test + + + + joda-time + joda-time + + + + org.jdbi + jdbi3-core + + + + com.facebook.presto + presto-tpch + test + + + + com.facebook.airlift + json + test + + + + com.facebook.presto + presto-testng-services + test + + + + com.facebook.airlift + testing + test + ${dep.airlift.version} + + + + com.fasterxml.jackson.core + jackson-databind + provided + + + + com.fasterxml.jackson.core + jackson-core + provided + + + + org.apache.arrow + flight-core + ${arrow.version} + + + org.slf4j + slf4j-api + + + + com.google.j2objc + j2objc-annotations + + + + + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-tests + test + + + + com.h2database + h2 + test + + + + + + + io.netty + netty-transport-native-unix-common + ${netty.version} + + + + io.netty + netty-common + ${netty.version} + + + + io.netty + netty-buffer + ${netty.version} + + + + io.netty + netty-handler + ${netty.version} + + + + io.netty + netty-transport + ${netty.version} + + + + io.netty + netty-codec + ${netty.version} + + + + io.netty + netty-handler-proxy + ${netty.version} + + + + io.netty + netty-codec-http + ${netty.version} + + + + org.jetbrains.kotlin + kotlin-stdlib-common + ${kotlin.version} + + + + com.google.errorprone + error_prone_annotations + ${error_prone_annotations.version} + + + + io.perfmark + perfmark-api + ${perfmark-api.version} + + + + org.apache.arrow + arrow-algorithm + ${arrow.version} + compile + + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + -Xss10M + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + org.apache.maven.plugins + maven-dependency-plugin + + + org.basepom.maven + duplicate-finder-maven-plugin + 1.2.1 + + + module-info + META-INF.versions.9.module-info + + + arrow-git.properties + about.html + + + + + + check + + + + + + + diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java new file mode 100644 index 0000000000000..bdadfb96f3b63 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -0,0 +1,269 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; +import static java.util.Objects.requireNonNull; + +public abstract class AbstractArrowMetadata + implements ConnectorMetadata +{ + private static final Logger logger = Logger.get(AbstractArrowMetadata.class); + private final ArrowFlightConfig config; + private final ArrowFlightClientHandler clientHandler; + + public AbstractArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler) + { + this.config = requireNonNull(config, "config is null"); + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + } + + private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) + { + switch (floatingPoint.getPrecision()) { + case SINGLE: + return RealType.REAL; + case DOUBLE: + return DoubleType.DOUBLE; + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision()); + } + } + + private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return BigintType.BIGINT; + case 32: + return IntegerType.INTEGER; + case 16: + return SmallintType.SMALLINT; + case 8: + return TinyintType.TINYINT; + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth()); + } + } + + protected Type getPrestoTypeFromArrowField(Field field) + { + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + return getPrestoTypeForArrowIntType(intType); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return VarbinaryType.VARBINARY; + case Date: + return DateType.DATE; + case Timestamp: + return TimestampType.TIMESTAMP; + case Utf8: + case LargeUtf8: + return VarcharType.VARCHAR; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + return getPrestoTypeForArrowFloatingPointType(floatingPoint); + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); + case Bool: + return BooleanType.BOOLEAN; + case Time: + return TimeType.TIME; + default: + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); + } + } + + protected abstract FlightDescriptor getFlightDescriptor(Optional query, String schema, String table); + + protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); + + protected abstract String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName); + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + if (!listSchemaNames(session).contains(tableName.getSchemaName())) { + return null; + } + + if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) { + return null; + } + return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName()); + } + + public List getColumnsList(String schema, String table, ConnectorSession connectorSession) + { + try { + String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); + String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); + FlightDescriptor flightDescriptor = getFlightDescriptor(Optional.empty(), + dataSourceSpecificSchemaName, dataSourceSpecificTableName); + + Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); + List fields = flightschema.map(Schema::getFields).orElse(Collections.emptyList()); + return fields; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The table columns could not be listed for the table " + table, e); + } + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Map columnHandles = new HashMap<>(); + + String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); + String tableValue = ((ArrowTableHandle) tableHandle).getTable(); + String dbSpecificSchemaValue = getDataSourceSpecificSchemaName(config, schemaValue); + String dBSpecificTableName = getDataSourceSpecificTableName(config, tableValue); + List columnList = getColumnsList(dbSpecificSchemaValue, dBSpecificTableName, session); + + for (Field field : columnList) { + String columnName = field.getName(); + logger.debug("The value of the flight columnName is:- %s", columnName); + + Type type = getPrestoTypeFromArrowField(field); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); + } + return columnHandles; + } + + @Override + public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + { + if (!(table instanceof ArrowTableHandle)) { + throw new PrestoException( + StandardErrorCode.INVALID_CAST_ARGUMENT, + "Invalid table handle: Expected an instance of ArrowTableHandle but received " + + table.getClass().getSimpleName()); + } + + ArrowTableHandle tableHandle = (ArrowTableHandle) table; + + List columns = new ArrayList<>(); + if (desiredColumns.isPresent()) { + List arrowColumns = new ArrayList<>(desiredColumns.get()); + columns = (List) (List) arrowColumns; + } + + ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary())); + return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) + { + List meta = new ArrayList<>(); + List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session); + + for (Field field : columnList) { + String columnName = field.getName(); + Type fieldType = getPrestoTypeFromArrowField(field); + meta.add(new ColumnMetadata(columnName, fieldType)); + } + return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ArrowColumnHandle) columnHandle).getColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + ImmutableMap.Builder> columns = ImmutableMap.builder(); + List tables; + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { + tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); + } + else { + tables = listTables(session, Optional.of(prefix.getSchemaName())); + } + + for (SchemaTableName tableName : tables) { + try { + ConnectorTableHandle tableHandle = getTableHandle(session, tableName); + columns.put(tableName, getTableMetadata(session, tableHandle).getColumns()); + } + catch (ClassCastException | NotFoundException e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The table columns could not be listed for the table " + tableName, e); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, e.getMessage(), e); + } + } + return columns.build(); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java new file mode 100644 index 0000000000000..e92afe8b0a4b8 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +public abstract class AbstractArrowSplitManager + implements ConnectorSplitManager +{ + private static final Logger logger = Logger.get(AbstractArrowSplitManager.class); + private final ArrowFlightClientHandler clientHandler; + + public AbstractArrowSplitManager(ArrowFlightClientHandler client) + { + this.clientHandler = client; + } + + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), + tableLayoutHandle); + + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, session); + List splits = flightInfo.getEndpoints() + .stream() + .map(info -> new ArrowSplit( + tableHandle.getSchema(), + tableHandle.getTable(), + info.getTicket().getBytes(), + info.getLocations().stream().map(location -> location.getUri().toString()).collect(toImmutableList()))) + .collect(toImmutableList()); + return new FixedSplitSource(splits); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java new file mode 100644 index 0000000000000..1ee3791e37457 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class ArrowColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + + @JsonCreator + public ArrowColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "type is null"); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + public ColumnMetadata getColumnMetadata() + { + return new ColumnMetadata(columnName, columnType); + } + + @Override + public String toString() + { + return columnName + ":" + columnType; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java new file mode 100644 index 0000000000000..1028af2414308 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.google.inject.Inject; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnector + implements Connector +{ + private final ConnectorMetadata metadata; + private final ConnectorSplitManager splitManager; + private final ConnectorPageSourceProvider pageSourceProvider; + private final ConnectorHandleResolver handleResolver; + + private final ArrowFlightClientHandler arrowFlightClientHandler; + + @Inject + public ArrowConnector(ConnectorMetadata metadata, + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + ArrowFlightClientHandler arrowFlightClientHandler) + { + this.metadata = requireNonNull(metadata, "Metadata is null"); + this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); + this.splitManager = requireNonNull(splitManager, "SplitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); + } + + public Optional getHandleResolver() + { + return Optional.of(handleResolver); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ArrowTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + @Override + public void shutdown() + { + arrowFlightClientHandler.closeRootallocator(); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java new file mode 100644 index 0000000000000..e070b2c4624d6 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpressionService; +import com.google.inject.ConfigurationException; +import com.google.inject.Injector; +import com.google.inject.Module; + +import java.util.Map; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module module; + private final ClassLoader classLoader; + + public ArrowConnectorFactory(String name, Module module, ClassLoader classLoader) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = requireNonNull(name, "name is null"); + this.module = requireNonNull(module, "module is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ArrowHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map requiredConfig, ConnectorContext context) + { + requireNonNull(requiredConfig, "requiredConfig is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap( + binder -> { + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); + binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); + binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); + binder.bind(NodeManager.class).toInstance(context.getNodeManager()); + }, + new ArrowModule(catalogName), + module); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(requiredConfig) + .initialize(); + + return injector.getInstance(ArrowConnector.class); + } + catch (ConfigurationException ex) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The connector instance could not be created.", ex); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java new file mode 100644 index 0000000000000..dce08bac4ac24 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorId +{ + private final String id; + + public ArrowConnectorId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + public String toString() + { + return id; + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + ArrowConnectorId other = (ArrowConnectorId) obj; + return Objects.equals(this.id, other.id); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java new file mode 100644 index 0000000000000..2e33f736a62c5 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.ErrorCode; +import com.facebook.presto.common.ErrorType; +import com.facebook.presto.spi.ErrorCodeSupplier; + +import static com.facebook.presto.common.ErrorType.EXTERNAL; +import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR; + +public enum ArrowErrorCode + implements ErrorCodeSupplier +{ + ARROW_INVALID_TABLE(0, EXTERNAL), + ARROW_INVALID_CREDENTAILS(1, EXTERNAL), + ARROW_FLIGHT_ERROR(2, EXTERNAL), + ARROW_INTERNAL_ERROR(3, INTERNAL_ERROR); + + private final ErrorCode errorCode; + + ArrowErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0509_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java new file mode 100644 index 0000000000000..ba2c6edba589c --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PrestoException; + +public class ArrowException + extends PrestoException +{ + public ArrowException(ErrorCodeSupplier errorCode, String message) + { + super(errorCode, message); + } + + public ArrowException(ErrorCodeSupplier errorCode, String message, Throwable cause) + { + super(errorCode, message, cause); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java new file mode 100644 index 0000000000000..3d12617a0839d --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.apache.arrow.flight.FlightClient; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ArrowFlightClient + implements AutoCloseable +{ + private final FlightClient flightClient; + private final Optional trustedCertificate; + + public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate) + { + this.flightClient = requireNonNull(flightClient, "flightClient cannot be null"); + this.trustedCertificate = requireNonNull(trustedCertificate, "trustedCertificate is null"); + } + + public FlightClient getFlightClient() + { + return flightClient; + } + + public Optional getTrustedCertificate() + { + return trustedCertificate; + } + + @Override + public void close() throws InterruptedException, IOException + { + flightClient.close(); + if (trustedCertificate.isPresent()) { + trustedCertificate.get().close(); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java new file mode 100644 index 0000000000000..545e1c8bc99a0 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; + +public abstract class ArrowFlightClientHandler +{ + private static final Logger logger = Logger.get(ArrowFlightClientHandler.class); + private final ArrowFlightConfig config; + + private RootAllocator allocator; + + public ArrowFlightClientHandler(ArrowFlightConfig config) + { + this.config = requireNonNull(config, "config is null"); + } + + private ArrowFlightClient initializeClient(Optional uri) + { + try { + Optional trustedCertificate = Optional.empty(); + + Location location; + if (uri.isPresent()) { + location = new Location(uri.get()); + } + else { + if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) { + location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); + } + else { + location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); + } + } + + if (null == allocator) { + initializeAllocator(); + } + + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); + if (config.getVerifyServer() != null && !config.getVerifyServer()) { + flightClientBuilder.verifyServer(false); + } + else if (config.getFlightServerSSLCertificate() != null) { + trustedCertificate = Optional.of(new FileInputStream(config.getFlightServerSSLCertificate())); + flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); + } + + FlightClient flightClient = flightClientBuilder.build(); + return new ArrowFlightClient(flightClient, trustedCertificate); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex); + } + } + + private synchronized void initializeAllocator() + { + if (allocator == null) { + allocator = new RootAllocator(Long.MAX_VALUE); + } + } + + protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); + + public ArrowFlightConfig getConfig() + { + return config; + } + + public ArrowFlightClient getClient(Optional uri) + { + return initializeClient(uri); + } + + public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + try (ArrowFlightClient client = getClient(Optional.empty())) { + CredentialCallOption auth = this.getCallOptions(connectorSession); + FlightInfo flightInfo = client.getFlightClient().getInfo(flightDescriptor, auth); + return flightInfo; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); + } + } + + public Optional getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + return getFlightInfo(flightDescriptor, connectorSession).getSchemaOptional(); + } + + public void closeRootallocator() + { + if (null != allocator) { + allocator.close(); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java new file mode 100644 index 0000000000000..a30301dcdc361 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; + +public class ArrowFlightConfig +{ + private String server; + private Boolean verifyServer; + private String flightServerSSLCertificate; + private Boolean arrowFlightServerSslEnabled; + private Integer arrowFlightPort; + public String getFlightServerName() + { + return server; + } + + public Boolean getVerifyServer() + { + return verifyServer; + } + + public Boolean getArrowFlightServerSslEnabled() + { + return arrowFlightServerSslEnabled; + } + + public String getFlightServerSSLCertificate() + { + return flightServerSSLCertificate; + } + + public Integer getArrowFlightPort() + { + return arrowFlightPort; + } + + @Config("arrow-flight.server") + public ArrowFlightConfig setFlightServerName(String server) + { + this.server = server; + return this; + } + + @Config("arrow-flight.server.verify") + public ArrowFlightConfig setVerifyServer(Boolean verifyServer) + { + this.verifyServer = verifyServer; + return this; + } + + @Config("arrow-flight.server.port") + public ArrowFlightConfig setArrowFlightPort(Integer arrowFlightPort) + { + this.arrowFlightPort = arrowFlightPort; + return this; + } + + @Config("arrow-flight.server-ssl-certificate") + public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCertificate) + { + this.flightServerSSLCertificate = flightServerSSLCertificate; + return this; + } + + @Config("arrow-flight.server-ssl-enabled") + public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled) + { + this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; + return this; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java new file mode 100644 index 0000000000000..8b231b98a6ee6 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ArrowHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ArrowTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ArrowTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ArrowColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ArrowSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ArrowTransactionHandle.class; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java new file mode 100644 index 0000000000000..b20762f8ad497 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static java.util.Objects.requireNonNull; + +public class ArrowModule + implements Module +{ + protected final String connectorId; + + public ArrowModule(String connectorId) + { + this.connectorId = requireNonNull(connectorId, "connector id is null"); + } + + public void configure(Binder binder) + { + configBinder(binder).bindConfig(ArrowFlightConfig.class); + binder.bind(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnectorId.class).toInstance(new ArrowConnectorId(connectorId)); + binder.bind(ConnectorHandleResolver.class).to(ArrowHandleResolver.class).in(Scopes.SINGLETON); + binder.bind(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(ConnectorPageSourceProvider.class).to(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(Connector.class).to(ArrowConnector.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java new file mode 100644 index 0000000000000..ec51125a34848 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -0,0 +1,163 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; + +public class ArrowPageSource + implements ConnectorPageSource +{ + private static final Logger logger = Logger.get(ArrowPageSource.class); + private final ArrowSplit split; + private final List columnHandles; + private boolean completed; + private int currentPosition; + private Optional vectorSchemaRoot = Optional.empty(); + private ArrowFlightClient flightClient; + private FlightStream flightStream; + + public ArrowPageSource(ArrowSplit split, List columnHandles, ArrowFlightClientHandler clientHandler, + ConnectorSession connectorSession) + { + this.columnHandles = columnHandles; + this.split = split; + getFlightStream(clientHandler, split.getTicket(), connectorSession); + } + + private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] ticket, ConnectorSession connectorSession) + { + try { + Optional uri = (split.getLocationUrls().isEmpty()) ? + Optional.empty() : Optional.of(split.getLocationUrls().get(0)); + flightClient = clientHandler.getClient(uri); + flightStream = flightClient.getFlightClient().getStream(new Ticket(ticket), clientHandler.getCallOptions(connectorSession)); + } + catch (FlightRuntimeException e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, e.getMessage(), e); + } + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getCompletedPositions() + { + return currentPosition; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return completed; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public Page getNextPage() + { + if (vectorSchemaRoot.isPresent()) { + vectorSchemaRoot.get().close(); + vectorSchemaRoot = Optional.empty(); + } + + if (flightStream.next()) { + vectorSchemaRoot = Optional.ofNullable(flightStream.getRoot()); + } + + if (!vectorSchemaRoot.isPresent()) { + completed = true; + } + + if (isFinished()) { + return null; + } + + currentPosition++; + + List blocks = new ArrayList<>(); + for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { + FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); + Type type = columnHandles.get(columnIndex).getColumnType(); + + boolean isDictionaryBlock = vector.getField().getDictionary() != null; + Dictionary dictionary = null; + if (isDictionaryBlock) { + dictionary = flightStream.getDictionaryProvider().lookup(vector.getField().getDictionary().getId()); + } + Block block = null != dictionary ? ArrowPageUtils.buildBlockFromVector(vector, type, dictionary.getVector(), isDictionaryBlock) : + ArrowPageUtils.buildBlockFromVector(vector, type, null, false); + blocks.add(block); + } + + return new Page(vectorSchemaRoot.get().getRowCount(), blocks.toArray(new Block[0])); + } + + @Override + public void close() + { + if (vectorSchemaRoot.isPresent()) { + vectorSchemaRoot.get().close(); + vectorSchemaRoot = Optional.empty(); + } + if (flightStream != null) { + try { + flightStream.close(); + } + catch (Exception e) { + logger.error(e); + } + } + try { + if (flightClient != null) { + flightClient.close(); + flightClient = null; + } + } + catch (Exception ex) { + logger.error("Failed to close the flight client: %s", ex.getMessage(), ex); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java new file mode 100644 index 0000000000000..f3bb41c3e35d4 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.List; + +public class ArrowPageSourceProvider + implements ConnectorPageSourceProvider +{ + private static final Logger logger = Logger.get(ArrowPageSourceProvider.class); + private ArrowFlightClientHandler clientHandler; + @Inject + public ArrowPageSourceProvider(ArrowFlightClientHandler clientHandler) + { + this.clientHandler = clientHandler; + } + + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List columns, SplitContext splitContext) + { + ImmutableList.Builder columnHandles = ImmutableList.builder(); + for (ColumnHandle handle : columns) { + columnHandles.add((ArrowColumnHandle) handle); + } + ArrowSplit arrowSplit = (ArrowSplit) split; + logger.debug("Processing split with flight ticket"); + return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java new file mode 100644 index 0000000000000..04b20d00c24a8 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -0,0 +1,968 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListReader; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.JsonStringArrayList; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static java.util.Objects.requireNonNull; + +public class ArrowPageUtils +{ + private ArrowPageUtils() + { + } + + public static Block buildBlockFromVector(FieldVector vector, Type type, FieldVector dictionary, boolean isDictionaryVector) + { + if (isDictionaryVector) { + return buildBlockFromDictionaryVector(vector, dictionary); + } + else if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + else if (vector instanceof ListVector) { + return buildBlockFromListVector((ListVector) vector, type); + } + + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); + } + + public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) + { + // Validate inputs + requireNonNull(fieldVector, "encoded vector is null"); + requireNonNull(dictionaryVector, "dictionary vector is null"); + + // Create a BlockBuilder for the decoded vector's data type + Type prestoType = getPrestoTypeFromArrowType(dictionaryVector.getField().getType()); + + Block dictionaryblock = null; + // Populate the block dynamically based on vector type + for (int i = 0; i < dictionaryVector.getValueCount(); i++) { + if (!dictionaryVector.isNull(i)) { + dictionaryblock = appendValueToBlock(dictionaryVector, prestoType); + } + } + + return getDictionaryBlock(fieldVector, dictionaryblock); + + // Create the Presto DictionaryBlock + } + + private static DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + { + if (fieldVector instanceof IntVector) { + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof SmallIntVector) { + // Get the SmallInt indices vector + SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; + int[] ids = new int[smallIntIndicesVector.getValueCount()]; + for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) { + ids[i] = smallIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof TinyIntVector) { + // Get the TinyInt indices vector + TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector; + int[] ids = new int[tinyIntIndicesVector.getValueCount()]; + for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) { + ids[i] = tinyIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else { + // Handle the case where the FieldVector is of an unsupported type + throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass()); + } + } + + private static Type getPrestoTypeFromArrowType(ArrowType arrowType) + { + if (arrowType instanceof ArrowType.Utf8) { + return VarcharType.VARCHAR; + } + else if (arrowType instanceof ArrowType.Int) { + ArrowType.Int intType = (ArrowType.Int) arrowType; + if (intType.getBitWidth() == 8 || intType.getBitWidth() == 16 || intType.getBitWidth() == 32) { + return IntegerType.INTEGER; + } + else if (intType.getBitWidth() == 64) { + return BigintType.BIGINT; + } + } + else if (arrowType instanceof ArrowType.FloatingPoint) { + ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; + FloatingPointPrecision precision = fpType.getPrecision(); + + if (precision == FloatingPointPrecision.SINGLE) { // 32-bit float + return RealType.REAL; + } + else if (precision == FloatingPointPrecision.DOUBLE) { // 64-bit float + return DoubleType.DOUBLE; + } + else { + throw new UnsupportedOperationException("Unsupported FloatingPoint precision: " + precision); + } + } + else if (arrowType instanceof ArrowType.Bool) { + return BooleanType.BOOLEAN; + } + else if (arrowType instanceof ArrowType.Binary) { + return VarbinaryType.VARBINARY; + } + else if (arrowType instanceof ArrowType.Decimal) { + return DecimalType.createDecimalType(); + } + throw new UnsupportedOperationException("Unsupported ArrowType: " + arrowType); + } + + private static Block appendValueToBlock(ValueVector vector, Type prestoType) + { + if (vector instanceof VarCharVector) { + return buildBlockFromVarCharVector((VarCharVector) vector, prestoType); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, prestoType); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, prestoType); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, prestoType); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, prestoType); + } + else if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, prestoType); + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, prestoType); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, prestoType); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, prestoType); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, prestoType); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, prestoType); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, prestoType); + } + else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + public static Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + public static Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromListVector(ListVector vector, Type type) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + arrayBuilder.appendNull(); + } + else { + BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); + UnionListReader reader = vector.getReader(); + reader.setPosition(i); + + while (reader.next()) { + Object value = reader.readObject(); + if (value == null) { + elementBuilder.appendNull(); + } + else { + appendValueToBuilder(elementType, elementBuilder, value); + } + } + arrayBuilder.closeEntry(); + } + } + return arrayBuilder.build(); + } + + public static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + { + if (value == null) { + builder.appendNull(); + return; + } + + if (type instanceof VarcharType) { + writeVarcharType(type, builder, value); + } + else if (type instanceof SmallintType) { + writeSmallintType(type, builder, value); + } + else if (type instanceof TinyintType) { + writeTinyintType(type, builder, value); + } + else if (type instanceof BigintType) { + writeBigintType(type, builder, value); + } + else if (type instanceof IntegerType) { + writeIntegerType(type, builder, value); + } + else if (type instanceof DoubleType) { + writeDoubleType(type, builder, value); + } + else if (type instanceof BooleanType) { + writeBooleanType(type, builder, value); + } + else if (type instanceof DecimalType) { + writeDecimalType((DecimalType) type, builder, value); + } + else if (type instanceof ArrayType) { + writeArrayType((ArrayType) type, builder, value); + } + else if (type instanceof RowType) { + writeRowType((RowType) type, builder, value); + } + else if (type instanceof DateType) { + writeDateType(type, builder, value); + } + else if (type instanceof TimestampType) { + writeTimestampType(type, builder, value); + } + else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + public static void writeVarcharType(Type type, BlockBuilder builder, Object value) + { + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } + + public static void writeSmallintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).shortValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + short shortValue = Short.parseShort(obj.toString()); + type.writeLong(builder, shortValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for SmallintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for SmallintType: " + value.getClass()); + } + } + + public static void writeTinyintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).byteValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + byte byteValue = Byte.parseByte(obj.toString()); + type.writeLong(builder, byteValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for TinyintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for TinyintType: " + value.getClass()); + } + } + + public static void writeBigintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } + else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } + + public static void writeIntegerType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } + + public static void writeDoubleType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } + else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } + + public static void writeBooleanType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } + else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } + + public static void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) + { + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (type.isShort()) { + // write ShortDecimalType + long unscaledValue = decimalValue.unscaledValue().longValue(); + type.writeLong(builder, unscaledValue); + } + else { + // write LongDecimalType + Slice slice = Decimals.encodeScaledValue(decimalValue); + type.writeSlice(builder, slice); + } + } + else if (value instanceof Long) { + // Direct handling for ShortDecimalType using long + if (type.isShort()) { + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Long value is not supported for LongDecimalType."); + } + } + else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + } + } + + public static void writeArrayType(ArrayType type, BlockBuilder builder, Object value) + { + Type elementType = type.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } + + public static void writeRowType(RowType type, BlockBuilder builder, Object value) + { + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = type.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } + + public static void writeDateType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { + int daysSinceEpoch = (int) (value instanceof java.sql.Date + ? ((java.sql.Date) value).toLocalDate().toEpochDay() + : ((java.time.LocalDate) value).toEpochDay()); + type.writeLong(builder, daysSinceEpoch); + } + else { + throw new IllegalArgumentException("Unsupported type for DateType: " + value.getClass()); + } + } + + public static void writeTimestampType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Timestamp) { + long millis = ((java.sql.Timestamp) value).getTime(); + type.writeLong(builder, millis); + } + else if (value instanceof java.time.Instant) { + long millis = ((java.time.Instant) value).toEpochMilli(); + type.writeLong(builder, millis); + } + else if (value instanceof Long) { // write long epoch milliseconds directly + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Unsupported type for TimestampType: " + value.getClass()); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java new file mode 100644 index 0000000000000..bb5599b00b625 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ArrowPlugin + implements Plugin +{ + protected final String name; + protected final Module module; + + public ArrowPlugin(String name, Module module) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + } + + private static ClassLoader getClassLoader() + { + return firstNonNull(Thread.currentThread().getContextClassLoader(), ArrowPlugin.class.getClassLoader()); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ArrowConnectorFactory(name, module, getClassLoader())); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java new file mode 100644 index 0000000000000..db65912de8c58 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +public class ArrowSplit + implements ConnectorSplit +{ + private final String schemaName; + private final String tableName; + private final byte[] ticket; + private final List locationUrls; + + @JsonCreator + public ArrowSplit( + @JsonProperty("schemaName") @Nullable String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("ticket") byte[] ticket, + @JsonProperty("locationUrls") List locationUrls) + { + this.schemaName = schemaName; + this.tableName = tableName; + this.ticket = ticket; + this.locationUrls = locationUrls; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return this.getInfoMap(); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public byte[] getTicket() + { + return ticket; + } + + @JsonProperty + public List getLocationUrls() + { + return locationUrls; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java new file mode 100644 index 0000000000000..cef04e4372a5a --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +public class ArrowTableHandle + implements ConnectorTableHandle +{ + private final String schema; + private final String table; + + @JsonCreator + public ArrowTableHandle( + @JsonProperty("schema") String schema, + @JsonProperty("table") String table) + { + this.schema = schema; + this.table = table; + } + + @JsonProperty("schema") + public String getSchema() + { + return schema; + } + + @JsonProperty("table") + public String getTable() + { + return table; + } + + @Override + public String toString() + { + return schema + ":" + table; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableHandle that = (ArrowTableHandle) o; + return Objects.equals(schema, that.schema) && Objects.equals(table, that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(schema, table); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..46e94a4e1143c --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final ArrowTableHandle tableHandle; + private final List columnHandles; + private final TupleDomain tupleDomain; + + @JsonCreator + public ArrowTableLayoutHandle(@JsonProperty("table") ArrowTableHandle table, + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) + { + this.tableHandle = requireNonNull(table, "table is null"); + this.columnHandles = requireNonNull(columnHandles, "columns are null"); + this.tupleDomain = requireNonNull(domain, "domain is null"); + } + + @JsonProperty("table") + public ArrowTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty("tupleDomain") + public TupleDomain getTupleDomain() + { + return tupleDomain; + } + + @JsonProperty("columnHandles") + public List getColumnHandles() + { + return columnHandles; + } + + @Override + public String toString() + { + return "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o; + return Objects.equals(tableHandle, arrowTableLayoutHandle.tableHandle) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain); + } + + @Override + public int hashCode() + { + return Objects.hash(tableHandle, columnHandles, tupleDomain); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java new file mode 100644 index 0000000000000..07eb7385cfbcf --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ArrowTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java new file mode 100644 index 0000000000000..2795a33668f6f --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class ArrowFlightQueryRunner +{ + private static final Logger logger = Logger.get(ArrowFlightQueryRunner.class); + private ArrowFlightQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner() throws Exception + { + return createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties, Class factoryClass) throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("testdb") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); + + try { + String connectorName = "arrow"; + queryRunner.installPlugin(new ArrowPlugin(connectorName, new TestingArrowModule())); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "127.0.0.1") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server.port", "9443") + .put("arrow-flight.server.verify", "false"); + + queryRunner.createCatalog(connectorName, connectorName, properties.build()); + + return queryRunner; + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java new file mode 100644 index 0000000000000..c4ab656d41bb3 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; + +final class ArrowMetadataUtil +{ + private ArrowMetadataUtil() {} + + public static final JsonCodec COLUMN_CODEC; + public static final JsonCodec TABLE_CODEC; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer())); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); + TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); + } + + public static final class TestingTypeDeserializer + extends FromStringDeserializer + { + private final Map types = ImmutableMap.of( + StandardTypes.BIGINT, BIGINT, + StandardTypes.VARCHAR, VARCHAR); + + public TestingTypeDeserializer() + { + super(Type.class); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + Type type = types.get(value.toLowerCase(ENGLISH)); + checkArgument(type != null, "Unknown type %s", value); + return type; + } + } + + public static void assertJsonRoundTrip(JsonCodec codec, T object) + { + String json = codec.toJson(object); + T copy = codec.fromJson(json); + assertEquals(copy, object); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java new file mode 100644 index 0000000000000..626e7815c13ad --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -0,0 +1,707 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import io.airlift.slice.Slice; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowPageUtils.buildBlockFromDictionaryVector; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class ArrowPageUtilsTest +{ + private static final int DICTIONARY_LENGTH = 10; + private static final int VECTOR_LENGTH = 50; + private BufferAllocator allocator; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + System.out.println("Allocator initialized: " + allocator); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + BitVector bitVector = new BitVector("bitVector", allocator); + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + IntVector intVector = new IntVector("intVector", allocator); + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator); + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2); // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator); + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } + + @Test + public void testBuildBlockFromListVector() + { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); + + // Validate the result + assertEquals(block.getPositionCount(), 4); // 4 lists in the block + } + } + + @Test + public void testProcessDictionaryVector() + { + // Create dictionary vector + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(DICTIONARY_LENGTH); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + dictionaryVector.setValueCount(DICTIONARY_LENGTH); + + // Create raw vector + VarCharVector rawVector = new VarCharVector("raw", allocator); + rawVector.allocateNew(VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int value = i % DICTIONARY_LENGTH; + rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8)); + } + rawVector.setValueCount(VECTOR_LENGTH); + + // Encode using dictionary + ArrowType.Int index = new ArrowType.Int(16, true); + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, index)); + BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); + + // Process the dictionary vector + Block result = buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); + + // Verify the result + assertNotNull(result, "The BlockBuilder should not be null."); + assertEquals(result.getPositionCount(), 50); + } + + @Test + public void testBuildBlockFromDictionaryVector() + { + IntVector indicesVector = new IntVector("indices", allocator); + indicesVector.allocateNew(3); // allocating space for 3 values + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Set up index values (this would reference the dictionary) + indicesVector.set(0, 0); // First index points to "apple" + indicesVector.set(1, 1); // Second index points to "banana" + indicesVector.set(2, 2); + indicesVector.set(3, 2); // Third index points to "cherry" + indicesVector.setValueCount(4); + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + else if (i == 3) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorBigInt() + { + BigIntVector indicesVector = new BigIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, 0L); + indicesVector.set(1, 1L); + indicesVector.set(2, 2L); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorSmallInt() + { + SmallIntVector indicesVector = new SmallIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (short) 0); + indicesVector.set(1, (short) 1); + indicesVector.set(2, (short) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorTinyInt() + { + TinyIntVector indicesVector = new TinyIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (byte) 0); + indicesVector.set(1, (byte) 1); + indicesVector.set(2, (byte) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testWriteVarcharType() + { + Type varcharType = VarcharType.createUnboundedVarcharType(); + BlockBuilder builder = varcharType.createBlockBuilder(null, 1); + + String value = "test_string"; + ArrowPageUtils.writeVarcharType(varcharType, builder, value); + + Block block = builder.build(); + Slice result = varcharType.getSlice(block, 0); + assertEquals(result.toStringUtf8(), value); + } + + @Test + public void testWriteSmallintType() + { + Type smallintType = SmallintType.SMALLINT; + BlockBuilder builder = smallintType.createBlockBuilder(null, 1); + + short value = 42; + ArrowPageUtils.writeSmallintType(smallintType, builder, value); + + Block block = builder.build(); + long result = smallintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTinyintType() + { + Type tinyintType = TinyintType.TINYINT; + BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); + + byte value = 7; + ArrowPageUtils.writeTinyintType(tinyintType, builder, value); + + Block block = builder.build(); + long result = tinyintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteBigintType() + { + Type bigintType = BigintType.BIGINT; + BlockBuilder builder = bigintType.createBlockBuilder(null, 1); + + long value = 123456789L; + ArrowPageUtils.writeBigintType(bigintType, builder, value); + + Block block = builder.build(); + long result = bigintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteIntegerType() + { + Type integerType = IntegerType.INTEGER; + BlockBuilder builder = integerType.createBlockBuilder(null, 1); + + int value = 42; + ArrowPageUtils.writeIntegerType(integerType, builder, value); + + Block block = builder.build(); + long result = integerType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteDoubleType() + { + Type doubleType = DoubleType.DOUBLE; + BlockBuilder builder = doubleType.createBlockBuilder(null, 1); + + double value = 42.42; + ArrowPageUtils.writeDoubleType(doubleType, builder, value); + + Block block = builder.build(); + double result = doubleType.getDouble(block, 0); + assertEquals(result, value, 0.001); + } + + @Test + public void testWriteBooleanType() + { + Type booleanType = BooleanType.BOOLEAN; + BlockBuilder builder = booleanType.createBlockBuilder(null, 1); + + boolean value = true; + ArrowPageUtils.writeBooleanType(booleanType, builder, value); + + Block block = builder.build(); + boolean result = booleanType.getBoolean(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteArrayType() + { + Type elementType = IntegerType.INTEGER; + ArrayType arrayType = new ArrayType(elementType); + BlockBuilder builder = arrayType.createBlockBuilder(null, 1); + + List values = Arrays.asList(1, 2, 3); + ArrowPageUtils.writeArrayType(arrayType, builder, values); + + Block block = builder.build(); + Block arrayBlock = arrayType.getObject(block, 0); + assertEquals(arrayBlock.getPositionCount(), values.size()); + for (int i = 0; i < values.size(); i++) { + assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue()); + } + } + + @Test + public void testWriteRowType() + { + RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); + RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); + RowType rowType = RowType.from(Arrays.asList(field1, field2)); + BlockBuilder builder = rowType.createBlockBuilder(null, 1); + + List rowValues = Arrays.asList(42, "test"); + ArrowPageUtils.writeRowType(rowType, builder, rowValues); + + Block block = builder.build(); + Block rowBlock = rowType.getObject(block, 0); + assertEquals(IntegerType.INTEGER.getLong(rowBlock, 0), 42); + assertEquals(VarcharType.createUnboundedVarcharType().getSlice(rowBlock, 1).toStringUtf8(), "test"); + } + + @Test + public void testWriteDateType() + { + Type dateType = DateType.DATE; + BlockBuilder builder = dateType.createBlockBuilder(null, 1); + + LocalDate value = LocalDate.of(2020, 1, 1); + ArrowPageUtils.writeDateType(dateType, builder, value); + + Block block = builder.build(); + long result = dateType.getLong(block, 0); + assertEquals(result, value.toEpochDay()); + } + + @Test + public void testWriteTimestampType() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC + ArrowPageUtils.writeTimestampType(timestampType, builder, value); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTimestampTypeWithSqlTimestamp() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); + long expectedMillis = timestamp.getTime(); + ArrowPageUtils.writeTimestampType(timestampType, builder, timestamp); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, expectedMillis); + } + + @Test + public void testShortDecimalRetrieval() + { + DecimalType shortDecimalType = DecimalType.createDecimalType(10, 2); // Precision: 10, Scale: 2 + BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); + + BigDecimal decimalValue = new BigDecimal("12345.67"); + ArrowPageUtils.writeDecimalType(shortDecimalType, builder, decimalValue); + + Block block = builder.build(); + long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 + BigDecimal result = BigDecimal.valueOf(unscaledValue).movePointLeft(shortDecimalType.getScale()); + assertEquals(result, decimalValue); + } + + @Test + public void testLongDecimalRetrieval() + { + // Create a DecimalType with precision 38 and scale 10 + DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); + BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); + BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); + ArrowPageUtils.writeDecimalType(longDecimalType, builder, decimalValue); + // Build the block after inserting the decimal value + Block block = builder.build(); + Slice unscaledSlice = longDecimalType.getSlice(block, 0); + BigInteger unscaledValue = Decimals.decodeUnscaledValue(unscaledSlice); + BigDecimal result = new BigDecimal(unscaledValue).movePointLeft(longDecimalType.getScale()); + // Assert the decoded result is equal to the original decimal value + assertEquals(result, decimalValue); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java new file mode 100644 index 0000000000000..1d9c490180abc --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.spi.ColumnMetadata; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestArrowColumnHandle +{ + @Test + public void testConstructorAndGetters() + { + // Given + String columnName = "testColumn"; + // When + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // Then + assertEquals(columnHandle.getColumnName(), columnName, "Column name should match the input"); + assertEquals(columnHandle.getColumnType(), IntegerType.INTEGER, "Column type should match the input"); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnName() + { + // Given + // When + new ArrowColumnHandle(null, IntegerType.INTEGER); // Should throw NullPointerException + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnType() + { + // Given + String columnName = "testColumn"; + + // When + new ArrowColumnHandle(columnName, null); // Should throw NullPointerException + } + + @Test + public void testGetColumnMetadata() + { + // Given + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // When + ColumnMetadata columnMetadata = columnHandle.getColumnMetadata(); + + // Then + assertNotNull(columnMetadata, "ColumnMetadata should not be null"); + assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); + } + + @Test + public void testToString() + { + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + String result = columnHandle.toString(); + String expected = columnName + ":" + IntegerType.INTEGER; + assertEquals(result, expected, "toString() should return the correct string representation"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java new file mode 100644 index 0000000000000..1bc50b08cbd37 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +import java.io.File; + +public class TestArrowFlightIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + private static final Logger logger = Logger.get(TestArrowFlightIntegrationSmokeTest.class); + private static RootAllocator allocator; + private static FlightServer server; + private static Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java new file mode 100644 index 0000000000000..21b28745cd1d3 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueries; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; + +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestArrowFlightQueries + extends AbstractTestQueries +{ + private static final Logger logger = Logger.get(TestArrowFlightQueries.class); + private static RootAllocator allocator; + private static FlightServer server; + private static Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } + + @Test + public void testShowCharColumns() + { + MaterializedResult actual = computeActual("SHOW COLUMNS FROM member"); + + MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar", "", "") + .row("sex", "char", "", "") + .row("state", "char", "", "") + .build(); + + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar(50)", "", "") + .row("sex", "char(1)", "", "") + .row("state", "char(5)", "", "") + .build(); + + assertTrue(actual.equals(expectedParametrizedVarchar) || actual.equals(expectedUnparametrizedVarchar), + format("%s matches neither %s nor %s", actual, expectedParametrizedVarchar, expectedUnparametrizedVarchar)); + } + + @Test + public void testPredicateOnCharColumn() + { + MaterializedResult actualRow = computeActual("SELECT * from member WHERE state = 'CD'"); + MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createCharType(1), createCharType(5)) + .row(2, "MARY", "F", "CD ") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + @Test + public void testSelectTime() + { + MaterializedResult actualRow = computeActual("SELECT * from event WHERE id = 1"); + Session session = getSession(); + MaterializedResult expectedRow = resultBuilder(session, INTEGER, DATE, TIME, TIMESTAMP) + .row(1, + getDate("2004-12-31"), + getTimeAtZone("23:59:59", session.getTimeZoneKey()), + getDateTimeAtZone("2005-12-31 23:59:59", session.getTimeZoneKey())) + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + private LocalDate getDate(String dateString) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); + LocalDate localDate = LocalDate.parse(dateString, formatter); + + return localDate; + } + + private LocalTime getTimeAtZone(String timeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ss"); + LocalTime localTime = LocalTime.parse(timeString, formatter); + + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.of(1970, 1, 1), localTime); + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalTime localTimeAtZone = zonedDateTime.toLocalTime(); + return localTimeAtZone; + } + + private LocalDateTime getDateTimeAtZone(String dateTimeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + LocalDateTime localDateTime = LocalDateTime.parse(dateTimeString, formatter); + + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalDateTime localDateTimeAtZone = zonedDateTime.toLocalDateTime(); + return localDateTimeAtZone; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java new file mode 100644 index 0000000000000..ea95e9fec01b0 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestArrowHandleResolver +{ + @Test + public void testGetTableHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableHandleClass(), + ArrowTableHandle.class, + "getTableHandleClass should return ArrowTableHandle class."); + } + @Test + public void testGetTableLayoutHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableLayoutHandleClass(), + ArrowTableLayoutHandle.class, + "getTableLayoutHandleClass should return ArrowTableLayoutHandle class."); + } + @Test + public void testGetColumnHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getColumnHandleClass(), + ArrowColumnHandle.class, + "getColumnHandleClass should return ArrowColumnHandle class."); + } + @Test + public void testGetSplitClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getSplitClass(), + ArrowSplit.class, + "getSplitClass should return ArrowSplit class."); + } + @Test + public void testGetTransactionHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTransactionHandleClass(), + ArrowTransactionHandle.class, + "getTransactionHandleClass should return ArrowTransactionHandle class."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java new file mode 100644 index 0000000000000..65da26254bd34 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestArrowSplit +{ + private ArrowSplit arrowSplit; + private String schemaName; + private String tableName; + private byte[] ticket; + private List locationUrls; + + @BeforeMethod + public void setUp() + { + schemaName = "testSchema"; + tableName = "testTable"; + ticket = new byte[] {1, 2, 3, 4}; + locationUrls = Arrays.asList("http://localhost:8080", "http://localhost:8081"); + + // Instantiate ArrowSplit with mock data + arrowSplit = new ArrowSplit(schemaName, tableName, ticket, locationUrls); + } + + @Test + public void testConstructorAndGetters() + { + // Test that the constructor correctly initializes fields + assertEquals(arrowSplit.getSchemaName(), schemaName, "Schema name should match."); + assertEquals(arrowSplit.getTableName(), tableName, "Table name should match."); + assertEquals(arrowSplit.getTicket(), ticket, "Ticket byte array should match."); + assertEquals(arrowSplit.getLocationUrls(), locationUrls, "Location URLs list should match."); + } + + @Test + public void testNodeSelectionStrategy() + { + // Test that the node selection strategy is NO_PREFERENCE + assertEquals(arrowSplit.getNodeSelectionStrategy(), NodeSelectionStrategy.NO_PREFERENCE, "Node selection strategy should be NO_PREFERENCE."); + } + + @Test + public void testGetPreferredNodes() + { + // Test that the preferred nodes list is empty + List preferredNodes = arrowSplit.getPreferredNodes(null); + assertNotNull(preferredNodes, "Preferred nodes list should not be null."); + assertTrue(preferredNodes.isEmpty(), "Preferred nodes list should be empty."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java new file mode 100644 index 0000000000000..2061fe5036534 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.testing.EquivalenceTester; +import org.testng.annotations.Test; + +import static com.facebook.plugin.arrow.ArrowMetadataUtil.TABLE_CODEC; +import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip; + +public class TestArrowTableHandle +{ + @Test + public void testJsonRoundTrip() + { + assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table")); + } + + @Test + public void testEquivalence() + { + EquivalenceTester.equivalenceTester() + .addEquivalentGroup( + new ArrowTableHandle("tm_engine", "employees")).check(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..0ff7301c7e0ff --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestArrowTableLayoutHandle +{ + @Test + public void testConstructorAndGetters() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + assertEquals(layoutHandle.getTableHandle(), tableHandle, "Table handle mismatch."); + assertEquals(layoutHandle.getColumnHandles(), columnHandles, "Column handles mismatch."); + assertEquals(layoutHandle.getTupleDomain(), tupleDomain, "Tuple domain mismatch."); + } + + @Test + public void testToString() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", BigintType.BIGINT)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + String expectedString = "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + assertEquals(layoutHandle.toString(), expectedString, "toString output mismatch."); + } + + @Test + public void testEqualsAndHashCode() + { + ArrowTableHandle tableHandle1 = new ArrowTableHandle("schema", "table"); + ArrowTableHandle tableHandle2 = new ArrowTableHandle("schema", "different_table"); + + List columnHandles1 = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + List columnHandles2 = Collections.singletonList( + new ArrowColumnHandle("column1", IntegerType.INTEGER)); + + TupleDomain tupleDomain1 = TupleDomain.all(); + TupleDomain tupleDomain2 = TupleDomain.none(); + + ArrowTableLayoutHandle layoutHandle1 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle2 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle3 = new ArrowTableLayoutHandle(tableHandle2, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle4 = new ArrowTableLayoutHandle(tableHandle1, columnHandles2, tupleDomain1); + ArrowTableLayoutHandle layoutHandle5 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain2); + + // Test equality + assertEquals(layoutHandle1, layoutHandle2, "Handles with same attributes should be equal."); + assertNotEquals(layoutHandle1, layoutHandle3, "Handles with different tableHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle4, "Handles with different columnHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle5, "Handles with different tupleDomains should not be equal."); + assertNotEquals(layoutHandle1, null, "Handle should not be equal to null."); + assertNotEquals(layoutHandle1, new Object(), "Handle should not be equal to an object of another class."); + + // Test hash codes + assertEquals(layoutHandle1.hashCode(), layoutHandle2.hashCode(), "Equal handles should have same hash code."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle3.hashCode(), "Handles with different tableHandles should have different hash codes."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle4.hashCode(), "Handles with different columnHandles should have different hash codes."); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "table is null") + public void testConstructorNullTableHandle() + { + new ArrowTableLayoutHandle(null, Collections.emptyList(), TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columns are null") + public void testConstructorNullColumnHandles() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), null, TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "domain is null") + public void testConstructorNullTupleDomain() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), Collections.emptyList(), null); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java new file mode 100644 index 0000000000000..99e2d6fb96f46 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public class TestingArrowFactory + extends ArrowConnectorFactory +{ + public TestingArrowFactory() + { + super("arrow", new TestingArrowModule(), TestingArrowFactory.class.getClassLoader()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java new file mode 100644 index 0000000000000..7eb2bc88a7907 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +import javax.inject.Inject; + +public class TestingArrowFlightClientHandler + extends ArrowFlightClientHandler +{ + @Inject + public TestingArrowFlightClientHandler(ArrowFlightConfig config) + { + super(config); + } + + @Override + protected CredentialCallOption getCallOptions(ConnectorSession connectorSession) + { + return new CredentialCallOption(new BearerCredentialWriter(null)); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java new file mode 100644 index 0000000000000..c55eab7ae97ef --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; + +public class TestingArrowFlightConfig +{ + private String host; // non-static field + private String database; // non-static field + private String username; // non-static field + private String password; // non-static field + private String name; // non-static field + private Integer port; // non-static field + private Boolean ssl; + + public String getDataSourceHost() + { // non-static getter + return host; + } + + public String getDataSourceDatabase() + { // non-static getter + return database; + } + + public String getDataSourceUsername() + { // non-static getter + return username; + } + + public String getDataSourcePassword() + { // non-static getter + return password; + } + + public String getDataSourceName() + { // non-static getter + return name; + } + + public Integer getDataSourcePort() + { // non-static getter + return port; + } + + public Boolean getDataSourceSSL() + { // non-static getter + return ssl; + } + + @Config("data-source.host") + public TestingArrowFlightConfig setDataSourceHost(String host) + { // non-static setter + this.host = host; + return this; + } + + @Config("data-source.database") + public TestingArrowFlightConfig setDataSourceDatabase(String database) + { // non-static setter + this.database = database; + return this; + } + + @Config("data-source.username") + public TestingArrowFlightConfig setDataSourceUsername(String username) + { // non-static setter + this.username = username; + return this; + } + + @Config("data-source.password") + @ConfigSecuritySensitive + public TestingArrowFlightConfig setDataSourcePassword(String password) + { // non-static setter + this.password = password; + return this; + } + + @Config("data-source.name") + public TestingArrowFlightConfig setDataSourceName(String name) + { // non-static setter + this.name = name; + return this; + } + + @Config("data-source.port") + public TestingArrowFlightConfig setDataSourcePort(Integer port) + { // non-static setter + this.port = port; + return this; + } + + @Config("data-source.ssl") + public TestingArrowFlightConfig setDataSourceSSL(Boolean ssl) + { // non-static setter + this.ssl = ssl; + return this; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java new file mode 100644 index 0000000000000..dd019a7689cec --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +public class TestingArrowFlightRequest +{ + private final String schema; + private final String table; + private final Optional query; + private final ArrowFlightConfig config; + private final int noOfPartitions; + + private final TestingArrowFlightConfig testconfig; + + public TestingArrowFlightRequest(ArrowFlightConfig config, TestingArrowFlightConfig testconfig, String schema, String table, Optional query, int noOfPartitions) + { + this.config = config; + this.schema = schema; + this.table = table; + this.query = query; + this.testconfig = testconfig; + this.noOfPartitions = noOfPartitions; + } + + public TestingArrowFlightRequest(ArrowFlightConfig config, String schema, int noOfPartitions, TestingArrowFlightConfig testconfig) + { + this.schema = schema; + this.table = null; + this.query = Optional.empty(); + this.config = config; + this.testconfig = testconfig; + this.noOfPartitions = noOfPartitions; + } + + public String getSchema() + { + return schema; + } + + public String getTable() + { + return table; + } + + public Optional getQuery() + { + return query; + } + + public TestingRequestData build() + { + TestingRequestData requestData = new TestingRequestData(); + requestData.setConnectionProperties(getConnectionProperties()); + requestData.setInteractionProperties(createInteractionProperties()); + return requestData; + } + + public byte[] getCommand() + { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + try { + String jsonString = objectMapper.writeValueAsString(build()); + return jsonString.getBytes(StandardCharsets.UTF_8); + } + catch (JsonProcessingException e) { + throw new ArrowException(ArrowErrorCode.ARROW_FLIGHT_ERROR, "JSON request cannot be created.", e); + } + } + + private TestingConnectionProperties getConnectionProperties() + { + return new TestingConnectionProperties(testconfig.getDataSourceDatabase(), testconfig.getDataSourcePassword(), testconfig.getDataSourceHost(), testconfig.getDataSourceSSL(), testconfig.getDataSourceUsername()); + } + + private TestingInteractionProperties createInteractionProperties() + { + return getQuery().isPresent() ? new TestingInteractionProperties(getQuery().get(), getSchema(), getTable()) : new TestingInteractionProperties(null, getSchema(), getTable()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java new file mode 100644 index 0000000000000..3fad658bbbb58 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -0,0 +1,158 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.Result; +import org.apache.arrow.vector.types.pojo.Field; + +import javax.inject.Inject; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static java.util.Locale.ENGLISH; + +public class TestingArrowMetadata + extends AbstractArrowMetadata +{ + private static final Logger logger = Logger.get(TestingArrowMetadata.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + private final NodeManager nodeManager; + private final TestingArrowFlightConfig testConfig; + private final ArrowFlightClientHandler clientHandler; + private final ArrowFlightConfig config; + + @Inject + public TestingArrowMetadata(ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testConfig, ArrowFlightConfig config) + { + super(config, clientHandler); + this.nodeManager = nodeManager; + this.testConfig = testConfig; + this.clientHandler = clientHandler; + this.config = config; + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + List listSchemas = extractSchemaAndTableData(Optional.empty(), session); + List names = new ArrayList<>(); + for (String value : listSchemas) { + names.add(value.toLowerCase(ENGLISH)); + } + return ImmutableList.copyOf(names); + } + + @Override + public List listTables(ConnectorSession session, Optional schemaName) + { + String schemaValue = schemaName.orElse(""); + String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schemaValue); + List listTables = extractSchemaAndTableData(Optional.ofNullable(dataSourceSpecificSchemaName), session); + List tables = new ArrayList<>(); + for (String value : listTables) { + tables.add(new SchemaTableName(dataSourceSpecificSchemaName.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH))); + } + + return tables; + } + + public List extractSchemaAndTableData(Optional schema, ConnectorSession connectorSession) + { + try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { + List names = new ArrayList<>(); + TestingArrowFlightRequest request = getArrowFlightRequest(schema.orElse(null)); + ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); + + String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); + byte[] queryJsonBytes = modifiedQueryJson.getBytes(StandardCharsets.UTF_8); + Iterator iterator = client.getFlightClient().doAction(new Action("discovery", queryJsonBytes), clientHandler.getCallOptions(connectorSession)); + while (iterator.hasNext()) { + Result result = iterator.next(); + String jsonResult = new String(result.getBody(), StandardCharsets.UTF_8); + List tableNames = objectMapper.readValue(jsonResult, new TypeReference>() { + }); + names.addAll(tableNames); + } + return names; + } + catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + protected Type getPrestoTypeFromArrowField(Field field) + { + String columnLength = field.getMetadata().get("columnLength"); + int length = columnLength != null ? Integer.parseInt(columnLength) : 0; + + String nativeType = field.getMetadata().get("columnNativeType"); + + if ("CHAR".equals(nativeType) || "CHARACTER".equals(nativeType)) { + return CharType.createCharType(length); + } + else if ("VARCHAR".equals(nativeType)) { + return VarcharType.createVarcharType(length); + } + else if ("TIME".equals(nativeType)) { + return TimeType.TIME; + } + else { + return super.getPrestoTypeFromArrowField(field); + } + } + + @Override + protected String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName) + { + return schemaName; + } + + @Override + protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName) + { + return tableName; + } + + @Override + protected FlightDescriptor getFlightDescriptor(Optional query, String schema, String table) + { + TestingArrowFlightRequest request = new TestingArrowFlightRequest(this.config, testConfig, schema, table, query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); + } + + private TestingArrowFlightRequest getArrowFlightRequest(String schema) + { + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testConfig); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java new file mode 100644 index 0000000000000..cab5872a5507a --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; + +public class TestingArrowModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(TestingArrowFlightConfig.class); + binder.bind(ConnectorSplitManager.class).to(TestingArrowSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); + binder.bind(ConnectorMetadata.class).to(TestingArrowMetadata.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java new file mode 100644 index 0000000000000..cf67aaa983121 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorFactory; +import org.testng.annotations.Test; + +import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestingArrowPlugin +{ + @Test + public void testStartup() + { + ArrowModule testModule = new ArrowModule("arrow-flight"); + ArrowPlugin plugin = new ArrowPlugin("arrow-flight", testModule); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ArrowConnectorFactory.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java new file mode 100644 index 0000000000000..600f100a7e6ec --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.Range; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class TestingArrowQueryBuilder +{ + // not all databases support booleans, so use 1=1 and 1=0 instead + private static final String ALWAYS_TRUE = "1=1"; + private static final String ALWAYS_FALSE = "1=0"; + public static final String DATE_FORMAT = "yyyy-MM-dd"; + public static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; + public static final String TIME_FORMAT = "HH:mm:ss"; + public static final TimeZone UTC_TIME_ZONE = TimeZone.getTimeZone(ZoneId.of("UTC")); + + public String buildSql( + String schema, + String table, + List columns, + Map columnExpressions, + TupleDomain tupleDomain) + { + StringBuilder sql = new StringBuilder(); + + sql.append("SELECT "); + sql.append(addColumnExpression(columns, columnExpressions)); + + sql.append(" FROM "); + if (!isNullOrEmpty(schema)) { + sql.append(quote(schema)).append('.'); + } + sql.append(quote(table)); + + List accumulator = new ArrayList<>(); + + if (tupleDomain != null && !tupleDomain.isAll()) { + List clauses = toConjuncts(columns, tupleDomain, accumulator); + if (!clauses.isEmpty()) { + sql.append(" WHERE ") + .append(Joiner.on(" AND ").join(clauses)); + } + } + + return sql.toString(); + } + + public static String convertEpochToString(long epochValue, Type type) + { + if (type instanceof DateType) { + long millis = TimeUnit.DAYS.toMillis(epochValue); + Date date = new Date(millis); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + dateFormat.setTimeZone(UTC_TIME_ZONE); + return dateFormat.format(date); + } + else if (type instanceof TimestampType) { + Timestamp timestamp = new Timestamp(epochValue); + SimpleDateFormat timestampFormat = new SimpleDateFormat(TIMESTAMP_FORMAT); + timestampFormat.setTimeZone(UTC_TIME_ZONE); + return timestampFormat.format(timestamp); + } + else if (type instanceof TimeType) { + long millis = TimeUnit.SECONDS.toMillis(epochValue / 1000); + Time time = new Time(millis); + SimpleDateFormat timeFormat = new SimpleDateFormat(TIME_FORMAT); + timeFormat.setTimeZone(UTC_TIME_ZONE); + return timeFormat.format(time); + } + else { + throw new UnsupportedOperationException(type + " is not supported."); + } + } + + protected static class TypeAndValue + { + private final Type type; + private final Object value; + + public TypeAndValue(Type type, Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = requireNonNull(value, "value is null"); + } + + public Type getType() + { + return type; + } + + public Object getValue() + { + return value; + } + } + + private String addColumnExpression(List columns, Map columnExpressions) + { + if (columns.isEmpty()) { + return "null"; + } + + return columns.stream() + .map(ArrowColumnHandle -> { + String columnAlias = quote(ArrowColumnHandle.getColumnName()); + String expression = columnExpressions.get(ArrowColumnHandle.getColumnName()); + if (expression == null) { + return columnAlias; + } + return format("%s AS %s", expression, columnAlias); + }) + .collect(joining(", ")); + } + + private static boolean isAcceptedType(Type type) + { + Type validType = requireNonNull(type, "type is null"); + return validType.equals(BigintType.BIGINT) || + validType.equals(TinyintType.TINYINT) || + validType.equals(SmallintType.SMALLINT) || + validType.equals(IntegerType.INTEGER) || + validType.equals(DoubleType.DOUBLE) || + validType.equals(RealType.REAL) || + validType.equals(BooleanType.BOOLEAN) || + validType.equals(DateType.DATE) || + validType.equals(TimeType.TIME) || + validType.equals(TimestampType.TIMESTAMP) || + validType instanceof VarcharType || + validType instanceof CharType; + } + private List toConjuncts(List columns, TupleDomain tupleDomain, List accumulator) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (ArrowColumnHandle column : columns) { + Type type = column.getColumnType(); + if (isAcceptedType(type)) { + Domain domain = tupleDomain.getDomains().get().get(column); + if (domain != null) { + builder.add(toPredicate(column.getColumnName(), domain, column, accumulator)); + } + } + } + return builder.build(); + } + + private String toPredicate(String columnName, Domain domain, ArrowColumnHandle columnHandle, List accumulator) + { + checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); + + if (domain.getValues().isNone()) { + return domain.isNullAllowed() ? quote(columnName) + " IS NULL" : ALWAYS_FALSE; + } + + if (domain.getValues().isAll()) { + return domain.isNullAllowed() ? ALWAYS_TRUE : quote(columnName) + " IS NOT NULL"; + } + + List disjuncts = new ArrayList<>(); + List singleValues = new ArrayList<>(); + for (Range range : domain.getValues().getRanges().getOrderedRanges()) { + checkState(!range.isAll()); // Already checked + if (range.isSingleValue()) { + singleValues.add(range.getSingleValue()); + } + else { + List rangeConjuncts = new ArrayList<>(); + if (!range.isLowUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), columnHandle, accumulator)); + } + if (!range.isHighUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), columnHandle, accumulator)); + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); + } + } + + // Add back all of the possible single values either as an equality or an IN predicate + if (singleValues.size() == 1) { + disjuncts.add(toPredicate(columnName, "=", getOnlyElement(singleValues), columnHandle, accumulator)); + } + else if (singleValues.size() > 1) { + for (Object value : singleValues) { + bindValue(value, columnHandle, accumulator); + } + String values = Joiner.on(",").join(singleValues.stream().map(v -> + parameterValueToString(columnHandle.getColumnType(), v)).collect(Collectors.toList())); + disjuncts.add(quote(columnName) + " IN (" + values + ")"); + } + + // Add nullability disjuncts + checkState(!disjuncts.isEmpty()); + if (domain.isNullAllowed()) { + disjuncts.add(quote(columnName) + " IS NULL"); + } + + return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; + } + + private String toPredicate(String columnName, String operator, Object value, ArrowColumnHandle columnHandle, List accumulator) + { + bindValue(value, columnHandle, accumulator); + return quote(columnName) + " " + operator + " " + parameterValueToString(columnHandle.getColumnType(), value); + } + private String quote(String name) + { + return "\"" + name + "\""; + } + + private String quoteValue(String name) + { + return "'" + name + "'"; + } + + private void bindValue(Object value, ArrowColumnHandle columnHandle, List accumulator) + { + Type type = columnHandle.getColumnType(); + accumulator.add(new TypeAndValue(type, value)); + } + + public static String convertLongToFloatString(Long value) + { + float floatFromIntBits = intBitsToFloat(toIntExact(value)); + return String.valueOf(floatFromIntBits); + } + + private String parameterValueToString(Type type, Object value) + { + Class javaType = type.getJavaType(); + if (type instanceof DateType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimeType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimestampType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof RealType && javaType == long.class) { + return convertLongToFloatString((Long) value); + } + else if (javaType == boolean.class || javaType == double.class || javaType == long.class) { + return value.toString(); + } + else if (javaType == Slice.class) { + return quoteValue(((Slice) value).toStringUtf8()); + } + else { + return quoteValue(value.toString()); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java new file mode 100644 index 0000000000000..aeed01e6ca473 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestingArrowQueryRunner +{ + private static DistributedQueryRunner queryRunner; + private static final Logger logger = Logger.get(TestingArrowQueryRunner.class); + private TestingArrowQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner() throws Exception + { + if (queryRunner == null) { + queryRunner = createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); + } + return queryRunner; + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties, Class factoryClass) throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("testdb") + .build(); + + if (queryRunner == null) { + queryRunner = DistributedQueryRunner.builder(session).build(); + } + + try { + String connectorName = "arrow"; + queryRunner.installPlugin(new ArrowPlugin(connectorName, new TestingArrowModule())); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "127.0.0.1") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server.port", "9443") + .put("arrow-flight.server.verify", "false"); + + queryRunner.createCatalog(connectorName, connectorName, properties.build()); + + return queryRunner; + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java new file mode 100644 index 0000000000000..1363f2b2a3398 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java @@ -0,0 +1,310 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.ThreadLocalRandom; + +public class TestingArrowServer + implements FlightProducer +{ + private final RootAllocator allocator; + private final ObjectMapper objectMapper = new ObjectMapper(); + private static Connection connection; + private static final Logger logger = Logger.get(TestingArrowServer.class); + + public TestingArrowServer(RootAllocator allocator) throws Exception + { + this.allocator = allocator; + String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1"; + TestingH2DatabaseSetup.setup(h2JdbcUrl); + this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + try (Statement stmt = connection.createStatement()) { + // Convert ticket bytes to String and parse into JSON + String ticketString = new String(ticket.getBytes(), StandardCharsets.UTF_8); + JsonNode ticketJson = objectMapper.readTree(ticketString); + + // Extract interaction properties and validate + JsonNode interactionProperties = ticketJson.get("interactionProperties"); + if (interactionProperties == null || !interactionProperties.has("select_statement")) { + throw new IllegalArgumentException("Invalid ticket format: missing select_statement."); + } + + // Extract and validate the SQL query + String query = interactionProperties.get("select_statement").asText(); + if (query == null || query.trim().isEmpty()) { + throw new IllegalArgumentException("Query cannot be null or empty."); + } + + logger.info("Executing query: " + query); + query = query.toUpperCase(); // Optionally, to maintain consistency + + try (ResultSet rs = stmt.executeQuery(query)) { + JdbcToArrowConfigBuilder config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048) + .setCalendar(Calendar.getInstance(TimeZone.getDefault())); + ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(rs, config.build()); + boolean firstBatch = true; + + VectorLoader vectorLoader = null; + VectorSchemaRoot newRoot = null; + while (iterator.hasNext()) { + try (VectorSchemaRoot root = iterator.next()) { + VectorUnloader vectorUnloader = new VectorUnloader(root); + try (ArrowRecordBatch batch = vectorUnloader.getRecordBatch()) { + if (firstBatch) { + firstBatch = false; + newRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + vectorLoader = new VectorLoader(newRoot); + serverStreamListener.start(newRoot); + } + if (vectorLoader != null) { + vectorLoader.load(batch); + } + serverStreamListener.putNext(); + } + } + } + if (newRoot != null) { + newRoot.close(); + } + serverStreamListener.completed(); + } + } + // Handle Arrow processing errors + catch (IOException e) { + logger.error("Arrow data processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process Arrow data", e); + } + // Handle all other exceptions, including parsing errors + catch (Exception e) { + logger.error("Ticket processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process the ticket", e); + } + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + try { + String jsonRequest = new String(flightDescriptor.getCommand(), StandardCharsets.UTF_8); + JsonNode rootNode = objectMapper.readTree(jsonRequest); + + String schemaName = rootNode.get("interactionProperties").get("schema_name").asText(null); + String tableName = rootNode.get("interactionProperties").get("table_name").asText(null); + String selectStatement = rootNode.get("interactionProperties").get("select_statement").asText(null); + + List fields = new ArrayList<>(); + if (schemaName != null && tableName != null) { + String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " + + "WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " + + "AND TABLE_NAME='" + tableName.toUpperCase() + "'"; + + try (ResultSet rs = connection.createStatement().executeQuery(query)) { + while (rs.next()) { + String columnName = rs.getString("COLUMN_NAME"); + String dataType = rs.getString("TYPE_NAME"); + String charMaxLength = rs.getString("CHARACTER_MAXIMUM_LENGTH"); + int precision = rs.getInt("NUMERIC_PRECISION"); + int scale = rs.getInt("NUMERIC_SCALE"); + + ArrowType arrowType = convertSqlTypeToArrowType(dataType, precision, scale); + Map metaDataMap = new HashMap<>(); + metaDataMap.put("columnNativeType", dataType); + if (charMaxLength != null) { + metaDataMap.put("columnLength", charMaxLength); + } + FieldType fieldType = new FieldType(true, arrowType, null, metaDataMap); + Field field = new Field(columnName, fieldType, null); + fields.add(field); + } + } + } + else if (selectStatement != null) { + selectStatement = selectStatement.toUpperCase(); + logger.info("Executing SELECT query: " + selectStatement); + try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) { + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + + for (int i = 1; i <= columnCount; i++) { + String columnName = metaData.getColumnName(i); + String columnType = metaData.getColumnTypeName(i); + int precision = metaData.getPrecision(i); + int scale = metaData.getScale(i); + + ArrowType arrowType = convertSqlTypeToArrowType(columnType, precision, scale); + Field field = new Field(columnName, FieldType.nullable(arrowType), null); + fields.add(field); + } + } + } + else { + throw new IllegalArgumentException("Either schema_name/table_name or select_statement must be provided."); + } + + Schema schema = new Schema(fields); + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand())); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to retrieve FlightInfo", e); + } + } + + @Override + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public void doAction(CallContext callContext, Action action, StreamListener streamListener) + { + try { + String jsonRequest = new String(action.getBody(), StandardCharsets.UTF_8); + JsonNode rootNode = objectMapper.readTree(jsonRequest); + String schemaName = rootNode.get("interactionProperties").get("schema_name").asText(null); + + String query; + if (schemaName == null) { + query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"; + } + else { + schemaName = schemaName.toUpperCase(); + query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName + "'"; + } + ResultSet rs = connection.createStatement().executeQuery(query); + List names = new ArrayList<>(); + while (rs.next()) { + names.add(rs.getString(1)); + } + + String jsonResponse = objectMapper.writeValueAsString(names); + streamListener.onNext(new Result(jsonResponse.getBytes(StandardCharsets.UTF_8))); + streamListener.onCompleted(); + } + catch (Exception e) { + streamListener.onError(e); + } + } + + @Override + public void listActions(CallContext callContext, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int scale) + { + switch (sqlType.toUpperCase()) { + case "VARCHAR": + case "CHAR": + case "CHARACTER VARYING": + case "CHARACTER": + case "CLOB": + return new ArrowType.Utf8(); + case "INTEGER": + case "INT": + return new ArrowType.Int(32, true); + case "BIGINT": + return new ArrowType.Int(64, true); + case "SMALLINT": + return new ArrowType.Int(16, true); + case "TINYINT": + return new ArrowType.Int(8, true); + case "DOUBLE": + case "DOUBLE PRECISION": + case "FLOAT": + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case "REAL": + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "BOOLEAN": + return new ArrowType.Bool(); + case "DATE": + return new ArrowType.Date(DateUnit.DAY); + case "TIMESTAMP": + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); + case "TIME": + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case "DECIMAL": + case "NUMERIC": + return new ArrowType.Decimal(precision, scale); + case "BINARY": + case "VARBINARY": + return new ArrowType.Binary(); + case "NULL": + return new ArrowType.Null(); + default: + throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java new file mode 100644 index 0000000000000..34694863bc277 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.NodeManager; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; + +import javax.inject.Inject; + +import java.util.Optional; + +public class TestingArrowSplitManager + extends AbstractArrowSplitManager +{ + private TestingArrowFlightConfig testconfig; + + private final NodeManager nodeManager; + + @Inject + public TestingArrowSplitManager(ArrowFlightConfig config, ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) + { + super(client); + this.testconfig = testconfig; + this.nodeManager = nodeManager; + } + + @Override + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) + { + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + Optional query = Optional.of(new TestingArrowQueryBuilder().buildSql(tableHandle.getSchema(), + tableHandle.getTable(), + tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), + tableLayoutHandle.getTupleDomain())); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java new file mode 100644 index 0000000000000..e158fdc67707e --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import javax.annotation.concurrent.Immutable; + +@Immutable +public class TestingConnectionProperties +{ + private final String database; + private final String password; + private final String host; + private final Boolean ssl; + private final String username; + + public TestingConnectionProperties(String database, String password, String host, Boolean ssl, String username) + { + this.database = database; + this.password = password; + this.host = host; + this.ssl = ssl; + this.username = username; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java new file mode 100644 index 0000000000000..7f21afc8c69fa --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java @@ -0,0 +1,273 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.tpch.TpchMetadata; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.base.Joiner; +import io.airlift.tpch.TpchTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.statement.PreparedBatch; +import org.joda.time.DateTimeZone; + +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; +import static java.lang.String.format; +import static java.util.Collections.nCopies; + +public class TestingH2DatabaseSetup +{ + private static final Logger logger = Logger.get(TestingH2DatabaseSetup.class); + private TestingH2DatabaseSetup() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static void setup(String h2JdbcUrl) throws Exception + { + Class.forName("org.h2.Driver"); + + Connection conn = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + + Jdbi jdbi = Jdbi.create(h2JdbcUrl, "sa", ""); + Handle handle = jdbi.open(); // Get a handle for the database connection + + TpchMetadata tpchMetadata = new TpchMetadata(""); + + Statement stmt = conn.createStatement(); + + // Create schema + stmt.execute("CREATE SCHEMA IF NOT EXISTS testdb"); + + stmt.execute("CREATE TABLE testdb.member (" + + " id INTEGER PRIMARY KEY," + + " name VARCHAR(50)," + + " sex CHAR(1)," + + " state CHAR(5)" + + ")"); + stmt.execute("INSERT INTO testdb.member VALUES(1, 'TOM', 'M', 'TMX '),(2, 'MARY', 'F', 'CD ')"); + + stmt.execute("CREATE TABLE testdb.event (" + + " id INTEGER PRIMARY KEY," + + " startDate DATE," + + " startTime TIME," + + " startTimestamp TIMESTAMP" + + ")"); + stmt.execute("INSERT INTO testdb.event VALUES(1, DATE '2004-12-31', TIME '23:59:59'," + + " TIMESTAMP '2005-12-31 23:59:59')"); + + stmt.execute("CREATE TABLE testdb.orders (\n" + + " orderkey BIGINT PRIMARY KEY,\n" + + " custkey BIGINT NOT NULL,\n" + + " orderstatus VARCHAR(1) NOT NULL,\n" + + " totalprice DOUBLE NOT NULL,\n" + + " orderdate DATE NOT NULL,\n" + + " orderpriority VARCHAR(15) NOT NULL,\n" + + " clerk VARCHAR(15) NOT NULL,\n" + + " shippriority INTEGER NOT NULL,\n" + + " comment VARCHAR(79) NOT NULL\n" + + ")"); + stmt.execute("CREATE INDEX custkey_index ON testdb.orders (custkey)"); + insertRows(tpchMetadata, ORDERS, handle); + + handle.execute("CREATE TABLE testdb.lineitem (\n" + + " orderkey BIGINT,\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " linenumber INTEGER,\n" + + " quantity DOUBLE NOT NULL,\n" + + " extendedprice DOUBLE NOT NULL,\n" + + " discount DOUBLE NOT NULL,\n" + + " tax DOUBLE NOT NULL,\n" + + " returnflag CHAR(1) NOT NULL,\n" + + " linestatus CHAR(1) NOT NULL,\n" + + " shipdate DATE NOT NULL,\n" + + " commitdate DATE NOT NULL,\n" + + " receiptdate DATE NOT NULL,\n" + + " shipinstruct VARCHAR(25) NOT NULL,\n" + + " shipmode VARCHAR(10) NOT NULL,\n" + + " comment VARCHAR(44) NOT NULL,\n" + + " PRIMARY KEY (orderkey, linenumber)" + + ")"); + insertRows(tpchMetadata, LINE_ITEM, handle); + + handle.execute(" CREATE TABLE testdb.partsupp (\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " availqty INTEGER NOT NULL,\n" + + " supplycost DOUBLE NOT NULL,\n" + + " comment VARCHAR(199) NOT NULL,\n" + + " PRIMARY KEY(partkey, suppkey)" + + ")"); + insertRows(tpchMetadata, PART_SUPPLIER, handle); + + handle.execute("CREATE TABLE testdb.nation (\n" + + " nationkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " regionkey BIGINT NOT NULL,\n" + + " comment VARCHAR(152) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, NATION, handle); + + handle.execute("CREATE TABLE testdb.region(\n" + + " regionkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " comment VARCHAR(115) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, REGION, handle); + handle.execute("CREATE TABLE testdb.part(\n" + + " partkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(55) NOT NULL,\n" + + " mfgr VARCHAR(25) NOT NULL,\n" + + " brand VARCHAR(10) NOT NULL,\n" + + " type VARCHAR(25) NOT NULL,\n" + + " size INTEGER NOT NULL,\n" + + " container VARCHAR(10) NOT NULL,\n" + + " retailprice DOUBLE NOT NULL,\n" + + " comment VARCHAR(23) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, PART, handle); + handle.execute(" CREATE TABLE testdb.customer ( \n" + + " custkey BIGINT NOT NULL, \n" + + " name VARCHAR(25) NOT NULL, \n" + + " address VARCHAR(40) NOT NULL, \n" + + " nationkey BIGINT NOT NULL, \n" + + " phone VARCHAR(15) NOT NULL, \n" + + " acctbal DOUBLE NOT NULL, \n" + + " mktsegment VARCHAR(10) NOT NULL, \n" + + " comment VARCHAR(117) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, CUSTOMER, handle); + handle.execute(" CREATE TABLE testdb.supplier ( \n" + + " suppkey bigint NOT NULL, \n" + + " name varchar(25) NOT NULL, \n" + + " address varchar(40) NOT NULL, \n" + + " nationkey bigint NOT NULL, \n" + + " phone varchar(15) NOT NULL, \n" + + " acctbal double NOT NULL, \n" + + " comment varchar(101) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, SUPPLIER, handle); + + ResultSet resultSet1 = stmt.executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'TESTDB'"); + List tables = new ArrayList<>(); + while (resultSet1.next()) { + String tableName = resultSet1.getString("TABLE_NAME"); + tables.add(tableName); + } + logger.info("Tables in 'testdb' schema: %s", tables.stream().collect(Collectors.joining(", "))); + + ResultSet resultSet = stmt.executeQuery("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"); + List schemas = new ArrayList<>(); + while (resultSet.next()) { + String schemaName = resultSet.getString("SCHEMA_NAME"); + schemas.add(schemaName); + } + logger.info("Schemas: %s", schemas.stream().collect(Collectors.joining(", "))); + } + + private static void insertRows(TpchMetadata tpchMetadata, TpchTable tpchTable, Handle handle) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(null, new SchemaTableName(TINY_SCHEMA_NAME, tpchTable.getTableName())); + insertRows(tpchMetadata.getTableMetadata(null, tableHandle), handle, createTpchRecordSet(tpchTable, tableHandle.getScaleFactor())); + } + + private static void insertRows(ConnectorTableMetadata tableMetadata, Handle handle, RecordSet data) + { + List columns = tableMetadata.getColumns().stream() + .filter(columnMetadata -> !columnMetadata.isHidden()) + .collect(toImmutableList()); + + String schemaName = "testdb"; + String tableNameWithSchema = schemaName + "." + tableMetadata.getTable().getTableName(); + String vars = Joiner.on(',').join(nCopies(columns.size(), "?")); + String sql = format("INSERT INTO %s VALUES (%s)", tableNameWithSchema, vars); + + RecordCursor cursor = data.cursor(); + while (true) { + // insert 1000 rows at a time + PreparedBatch batch = handle.prepareBatch(sql); + for (int row = 0; row < 1000; row++) { + if (!cursor.advanceNextPosition()) { + if (batch.size() > 0) { + batch.execute(); + } + return; + } + for (int column = 0; column < columns.size(); column++) { + Type type = columns.get(column).getType(); + if (BOOLEAN.equals(type)) { + batch.bind(column, cursor.getBoolean(column)); + } + else if (BIGINT.equals(type)) { + batch.bind(column, cursor.getLong(column)); + } + else if (INTEGER.equals(type)) { + batch.bind(column, (int) cursor.getLong(column)); + } + else if (DOUBLE.equals(type)) { + batch.bind(column, cursor.getDouble(column)); + } + else if (type instanceof VarcharType) { + batch.bind(column, cursor.getSlice(column).toStringUtf8()); + } + else if (DATE.equals(type)) { + long millisUtc = TimeUnit.DAYS.toMillis(cursor.getLong(column)); + // H2 expects dates in to be millis at midnight in the JVM timezone + long localMillis = DateTimeZone.UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millisUtc); + batch.bind(column, new Date(localMillis)); + } + else { + throw new IllegalArgumentException("Unsupported type " + type); + } + } + batch.add(); + } + batch.execute(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java new file mode 100644 index 0000000000000..61d36d43fe575 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +@Immutable +public class TestingInteractionProperties +{ + @JsonProperty("select_statement") + private final String selectStatement; + + @JsonProperty("schema_name") + private final String schema; + + @JsonProperty("table_name") + private final String table; + + // Constructor to initialize the fields + public TestingInteractionProperties(String selectStatement, String schema, String table) + { + this.selectStatement = selectStatement; + this.schema = schema; + this.table = table; + } + + // Getters (no setters as the fields are final and immutable) + public String getSelectStatement() + { + return selectStatement; + } + + public String getSchema() + { + return schema; + } + + public String getTable() + { + return table; + } + + // No setters as the class is immutable +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java new file mode 100644 index 0000000000000..aee379dc04311 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public class TestingRequestData +{ + private TestingConnectionProperties connectionProperties; + private TestingInteractionProperties interactionProperties; + + public TestingConnectionProperties getConnectionProperties() + { + return connectionProperties; + } + + public TestingInteractionProperties getInteractionProperties() + { + return interactionProperties; + } + + public void setConnectionProperties(TestingConnectionProperties connectionProperties) + { + this.connectionProperties = connectionProperties; + } + + public void setInteractionProperties(TestingInteractionProperties interactionProperties) + { + this.interactionProperties = interactionProperties; + } +} diff --git a/presto-base-arrow-flight/src/test/resources/server.crt b/presto-base-arrow-flight/src/test/resources/server.crt new file mode 100644 index 0000000000000..9af01752e2959 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUBYcCyz9qphcpgV9wIx2caf4ikoEwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNDA4MzAxMzM0MjJaFw0yNTA4 +MzAxMzM0MjJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCTu4vGRFav9/aNuGorx2v+xAevMRFfSi6x2onRCpMp +u+9L3bK3em+vSJcC/+QwvNkvTmf0ac8/G79BXQiUtPSV3tDSFnpO7LIICWpO2gJX +UKjTN+SZ3LfLEXSYAijVUCkWf8RswJLvJ12qNQRiv4IDrpBh/X+ICr+ALMa7y6cn +hBfpFtVyxlLkCp6Q71JgMGDjocv2195pc4uiMEDsn6tZfEKxmw1phX+CLke42srX ++AHoBQvlfcOBVIuMkxU7pZvvDDOuw0tSm0qomjRV+Azjn4oWbPDfkIEZIbBvgLwN +nxtKzSAgW+7xWQ0eVMRq1f5JPk6vLAV/yit8lFk8TCtdAgMBAAGjUzBRMB0GA1Ud +DgQWBBS4K4nHKbp4qw77KxR+tqUWXKCADjAfBgNVHSMEGDAWgBS4K4nHKbp4qw77 +KxR+tqUWXKCADjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBI +yraxUUgEb53q0aF2ID1ml3/9uMZALK4xQHKYP0dtkqdaS0CvuMsqSe0AT7PMDJOD +Q+s1lxe0gJ/lb7dLSOQpFGAU7nnUsBaL5alCGmYgeF6ayLzNWobsn4amnjFpSlOH +P/nOQvOrRgVbphozhaYBSjMtwy2Uzj8G7ZHn+Vg/fdEUf/mGb7/4M9sL7iQWJmIg +slzNnA/LssAqGvyJ65cSkjGrb82VvDJv6JZ7fam6nrMEpq5D4H4GzOwupNvdItLH +Zf9IBsedkDEu8P+Rzp9kGIMbGTLfL3u/pP1st+pD4PfHYRsbl8aAUqYA7P+d29mf +aVcp98Z1h1s5e005BT1T +-----END CERTIFICATE----- diff --git a/presto-base-arrow-flight/src/test/resources/server.key b/presto-base-arrow-flight/src/test/resources/server.key new file mode 100644 index 0000000000000..894dc1fad8e61 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCTu4vGRFav9/aN +uGorx2v+xAevMRFfSi6x2onRCpMpu+9L3bK3em+vSJcC/+QwvNkvTmf0ac8/G79B +XQiUtPSV3tDSFnpO7LIICWpO2gJXUKjTN+SZ3LfLEXSYAijVUCkWf8RswJLvJ12q +NQRiv4IDrpBh/X+ICr+ALMa7y6cnhBfpFtVyxlLkCp6Q71JgMGDjocv2195pc4ui +MEDsn6tZfEKxmw1phX+CLke42srX+AHoBQvlfcOBVIuMkxU7pZvvDDOuw0tSm0qo +mjRV+Azjn4oWbPDfkIEZIbBvgLwNnxtKzSAgW+7xWQ0eVMRq1f5JPk6vLAV/yit8 +lFk8TCtdAgMBAAECggEAAKgSdrLajMUmFhql9CRafUMbQqLN8DW47+bn+mMY5NRW +O6jUUL7tTKLesu92sOXB9FUdnqcyudXSe4Shk2GbfagEFw7tA7lHEESUcZ3D6WXt +HiUvMaTatz8QXNWTn2EQEa7HLXGMpZ3v61/5cUPnHMOTli/ld3IOyE/KoU6GI2WP +5+8eKJUdJiZDLdywLTF/OefTuHhnbkNxsoPtR1DwO1UZKsX1OFo0b3x6ibqPZSu3 +x01QkYfTlUoVjUBNFQTcyfXKWX5rrAK9C/2h7TVJxbktj7V4DOGLb8ee6oN+yr4e +lIg6Vfz/KAakXKaRVappoxSxvFTZXqGhL8kBnLPFoQKBgQDIfr3ir65pi2Gzc+6e +0DWCkznfDqVvfgzoAxCrIqKp/6MqVUIJmProPjBztSAhiaNhK8sDbLFBssw0AOeZ +YEqMcY69ClyVrtDoGYBNeuCTTz3S8yXf0mzZe+CgWA6vNnbzEgKe7NqB+EbGjLnH +BcNoSE1GuBjt7Juvuor8P9XW/QKBgQC8oXtd0gQmpIyApgTXtE4AdbckPZE6sX+5 +aWWk9RY7yL9ivx86ej+gG2oj0YqmBtCbcabrYdF2Dv1idSpr+UoW5cxMa+jfAwSQ +8WNeUOeWk1OfEMrnbp8XLZMTDMGMTTyHdPMNpXq0zw7bvIReiRzeFb7fEgdE2+q+ +Me+GHU5D4QKBgDK+eTrFciQ+ZbTwk6VYVyK8NnpxD4f/ZC7Yj8BwnLDgBaDyQSuC +r4ZWLxcp8X7rghFW7yPnv5k8Mpi63eMgzt1q5FCOLc6olzEXOzTg87P061XXum9C +p9AHnVuXzeekpkhw937XvZoFh4w7E83+dG2RVxWeBJk7OFAqq4Cae3nVAoGAYXOD +yqqvnk8wj1417kKmcbJfFYgBObNt6xo6ewhrniNOTPO0bH+v00WWhj7BRJkMuOH0 +fHKixj1kRrOFYRb/YekCrRCq1Fw4xbEPxzBBFRe0Ad+pE/ugkVboPtU+QP++H7UZ +xJkTVcoLQRaZxEVN9qaBX7luq/J5yhz+Q+lr/8ECgYEAsDUXgtuRyMvKr1PFzrfQ +n/aVb4wmY39Zyr1sERHwHYLUvX2gsiNMaUq7lFWUNBGy0x0gcaMCWOlxXT3SMUP9 +aE3LCjrkUWP69/RuFjKdcustLGOtJmnr9ioSpZkkrR49tmiu1SWFsK2IunTd4GUg +/p/ITUEsA2oiyDZcpTkWx0Q= +-----END PRIVATE KEY----- diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index f07506b460cab..d337fe4ed12d1 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -9,6 +9,7 @@ from different data sources. :maxdepth: 1 connector/accumulo + connector/base-arrow-flight connector/bigquery connector/blackhole connector/cassandra diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst new file mode 100644 index 0000000000000..0a5590c320090 --- /dev/null +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -0,0 +1,95 @@ + +====================== +Arrow Flight Connector +====================== +This connector allows querying multiple data sources that are supported by an Arrow Flight server. Flight supports parallel transfers, allowing data to be streamed to or from a cluster of servers simultaneously. Official documentation for Arrow Flight can be found at https://arrow.apache.org/docs/format/Flight.html + +Getting Started with base-arrow-module: Essential Abstract Methods for Developers +--------------------------------------------------------------------------------- +To use the base-arrow-module, you need to implement certain abstract methods that are specific to your use case. Below are the required classes and their purposes: + +* ``ArrowFlightClientHandler.java`` + This class is responsible for initializing the Flight client and retrieving Flight information from the Flight server. To authenticate the Flight server, you must implement the abstract method ``getCallOptions`` in ArrowFlightClientHandler, which returns the ``CredentialCallOption`` specific to your Flight server. + +* ``AbstractArrowFlightRequest.java`` + Implement this class to define the request data, including the data source type, connection properties, the number of partitions and other data required to interact with database. + +* ``AbstractArrowMetadata.java`` + To retrieve metadata (schema and table information), implement the abstract methods in the ArrowAbstractMetadata class. + +* ``AbstractArrowSplitManager.java`` + Extend the ArrowAbstractSplitManager class to implement the Arrow Flight request, defining the Arrow split. + +* ``ArrowPlugin.java`` + Register your connector name by extending the ArrowPlugin class. + +* ``ArrowFlightRequest`` + The ``getCommand`` method in the ``ArrowFlightRequest`` interface should return a byte array for the Flight request. + + +Configuration +------------- +Create a catalog file +in ``etc/catalog`` named, for example, ``arrowmariadb.properties``, to +mount the Flight connector as the ``arrowmariadb`` catalog. +Create the file with the following contents, replacing the +connection properties as appropriate for your setup: + + +.. code-block:: none + + + connector.name= + arrow-flight.server= + arrow-flight.server.port= + + + +Add other properties that are required for your Flight server to connect. + +========================================== ============================================================== +Property Name Description +========================================== ============================================================== +``arrow-flight.server`` Endpoint of Arrow Flight server +``arrow-flight.server.port`` Flight server port +``arrow-flight.server-ssl-certificate`` Pass ssl certificate +``arrow-flight.server.verify`` To verify server +``arrow-flight.server-ssl-enabled`` Port is ssl enabled +========================================== ============================================================== + +Querying Arrow-Flight +--------------------- + +The Flight connector provides schema for each supported *database*. +Example for MariaDB is shown below. +To see the available schemas, run ``SHOW SCHEMAS``:: + + SHOW SCHEMAS FROM arrowmariadb; + +To view the tables in the MariaDB database named ``user``, +run ``SHOW TABLES``:: + + SHOW TABLES FROM arrowmariadb.user; + +To see a list of the columns in the ``admin`` table in the ``user`` database, +use either of the following commands:: + + DESCRIBE arrowmariadb.user.admin; + SHOW COLUMNS FROM arrowmariadb.user.admin; + +Finally, you can access the ``admin`` table in the ``user`` database:: + + SELECT * FROM arrowmariadb.user.admin; + +If you used a different name for your catalog properties file, use +that catalog name instead of ``arrowmariadb`` in the above examples. + + +Flight Connector Limitations +---------------------------- + +* SELECT and DESCRIBE queries are supported by this connector template. Implementing modules can add support for additional features. + +* Flight connector can query against only those datasources which are supported by the Flight server. + +* The user should have the Flight server running for the Flight connector to work.