diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f8f4cfa2f9b8e..47984fd85bc5d 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -523,6 +523,25 @@ class FlightClient::FlightClientImpl { return Status::OK(); } + Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, + std::unique_ptr* schema_result) { + pb::FlightDescriptor pb_descriptor; + pb::SchemaResult pb_response; + + RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + Status s = internal::FromGrpcStatus( + stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response)); + RETURN_NOT_OK(s); + + std::string str; + RETURN_NOT_OK(internal::FromProto(pb_response, &str)); + schema_result->reset(new SchemaResult(str)); + return Status::OK(); + } + Status DoGet(const FlightCallOptions& options, const Ticket& ticket, std::unique_ptr* out) { pb::Ticket pb_ticket; @@ -595,6 +614,12 @@ Status FlightClient::GetFlightInfo(const FlightCallOptions& options, return impl_->GetFlightInfo(options, descriptor, info); } +Status FlightClient::GetSchema(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* schema_result) { + return impl_->GetSchema(options, descriptor, schema_result); +} + Status FlightClient::ListFlights(std::unique_ptr* listing) { return ListFlights({}, {}, listing); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 774b0f6f56bb3..82fd8aa6aef63 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -167,6 +167,20 @@ class ARROW_FLIGHT_EXPORT FlightClient { return GetFlightInfo({}, descriptor, info); } + /// \brief Request schema for a single flight, which may be an existing + /// dataset or a command to be executed + /// \param[in] options Per-RPC options + /// \param[in] descriptor the dataset request, whether a named dataset or + /// command + /// \param[out] schema_result the SchemaResult describing the dataset schema + /// \return Status + Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, + std::unique_ptr* schema_result); + Status GetSchema(const FlightDescriptor& descriptor, + std::unique_ptr* schema_result) { + return GetSchema({}, descriptor, schema_result); + } + /// \brief List all available flights known to the server /// \param[out] listing an iterator that returns a FlightInfo for each flight /// \return Status diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 24420068f4113..c3c17e9394d17 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -556,6 +556,17 @@ TEST_F(TestFlightClient, GetFlightInfo) { AssertEqual(flights[0], *info); } +TEST_F(TestFlightClient, GetSchema) { + auto descr = FlightDescriptor::Path({"examples", "ints"}); + std::unique_ptr schema_result; + std::shared_ptr schema; + ipc::DictionaryMemo dict_memo; + + ASSERT_OK(client_->GetSchema(descr, &schema_result)); + ASSERT_NE(schema_result, nullptr); + ASSERT_OK(schema_result->GetSchema(&dict_memo, &schema)); +} + TEST_F(TestFlightClient, GetFlightInfoNotFound) { auto descr = FlightDescriptor::Path({"examples", "things"}); std::unique_ptr info; diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index 70033b1a70d82..1175cd67edc24 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -314,6 +314,11 @@ Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) { return Status::OK(); } +Status FromProto(const pb::SchemaResult& pb_result, std::string* result) { + *result = pb_result.schema(); + return Status::OK(); +} + Status SchemaToString(const Schema& schema, std::string* out) { // TODO(wesm): Do we care about better memory efficiency here? std::shared_ptr serialized_schema; @@ -344,6 +349,11 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) { return Status::OK(); } +Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result) { + pb_result->set_schema(result.serialized_schema()); + return Status::OK(); +} + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index f617d83634ccd..baf5c8f1a0551 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -89,12 +89,14 @@ Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info); +Status FromProto(const pb::SchemaResult& pb_result, std::string* result); Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info); Status ToProto(const ActionType& type, pb::ActionType* pb_type); Status ToProto(const Action& action, pb::Action* pb_action); Status ToProto(const Result& result, pb::Result* pb_result); +Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result); void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket); } // namespace internal diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 86e183f9da722..6794a356212ec 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -344,6 +344,27 @@ class FlightServiceImpl : public FlightService::Service { return grpc::Status::OK; } + grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, + pb::SchemaResult* response) { + CHECK_ARG_NOT_NULL(request, "FlightDescriptor cannot be null"); + GrpcServerCallContext flight_context; + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(context, flight_context)); + + FlightDescriptor descr; + GRPC_RETURN_NOT_OK(internal::FromProto(*request, &descr)); + + std::unique_ptr result; + GRPC_RETURN_NOT_OK(server_->GetSchema(flight_context, descr, &result)); + + if (!result) { + // Treat null listing as no flights available + return grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found"); + } + + GRPC_RETURN_NOT_OK(internal::ToProto(*result, response)); + return grpc::Status::OK; + } + grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, ServerWriter* writer) { CHECK_ARG_NOT_NULL(request, "ticket cannot be null"); @@ -627,6 +648,12 @@ Status FlightServerBase::ListActions(const ServerCallContext& context, return Status::NotImplemented("NYI"); } +Status FlightServerBase::GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema) { + return Status::NotImplemented("NYI"); +} + // ---------------------------------------------------------------------- // Implement RecordBatchStream diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index d53eb4378bfac..dfecae076cb74 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -166,6 +166,15 @@ class ARROW_FLIGHT_EXPORT FlightServerBase { const FlightDescriptor& request, std::unique_ptr* info); + /// \brief Retrieve the schema for the indicated descriptor + /// \param[in] context The call context. + /// \param[in] request may be null + /// \param[out] schema the returned flight schema provider + /// \return Status + virtual Status GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema); + /// \brief Get a stream of IPC payloads to put on the wire /// \param[in] context The call context. /// \param[in] request an opaque ticket diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index b8d59f9ba0869..4d4cf9e663669 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -244,6 +244,20 @@ class FlightTestServer : public FlightServerBase { *out = std::move(actions); return Status::OK(); } + + Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* schema) override { + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *schema = + std::unique_ptr(new SchemaResult(info.serialized_schema())); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); + } }; std::unique_ptr ExampleTestServer() { diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 91e1f3d9deb99..4617524ce2aa8 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -118,6 +118,13 @@ std::string FlightDescriptor::ToString() const { return ss.str(); } +Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, + std::shared_ptr* out) const { + io::BufferReader schema_reader(raw_schema_); + RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo, out)); + return Status::OK(); +} + Status FlightDescriptor::SerializeToString(std::string* out) const { pb::FlightDescriptor pb_descriptor; RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor)); diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 152d88820bcad..ae8e61775893f 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -319,6 +319,24 @@ struct ARROW_FLIGHT_EXPORT FlightPayload { ipc::internal::IpcPayload ipc_message; }; +/// \brief Schema result returned after a schema request RPC +struct ARROW_FLIGHT_EXPORT SchemaResult { + public: + explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} + + /// \brief return schema + /// \param[in,out] dictionary_memo for dictionary bookkeeping, will + /// be modified + /// \param[out] out the reconstructed Schema + Status GetSchema(ipc::DictionaryMemo* dictionary_memo, + std::shared_ptr* out) const; + + const std::string& serialized_schema() const { return raw_schema_; } + + private: + std::string raw_schema_; +}; + /// \brief The access coordinates for retireval of a dataset, returned by /// GetFlightInfo class ARROW_FLIGHT_EXPORT FlightInfo { diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index bbfc584e20f36..7c0fa48cbdaba 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -106,6 +106,15 @@ Status PyFlightServer::GetFlightInfo(const arrow::flight::ServerCallContext& con }); } +Status PyFlightServer::GetSchema(const arrow::flight::ServerCallContext& context, + const arrow::flight::FlightDescriptor& request, + std::unique_ptr* result) { + return SafeCallIntoPython([&] { + vtable_.get_schema(server_.obj(), context, request, result); + return CheckPyError(); + }); +} + Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context, const arrow::flight::Ticket& request, std::unique_ptr* stream) { @@ -245,6 +254,16 @@ Status CreateFlightInfo(const std::shared_ptr& schema, return Status::OK(); } +Status CreateSchemaResult(const std::shared_ptr& schema, + std::unique_ptr* out) { + std::string schema_in; + RETURN_NOT_OK(arrow::flight::internal::SchemaToString(*schema, &schema_in)); + arrow::flight::SchemaResult value(schema_in); + *out = std::unique_ptr( + new arrow::flight::SchemaResult(value)); + return Status::OK(); +} + } // namespace flight } // namespace py } // namespace arrow diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index 3c5dc5ffee80f..35d23d3e2badb 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -45,6 +45,10 @@ class ARROW_PYTHON_EXPORT PyFlightServerVtable { const arrow::flight::FlightDescriptor&, std::unique_ptr*)> get_flight_info; + std::function*)> + get_schema; std::function*)> @@ -120,6 +124,9 @@ class ARROW_PYTHON_EXPORT PyFlightServer : public arrow::flight::FlightServerBas Status GetFlightInfo(const arrow::flight::ServerCallContext& context, const arrow::flight::FlightDescriptor& request, std::unique_ptr* info) override; + Status GetSchema(const arrow::flight::ServerCallContext& context, + const arrow::flight::FlightDescriptor& request, + std::unique_ptr* result) override; Status DoGet(const arrow::flight::ServerCallContext& context, const arrow::flight::Ticket& request, std::unique_ptr* stream) override; @@ -205,6 +212,11 @@ Status CreateFlightInfo(const std::shared_ptr& schema, int64_t total_records, int64_t total_bytes, std::unique_ptr* out); +/// \brief Create a SchemaResult from schema. +ARROW_PYTHON_EXPORT +Status CreateSchemaResult(const std::shared_ptr& schema, + std::unique_ptr* out); + } // namespace flight } // namespace py } // namespace arrow diff --git a/format/Flight.proto b/format/Flight.proto index 0c8f28e5315eb..a9691220969a2 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -61,6 +61,14 @@ service FlightService { */ rpc GetFlightInfo(FlightDescriptor) returns (FlightInfo) {} + /* + * For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + * This is used when a consumer needs the Schema of flight stream. Similar to + * GetFlightInfo this interface may generate a new flight that was not previously + * available in ListFlights. + */ + rpc GetSchema(FlightDescriptor) returns (SchemaResult) {} + /* * Retrieve a single stream associated with a particular descriptor * associated with the referenced ticket. A Flight can be composed of one or @@ -169,6 +177,14 @@ message Result { bytes body = 1; } +/* + * Wrap the result of a getSchema call + */ +message SchemaResult { + // schema of the dataset as described in Schema.fbs::Schema. + bytes schema = 1; +} + /* * The name or tag for a Flight. May be used as a way to retrieve or generate * a flight or be used to expose a set of previously defined flights. diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 8876a49089562..b35bf0d1dc5fc 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -221,6 +221,15 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { } } + /** + * Get schema for a stream. + * @param descriptor The descriptor for the stream. + * @param options RPC-layer hints for this call. + */ + public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { + return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol())); + } + /** * Retrieve a stream from the server. * @param ticket The ticket granting access to the data stream. diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java index fdb5e9f586cd1..aa3278b421368 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -55,6 +55,19 @@ void listFlights(CallContext context, Criteria criteria, */ FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor); + /** + * Get schema for a particular data stream. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Schema for the stream. + */ + default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { + FlightInfo info = getFlightInfo(context, descriptor); + return new SchemaResult(info.getSchema()); + } + + /** * Accept uploaded data for a particular stream. * diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java index 056b3e651210e..31c6561f51b09 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java @@ -223,6 +223,19 @@ public void getFlightInfo(Flight.FlightDescriptor request, StreamObserver responseObserver) { + try { + SchemaResult result = producer + .getSchema(makeContext((ServerCallStreamObserver) responseObserver), + new FlightDescriptor(request)); + responseObserver.onNext(result.toProtocol()); + responseObserver.onCompleted(); + } catch (Exception ex) { + responseObserver.onError(StatusUtils.toGrpcException(ex)); + } + } + /** * Call context for the service. */ diff --git a/java/flight/src/main/java/org/apache/arrow/flight/SchemaResult.java b/java/flight/src/main/java/org/apache/arrow/flight/SchemaResult.java new file mode 100644 index 0000000000000..764f4c70f33be --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/SchemaResult.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +/** + * Opaque result returned after executing a getSchema request. + * + *

POJO wrapper around the Flight protocol buffer message sharing the same name. + */ +public class SchemaResult { + + private final Schema schema; + + public SchemaResult(Schema schema) { + this.schema = schema; + } + + + public Schema getSchema() { + return schema; + } + + /** + * Converts to the protocol buffer representation. + */ + Flight.SchemaResult toProtocol() { + // Encode schema in a Message payload + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema); + } catch (IOException e) { + throw new RuntimeException(e); + } + return Flight.SchemaResult.newBuilder() + .setSchema(ByteString.copyFrom(baos.toByteArray())) + .build(); + + } + + /** + * Converts from the protocol buffer representation. + */ + static SchemaResult fromProtocol(Flight.SchemaResult pbSchemaResult) { + try { + final ByteBuffer schemaBuf = pbSchemaResult.getSchema().asReadOnlyByteBuffer(); + Schema schema = pbSchemaResult.getSchema().size() > 0 ? + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) + : new Schema(ImmutableList.of()); + return new SchemaResult(schema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 634e38fc876dc..2e214a4f1e314 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -126,6 +126,14 @@ public void getDescriptor() throws Exception { }); } + @Test + public void getSchema() throws Exception { + test(c -> { + System.out.println(c.getSchema(FlightDescriptor.path("hello")).getSchema()); + }); + } + + @Test public void listActions() throws Exception { test(c -> { diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index af7e1ed7026b5..147d40794ee0b 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -508,6 +508,34 @@ cdef class FlightEndpoint: return self.endpoint == other.endpoint +cdef class SchemaResult: + """A result from a getschema request. Holding a schema""" + cdef: + unique_ptr[CSchemaResult] result + + def __init__(self, Schema schema): + """Create a SchemaResult from a schema. + + Parameters + ---------- + schema: Schema + the schema of the data in this flight. + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + check_status(CreateSchemaResult(c_schema, &self.result)) + + @property + def schema(self): + """The schema of the data in this flight.""" + cdef: + shared_ptr[CSchema] schema + CDictionaryMemo dummy_memo + + check_status(self.result.get().GetSchema(&dummy_memo, &schema)) + return pyarrow_wrap_schema(schema) + + cdef class FlightInfo: """A description of a Flight stream.""" cdef: @@ -899,6 +927,22 @@ cdef class FlightClient: return result + def get_schema(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Request schema for an available flight.""" + cdef: + SchemaResult result = SchemaResult.__new__(SchemaResult) + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + with nogil: + check_status( + self.client.get() + .GetSchema(deref(c_options), c_descriptor, &result.result) + ) + + return result + def do_get(self, ticket: Ticket, options: FlightCallOptions = None): """Request the data for a flight. @@ -1287,6 +1331,22 @@ cdef CStatus _get_flight_info(void* self, const CServerCallContext& context, info.reset(new CFlightInfo(deref(( result).info.get()))) return CStatus_OK() +cdef CStatus _get_schema(void* self, const CServerCallContext& context, + CFlightDescriptor c_descriptor, + unique_ptr[CSchemaResult]* info) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + FlightDescriptor py_descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + py_descriptor.descriptor = c_descriptor + result = ( self).get_schema(ServerCallContext.wrap(context), + py_descriptor) + if not isinstance(result, SchemaResult): + raise TypeError("FlightServerBase.get_schema_info must return " + "a SchemaResult instance, but got {}".format( + type(result))) + info.reset(new CSchemaResult(deref(( result).result.get()))) + return CStatus_OK() cdef CStatus _do_put(void* self, const CServerCallContext& context, unique_ptr[CFlightMessageReader] reader, @@ -1556,6 +1616,7 @@ cdef class FlightServerBase: vtable.list_flights = &_list_flights vtable.get_flight_info = &_get_flight_info + vtable.get_schema = &_get_schema vtable.do_put = &_do_put vtable.do_get = &_do_get vtable.list_actions = &_list_actions @@ -1584,6 +1645,9 @@ cdef class FlightServerBase: def get_flight_info(self, context, descriptor): raise NotImplementedError + def get_schema(self, context, descriptor): + raise NotImplementedError + def do_put(self, context, descriptor, reader, writer: FlightMetadataWriter): raise NotImplementedError diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py index 2d037f1d44886..b336365540d39 100644 --- a/python/pyarrow/flight.py +++ b/python/pyarrow/flight.py @@ -32,6 +32,7 @@ FlightDescriptor, FlightEndpoint, FlightInfo, + SchemaResult, FlightServerBase, FlightError, FlightInternalError, diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 69b4653463a9e..f99b0a5050663 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -107,6 +107,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CStatus Deserialize(const c_string& serialized, unique_ptr[CFlightInfo]* out) + cdef cppclass CSchemaResult" arrow::flight::SchemaResult": + CSchemaResult(CSchemaResult result) + CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) + cdef cppclass CFlightListing" arrow::flight::FlightListing": CStatus Next(unique_ptr[CFlightInfo]* info) @@ -219,7 +223,9 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CStatus GetFlightInfo(CFlightCallOptions& options, CFlightDescriptor& descriptor, unique_ptr[CFlightInfo]* info) - + CStatus GetSchema(CFlightCallOptions& options, + CFlightDescriptor& descriptor, + unique_ptr[CSchemaResult]* result) CStatus DoGet(CFlightCallOptions& options, CTicket& ticket, unique_ptr[CFlightStreamReader]* stream) CStatus DoPut(CFlightCallOptions& options, @@ -261,6 +267,9 @@ ctypedef CStatus cb_list_flights(object, const CServerCallContext&, ctypedef CStatus cb_get_flight_info(object, const CServerCallContext&, const CFlightDescriptor&, unique_ptr[CFlightInfo]*) +ctypedef CStatus cb_get_schema(object, const CServerCallContext&, + const CFlightDescriptor&, + unique_ptr[CSchemaResult]*) ctypedef CStatus cb_do_put(object, const CServerCallContext&, unique_ptr[CFlightMessageReader], unique_ptr[CFlightMetadataWriter]) @@ -286,6 +295,7 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: PyFlightServerVtable() function[cb_list_flights] list_flights function[cb_get_flight_info] get_flight_info + function[cb_get_schema] get_schema function[cb_do_put] do_put function[cb_do_get] do_get function[cb_do_action] do_action @@ -341,6 +351,10 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: int64_t total_bytes, unique_ptr[CFlightInfo]* out) + cdef CStatus CreateSchemaResult" arrow::py::flight::CreateSchemaResult"( + shared_ptr[CSchema] schema, + unique_ptr[CSchemaResult]* out) + cdef extern from "" namespace "std": unique_ptr[CFlightDataStream] move(unique_ptr[CFlightDataStream]) nogil diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 3e213698e2575..c65e309cc32ae 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -221,6 +221,10 @@ def get_flight_info(self, context, descriptor): -1, ) + def get_schema(self, context, descriptor): + info = self.get_flight_info(context, descriptor) + return flight.SchemaResult(info.schema) + class CheckTicketFlightServer(FlightServerBase): """A Flight server that compares the given ticket to an expected value.""" @@ -498,6 +502,14 @@ def test_flight_get_info(): flight.Location.for_grpc_tcp('localhost', 5005) +def test_flight_get_schema(): + """Make sure GetSchema returns correct schema.""" + with flight_server(GetInfoFlightServer) as server_location: + client = flight.FlightClient.connect(server_location) + info = client.get_schema(flight.FlightDescriptor.for_command(b'')) + assert info.schema == pa.schema([('a', pa.int32())]) + + @pytest.mark.skipif(os.name == 'nt', reason="Unix sockets can't be tested on Windows") def test_flight_domain_socket():