diff --git a/.github/workflows/Windows.yml b/.github/workflows/Windows.yml index ef7f2a9..f8be85e 100644 --- a/.github/workflows/Windows.yml +++ b/.github/workflows/Windows.yml @@ -1,5 +1,5 @@ name: Windows -on: [push, pull_request,repository_dispatch] +on: [push, repository_dispatch] concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} cancel-in-progress: true diff --git a/.github/workflows/sql_tests.yml b/.github/workflows/sql_tests.yml new file mode 100644 index 0000000..0f7a587 --- /dev/null +++ b/.github/workflows/sql_tests.yml @@ -0,0 +1,29 @@ +name: SQL Unit tests + +on: [push, pull_request,repository_dispatch] + +defaults: + run: + shell: bash + +jobs: + unitTests: + name: SQL unit tests + runs-on: ubuntu-latest + env: + GEN: ninja + + steps: + - name: Install Ninja + run: sudo apt-get update && sudo apt-get install -y ninja-build + + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: true + + - name: Build Arrow extension + run: make release + + - name: Run SQL unit tests + run: make test \ No newline at end of file diff --git a/src/arrow_copy_functions.cpp b/src/arrow_copy_functions.cpp index 685d4f4..5502202 100644 --- a/src/arrow_copy_functions.cpp +++ b/src/arrow_copy_functions.cpp @@ -1,12 +1,11 @@ #include "arrow_copy_functions.hpp" -#include "arrow_to_ipc.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/function/table/arrow.hpp" #include "arrow/c/bridge.h" #include "arrow/io/file.h" #include "arrow/ipc/writer.h" +#include "arrow_to_ipc.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/main/config.hpp" namespace duckdb { @@ -48,6 +47,37 @@ ArrowIPCWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, case LogicalTypeId::BIGINT: arrow_type = arrow::int64(); break; + case LogicalTypeId::VARCHAR: + arrow_type = arrow::utf8(); + break; + case LogicalTypeId::DOUBLE: + arrow_type = arrow::float64(); + break; + case LogicalTypeId::BOOLEAN: + arrow_type = arrow::boolean(); + break; + case LogicalTypeId::DATE: + arrow_type = arrow::date32(); + break; + case LogicalTypeId::TIMESTAMP_SEC: + arrow_type = arrow::timestamp(arrow::TimeUnit::SECOND); + break; + case LogicalTypeId::TIMESTAMP_MS: + arrow_type = arrow::timestamp(arrow::TimeUnit::MILLI); + break; + case LogicalTypeId::TIMESTAMP: + arrow_type = arrow::timestamp(arrow::TimeUnit::MICRO); + break; + case LogicalTypeId::TIMESTAMP_NS: + arrow_type = arrow::timestamp(arrow::TimeUnit::NANO); + break; + case LogicalTypeId::BLOB: + arrow_type = arrow::binary(); + break; + case LogicalTypeId::DECIMAL: + arrow_type = arrow::decimal(DecimalType::GetWidth(sql_type), + DecimalType::GetScale(sql_type)); + break; // Add more type conversions as needed default: throw IOException("Unsupported type for Arrow IPC: " + @@ -86,17 +116,8 @@ ArrowIPCWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, } vector> ArrowIPCWriteSelect(CopyToSelectInput &input) { - vector> result; - bool any_change = false; - - for (auto &expr : input.select_list) { - const auto &type = expr->return_type; - const auto &name = expr->GetAlias(); - - // All types supported by Arrow IPC - result.push_back(std::move(expr)); - } - // If no changes were made, return empty vector to avoid unnecessary + // All types supported by Arrow IPC + // As no changes were made, return empty vector to avoid unnecessary // projection return {}; } @@ -147,7 +168,6 @@ void ArrowIPCWriteSink(ExecutionContext &context, FunctionData &bind_data, void ArrowIPCWriteCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, LocalFunctionData &lstate) { - auto &arrow_bind = bind_data.Cast(); auto &global_state = gstate.Cast(); auto &local_state = lstate.Cast(); @@ -226,19 +246,31 @@ unique_ptr ArrowIPCCopyDeserialize(Deserializer &deserializer, unique_ptr ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info, - vector &names, + vector &expected_names, vector &expected_types) { auto &fs = FileSystem::GetFileSystem(context); auto handle = fs.OpenFile(info.file_path, FileFlags::FILE_FLAGS_READ); auto file_size = fs.GetFileSize(*handle); // Read file into memory - vector file_buffer(file_size); - handle->Read(file_buffer.data(), file_size); + auto s_file_buffer = make_shared_ptr>(file_size); + auto file_buffer = s_file_buffer.get(); + auto bytes_read = handle->Read(file_buffer->data(), file_size); + + if (bytes_read != file_size) { + throw IOException("Failed to read Arrow IPC file"); + } + auto buffer = file_buffer->data(); + auto buffer_size = file_buffer->size(); // Create stream decoder and buffer + if (std::string(reinterpret_cast(file_buffer->data()), 6) == ARROW_MAGIC) { + // ignore magic and footer if it is arrow file format + buffer = file_buffer->data() + 8; // skip 8byte magic at the start + buffer_size -= 16; // skip 8byte magic and 8byte footer + } auto stream_decoder = make_uniq(); - auto consume_result = stream_decoder->Consume(file_buffer.data(), file_size); + auto consume_result = stream_decoder->Consume(buffer, buffer_size); if (!consume_result.ok()) { throw IOException("Invalid Arrow IPC file"); } @@ -251,16 +283,17 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info, auto stream_factory_ptr = (uintptr_t)&stream_decoder->buffer(); auto stream_factory_produce = (stream_factory_produce_t)&ArrowIPCStreamBufferReader::CreateStream; - auto stream_factory_get_schema = - (stream_factory_get_schema_t)&ArrowIPCStreamBufferReader::GetSchema; // Store decoder and get buffer pointer auto result = make_uniq(stream_factory_produce, stream_factory_ptr); result->stream_decoder = std::move(stream_decoder); + result->file_buffer = s_file_buffer; auto &data = *result; - stream_factory_get_schema((ArrowArrayStream *)stream_factory_ptr, - data.schema_root.arrow_schema); + ArrowIPCStreamBufferReader::GetSchema((uintptr_t)&result->stream_decoder->buffer(), + data.schema_root); + vector names; + vector types; for (idx_t col_idx = 0; col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) { auto &schema = *data.schema_root.arrow_schema.children[col_idx]; @@ -273,10 +306,10 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info, if (schema.dictionary) { auto dictionary_type = ArrowType::GetArrowLogicalType( DBConfig::GetConfig(context), *schema.dictionary); - expected_types.emplace_back(dictionary_type->GetDuckType()); + types.emplace_back(dictionary_type->GetDuckType()); arrow_type->SetDictionary(std::move(dictionary_type)); } else { - expected_types.emplace_back(arrow_type->GetDuckType()); + types.emplace_back(arrow_type->GetDuckType()); } result->arrow_table.AddColumn(col_idx, std::move(arrow_type)); auto format = string(schema.format); @@ -287,6 +320,18 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info, names.push_back(name); } QueryResult::DeduplicateColumns(names); + if (expected_types.empty()) { + expected_names = names; + expected_types = types; + } else { + if (expected_names != names) { + throw IOException("Arrow IPC schema mismatch, column names mismatch"); + } + if (expected_types != types) { + // TODO add more detailed error message + throw IOException("Arrow IPC schema mismatch, column types mismatch"); + } + } return std::move(result); } diff --git a/src/arrow_scan_ipc.cpp b/src/arrow_scan_ipc.cpp index fa1b15c..a300455 100644 --- a/src/arrow_scan_ipc.cpp +++ b/src/arrow_scan_ipc.cpp @@ -24,7 +24,6 @@ TableFunction ArrowIPCTableFunction::GetFunction() { unique_ptr ArrowIPCTableFunction::ArrowScanBind( ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - auto stream_decoder = make_uniq(); // Get file path from input auto file_path = input.inputs[0].GetValue(); @@ -33,11 +32,25 @@ unique_ptr ArrowIPCTableFunction::ArrowScanBind( // Read file into memory auto file_size = fs.GetFileSize(*handle); - vector buffer(file_size); - handle->Read(buffer.data(), file_size); + auto s_file_buffer = make_shared_ptr>(file_size); + auto file_buffer = s_file_buffer.get(); + auto bytes_read = handle->Read(file_buffer->data(), file_size); + + if (bytes_read != file_size) { + throw IOException("Failed to read Arrow IPC file"); + } + // Create stream decoder and buffer + auto buffer = file_buffer->data(); + auto buffer_size = file_buffer->size(); + if (std::string(reinterpret_cast(file_buffer->data()), 6) == ARROW_MAGIC) { + // ignore magic and footer if it is arrow file format + buffer = file_buffer->data() + 8; // skip 8byte magic at the start + buffer_size -= 16; // skip 8byte magic and 8byte footer + } // Feed file into decoder - auto consume_result = stream_decoder->Consume(buffer.data(), file_size); + auto stream_decoder = make_uniq(); + auto consume_result = stream_decoder->Consume(buffer, buffer_size); if (!consume_result.ok()) { throw IOException("Invalid Arrow IPC file"); } @@ -58,7 +71,7 @@ unique_ptr ArrowIPCTableFunction::ArrowScanBind( // Store decoder result->stream_decoder = std::move(stream_decoder); - + result->file_buffer = s_file_buffer; // TODO Everything below this is identical to the bind in // duckdb/src/function/table/arrow.cpp auto &data = *result; diff --git a/src/include/arrow_scan_ipc.hpp b/src/include/arrow_scan_ipc.hpp index 66a7827..c0290eb 100644 --- a/src/include/arrow_scan_ipc.hpp +++ b/src/include/arrow_scan_ipc.hpp @@ -7,10 +7,13 @@ namespace duckdb { +const std::string ARROW_MAGIC = "ARROW1"; + struct ArrowIPCScanFunctionData : public ArrowScanFunctionData { public: using ArrowScanFunctionData::ArrowScanFunctionData; unique_ptr stream_decoder = nullptr; + shared_ptr> file_buffer; }; // IPC Table scan is identical to ArrowTableFunction arrow scan except instead diff --git a/src/include/arrow_stream_buffer.hpp b/src/include/arrow_stream_buffer.hpp index 550e271..73da634 100644 --- a/src/include/arrow_stream_buffer.hpp +++ b/src/include/arrow_stream_buffer.hpp @@ -30,12 +30,12 @@ struct ArrowIPCStreamBuffer : public arrow::ipc::Listener { bool is_eos_; /// Decoded a record batch - arrow::Status OnSchemaDecoded(std::shared_ptr schema); + arrow::Status OnSchemaDecoded(std::shared_ptr schema) override; /// Decoded a record batch arrow::Status - OnRecordBatchDecoded(std::shared_ptr record_batch); + OnRecordBatchDecoded(std::shared_ptr record_batch) override; /// Reached end of stream - arrow::Status OnEOS(); + arrow::Status OnEOS() override; public: /// Constructor @@ -71,14 +71,7 @@ struct ArrowIPCStreamBufferReader : public arrow::RecordBatchReader, ArrowIPCStreamBufferReader(std::shared_ptr buffer); /// Destructor - ~ArrowIPCStreamBufferReader() override { - // Clear batches first - if (buffer_) { - buffer_->batches().clear(); - // Let schema cleanup happen through ArrowSchemaWrapper - buffer_.reset(); - } - } + ~ArrowIPCStreamBufferReader() = default; /// Get the schema std::shared_ptr schema() const override; diff --git a/test/sql/arrow_ipc/write_arrow_ipc.test b/test/sql/arrow_ipc/write_arrow_ipc.test index 5f0b3d6..7547424 100644 --- a/test/sql/arrow_ipc/write_arrow_ipc.test +++ b/test/sql/arrow_ipc/write_arrow_ipc.test @@ -10,8 +10,13 @@ CREATE TABLE test_data AS SELECT * FROM range(0, 1000) tbl(i); statement ok COPY test_data TO '__TEST_DIR__/test.arrow' (FORMAT 'arrow'); +query I +SELECT COUNT(*) FROM arrow('__TEST_DIR__/test.arrow'); +---- +1000 + statement ok -CREATE TABLE test_verify(i INTEGER); +CREATE TABLE test_verify(i BIGINT); statement ok COPY test_verify FROM '__TEST_DIR__/test.arrow' (FORMAT 'arrow'); @@ -20,3 +25,30 @@ query I SELECT COUNT(*) FROM test_verify; ---- 1000 + +statement ok +CREATE TABLE all_types (id INTEGER, name TEXT, amount DOUBLE, is_active BOOLEAN, create_date DATE, created_at TIMESTAMP, big_number BIGINT, ratio DECIMAL(10,2), binary_data BLOB); + +statement ok +INSERT INTO all_types VALUES +(1, 'Alice', 99.99, TRUE, '2024-02-01', '2024-02-01 12:00:00', 123456789012345, 3.14, X'68656C6C6F'), +(2, 'Bob', 250.75, FALSE, '2023-12-15', '2023-12-15 08:30:00', 987654321098765, 1.23, X'74657374'); + +statement ok +COPY all_types TO '__TEST_DIR__/all_types.arrow' (FORMAT 'arrow'); + +query I +SELECT COUNT(*) FROM arrow('__TEST_DIR__/all_types.arrow'); +---- +2 + +statement ok +CREATE TABLE all_types_verify (id INTEGER, name TEXT, amount DOUBLE, is_active BOOLEAN, create_date DATE, created_at TIMESTAMP, big_number BIGINT, ratio DECIMAL(10,2), binary_data BLOB); + +statement ok +COPY all_types_verify FROM '__TEST_DIR__/all_types.arrow' (FORMAT 'arrow'); + +query I +SELECT COUNT(*) FROM all_types_verify; +---- +2