Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arrow ipc copy #2

Open
wants to merge 3 commits into
base: devin/1739874746-arrow-ipc-copy
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/sql_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: SQL Unit tests

on: [push, pull_request,repository_dispatch]

defaults:
run:
shell: bash

jobs:
unitTests:
name: SQL unit tests
runs-on: macos-latest
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why osx out of curiosity?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no reason, we were build linux and osx in other repo. I thought we will start with one. LMK if you want me to change this to linux.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would prefer linux. gives me more confidence on the platform we will run on

env:
GEN: ninja

steps:
- name: Install Ninja
run: brew install ninja

- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true

- name: Build Arrow extension
run: make release

- name: Run SQL unit tests
run: make test
91 changes: 72 additions & 19 deletions src/arrow_copy_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "arrow_copy_functions.hpp"
#include "arrow_to_ipc.hpp"
#include "duckdb/common/types/value.hpp"
#include "duckdb/main/config.hpp"
#include "duckdb/common/file_system.hpp"
#include "duckdb/function/table/arrow.hpp"
#include "arrow/c/bridge.h"
#include "arrow/io/file.h"
#include "arrow/ipc/writer.h"
#include "arrow_to_ipc.hpp"
#include "duckdb/common/file_system.hpp"
#include "duckdb/function/table/arrow.hpp"
#include "duckdb/main/config.hpp"

namespace duckdb {

Expand Down Expand Up @@ -48,6 +47,37 @@ ArrowIPCWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data,
case LogicalTypeId::BIGINT:
arrow_type = arrow::int64();
break;
case LogicalTypeId::VARCHAR:
arrow_type = arrow::utf8();
break;
case LogicalTypeId::DOUBLE:
arrow_type = arrow::float64();
break;
case LogicalTypeId::BOOLEAN:
arrow_type = arrow::boolean();
break;
case LogicalTypeId::DATE:
arrow_type = arrow::date32();
break;
case LogicalTypeId::TIMESTAMP_SEC:
arrow_type = arrow::timestamp(arrow::TimeUnit::SECOND);
break;
case LogicalTypeId::TIMESTAMP_MS:
arrow_type = arrow::timestamp(arrow::TimeUnit::MILLI);
break;
case LogicalTypeId::TIMESTAMP:
arrow_type = arrow::timestamp(arrow::TimeUnit::MICRO);
break;
case LogicalTypeId::TIMESTAMP_NS:
arrow_type = arrow::timestamp(arrow::TimeUnit::NANO);
break;
case LogicalTypeId::BLOB:
arrow_type = arrow::binary();
break;
case LogicalTypeId::DECIMAL:
arrow_type = arrow::decimal(DecimalType::GetWidth(sql_type),
DecimalType::GetScale(sql_type));
break;
// Add more type conversions as needed
default:
throw IOException("Unsupported type for Arrow IPC: " +
Expand Down Expand Up @@ -87,11 +117,11 @@ ArrowIPCWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data,

vector<unique_ptr<Expression>> ArrowIPCWriteSelect(CopyToSelectInput &input) {
vector<unique_ptr<Expression>> result;
bool any_change = false;
// bool any_change = false;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the commented code?


for (auto &expr : input.select_list) {
const auto &type = expr->return_type;
const auto &name = expr->GetAlias();
// const auto &type = expr->return_type;
// const auto &name = expr->GetAlias();

// All types supported by Arrow IPC
result.push_back(std::move(expr));
Expand Down Expand Up @@ -147,7 +177,7 @@ void ArrowIPCWriteSink(ExecutionContext &context, FunctionData &bind_data,
void ArrowIPCWriteCombine(ExecutionContext &context, FunctionData &bind_data,
GlobalFunctionData &gstate,
LocalFunctionData &lstate) {
auto &arrow_bind = bind_data.Cast<ArrowIPCWriteBindData>();
// auto &arrow_bind = bind_data.Cast<ArrowIPCWriteBindData>();
auto &global_state = gstate.Cast<ArrowIPCWriteGlobalState>();
auto &local_state = lstate.Cast<ArrowIPCWriteLocalState>();

Expand Down Expand Up @@ -226,19 +256,29 @@ unique_ptr<FunctionData> ArrowIPCCopyDeserialize(Deserializer &deserializer,

unique_ptr<FunctionData>
ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info,
vector<string> &names,
vector<string> &expected_names,
vector<LogicalType> &expected_types) {
auto &fs = FileSystem::GetFileSystem(context);
auto handle = fs.OpenFile(info.file_path, FileFlags::FILE_FLAGS_READ);
auto file_size = fs.GetFileSize(*handle);

// Read file into memory
vector<uint8_t> file_buffer(file_size);
handle->Read(file_buffer.data(), file_size);
auto s_file_buffer = make_shared_ptr<vector<uint8_t>>(file_size);
auto file_buffer = s_file_buffer.get();
auto bytes_read = handle->Read(file_buffer->data(), file_size);

if (bytes_read != file_size) {
throw IOException("Failed to read Arrow IPC file");
}

// Create stream decoder and buffer
auto stream_decoder = make_uniq<BufferingArrowIPCStreamDecoder>();
auto consume_result = stream_decoder->Consume(file_buffer.data(), file_size);
if (std::string(reinterpret_cast<char*>(file_buffer->data()), 6) == ARROW_MAGIC) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what this is doing? Is the written data wrong and if so can we fix it on the write side?

Is the write side compatible w/ snowflake's arrow or not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is trying to handle both file and stream formats. ipc stream doesn't have magic and footer, where as file format has it.
I will fix the write side to keep it same as snowflake.
Just thinking that the extension should support both formats, atleast on read side. I will do similar fix in arrow_scan_ipc.cpp.
I will add a comment.

// Remove magic and footer
file_buffer->erase(file_buffer->begin(), file_buffer->begin() + 8);
file_buffer->erase(file_buffer->end() - 8, file_buffer->end());
}
auto consume_result = stream_decoder->Consume(file_buffer->data(), file_buffer->size());
if (!consume_result.ok()) {
throw IOException("Invalid Arrow IPC file");
}
Expand All @@ -251,16 +291,17 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info,
auto stream_factory_ptr = (uintptr_t)&stream_decoder->buffer();
auto stream_factory_produce =
(stream_factory_produce_t)&ArrowIPCStreamBufferReader::CreateStream;
auto stream_factory_get_schema =
(stream_factory_get_schema_t)&ArrowIPCStreamBufferReader::GetSchema;

// Store decoder and get buffer pointer
auto result = make_uniq<ArrowIPCScanFunctionData>(stream_factory_produce, stream_factory_ptr);
result->stream_decoder = std::move(stream_decoder);
result->file_buffer = s_file_buffer;

auto &data = *result;
stream_factory_get_schema((ArrowArrayStream *)stream_factory_ptr,
data.schema_root.arrow_schema);
ArrowIPCStreamBufferReader::GetSchema((uintptr_t)&result->stream_decoder->buffer(),
data.schema_root);
vector<string> names;
vector<LogicalType> types;
for (idx_t col_idx = 0;
col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) {
auto &schema = *data.schema_root.arrow_schema.children[col_idx];
Expand All @@ -273,10 +314,10 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info,
if (schema.dictionary) {
auto dictionary_type = ArrowType::GetArrowLogicalType(
DBConfig::GetConfig(context), *schema.dictionary);
expected_types.emplace_back(dictionary_type->GetDuckType());
types.emplace_back(dictionary_type->GetDuckType());
arrow_type->SetDictionary(std::move(dictionary_type));
} else {
expected_types.emplace_back(arrow_type->GetDuckType());
types.emplace_back(arrow_type->GetDuckType());
}
result->arrow_table.AddColumn(col_idx, std::move(arrow_type));
auto format = string(schema.format);
Expand All @@ -287,6 +328,18 @@ ArrowIPCCopyFromBind(ClientContext &context, CopyInfo &info,
names.push_back(name);
}
QueryResult::DeduplicateColumns(names);
if (expected_types.empty()) {
expected_names = names;
expected_types = types;
} else {
if (expected_names != names) {
throw IOException("Arrow IPC schema mismatch, column names mismatch");
}
if (expected_types != types) {
// TODO add more detailed error message
throw IOException("Arrow IPC schema mismatch, column types mismatch");
}
}
return std::move(result);
}

Expand Down
3 changes: 3 additions & 0 deletions src/include/arrow_scan_ipc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

namespace duckdb {

const std::string ARROW_MAGIC = "ARROW1";

struct ArrowIPCScanFunctionData : public ArrowScanFunctionData {
public:
using ArrowScanFunctionData::ArrowScanFunctionData;
unique_ptr<BufferingArrowIPCStreamDecoder> stream_decoder = nullptr;
shared_ptr<vector<uint8_t>> file_buffer;
};

// IPC Table scan is identical to ArrowTableFunction arrow scan except instead
Expand Down
15 changes: 4 additions & 11 deletions src/include/arrow_stream_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ struct ArrowIPCStreamBuffer : public arrow::ipc::Listener {
bool is_eos_;

/// Decoded a record batch
arrow::Status OnSchemaDecoded(std::shared_ptr<arrow::Schema> schema);
arrow::Status OnSchemaDecoded(std::shared_ptr<arrow::Schema> schema) override;
/// Decoded a record batch
arrow::Status
OnRecordBatchDecoded(std::shared_ptr<arrow::RecordBatch> record_batch);
OnRecordBatchDecoded(std::shared_ptr<arrow::RecordBatch> record_batch) override;
/// Reached end of stream
arrow::Status OnEOS();
arrow::Status OnEOS() override;

public:
/// Constructor
Expand Down Expand Up @@ -71,14 +71,7 @@ struct ArrowIPCStreamBufferReader : public arrow::RecordBatchReader,
ArrowIPCStreamBufferReader(std::shared_ptr<ArrowIPCStreamBuffer> buffer);

/// Destructor
~ArrowIPCStreamBufferReader() override {
// Clear batches first
if (buffer_) {
buffer_->batches().clear();
// Let schema cleanup happen through ArrowSchemaWrapper
buffer_.reset();
}
}
~ArrowIPCStreamBufferReader() = default;

/// Get the schema
std::shared_ptr<arrow::Schema> schema() const override;
Expand Down
24 changes: 23 additions & 1 deletion test/sql/arrow_ipc/write_arrow_ipc.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ statement ok
COPY test_data TO '__TEST_DIR__/test.arrow' (FORMAT 'arrow');

statement ok
CREATE TABLE test_verify(i INTEGER);
CREATE TABLE test_verify(i BIGINT);

statement ok
COPY test_verify FROM '__TEST_DIR__/test.arrow' (FORMAT 'arrow');
Expand All @@ -20,3 +20,25 @@ query I
SELECT COUNT(*) FROM test_verify;
----
1000

statement ok
CREATE TABLE all_types (id INTEGER, name TEXT, amount DOUBLE, is_active BOOLEAN, create_date DATE, created_at TIMESTAMP, big_number BIGINT, ratio DECIMAL(10,2), binary_data BLOB);

statement ok
INSERT INTO all_types VALUES
(1, 'Alice', 99.99, TRUE, '2024-02-01', '2024-02-01 12:00:00', 123456789012345, 3.14, X'68656C6C6F'),
(2, 'Bob', 250.75, FALSE, '2023-12-15', '2023-12-15 08:30:00', 987654321098765, 1.23, X'74657374');

statement ok
COPY all_types TO '__TEST_DIR__/all_types.arrow' (FORMAT 'arrow');

statement ok
CREATE TABLE all_types_verify (id INTEGER, name TEXT, amount DOUBLE, is_active BOOLEAN, create_date DATE, created_at TIMESTAMP, big_number BIGINT, ratio DECIMAL(10,2), binary_data BLOB);

statement ok
COPY all_types_verify FROM '__TEST_DIR__/all_types.arrow' (FORMAT 'arrow');

query I
SELECT COUNT(*) FROM all_types_verify;
----
2