Skip to content

Commit

Permalink
Draft hash_pivot API
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jan 14, 2025
1 parent ef00568 commit 6bc1ee7
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 2 deletions.
69 changes: 69 additions & 0 deletions cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,75 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
uint32_t min_count;
};

/// \brief Control Pivot kernel behavior
///
/// These options apply to the "pivot" (TODO) and "hash_pivot" (TODO) functions.
///
/// Constraints:
/// - The corresponding `Aggregate::target` must have two FieldRef elements;
/// the first one points to the pivot key column, the second points to the
/// pivoted data column.
/// - The pivot key column must be string-like; its values will be matched
/// against `key_names` in order to dispatch the pivoted data into the
/// output.
///
/// "hash_pivot" example
/// --------------------
///
/// Assuming the following input with schema
/// `{"group": int32, "key": utf8, "value": int16}`:
/// ```
/// group | key | value
/// -----------------------------
/// 1 | height | 11
/// 1 | width | 12
/// 2 | width | 13
/// 3 | height | 14
/// 3 | depth | 15
/// ```
/// and the following settings:
/// - a hash grouping key "group"
/// - Aggregate(
/// .function = "hash_pivot",
/// .options = PivotOptions(.key_names = {"height", "width"}),
/// .target = {"key", "value"},
/// .name = {"props"})
///
/// then the output will have the schema
/// `{"group": int32, "props": struct{"height": int16, "width": int16}}`
/// and the following value:
/// ```
/// group | props
/// | height | width
/// -----------------------------
/// 1 | 11 | 12
/// 2 | null | 13
/// 3 | 14 | null
/// ```
class ARROW_EXPORT PivotOptions : public FunctionOptions {
public:
// Configure the behavior of pivot keys not in `key_names`
enum UnexpectedKeyBehavior {
// Unexpected pivot keys are ignored silently
kIgnore,
// Unexpected pivot keys return a KeyError
kRaise
};
// TODO should duplicate key behavior be configurable as well?

explicit PivotOptions(std::vector<std::string> key_names,
UnexpectedKeyBehavior unexpected_key_behavior = kIgnore);
// Default constructor for serialization
PivotOptions();
static constexpr char const kTypeName[] = "PivotOptions";
static PivotOptions Defaults() { return PivotOptions{}; }

// The values expected in the pivot key column
std::vector<std::string> key_names;
// The behavior when pivot keys not in `key_names` are encountered
UnexpectedKeyBehavior unexpected_key_behavior = kIgnore;
};

/// \brief Control Index kernel behavior
class ARROW_EXPORT IndexOptions : public FunctionOptions {
public:
Expand Down
258 changes: 256 additions & 2 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "arrow/util/int128_internal.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/ree_util.h"
#include "arrow/util/span.h"
#include "arrow/util/task_group.h"
#include "arrow/util/tdigest.h"
#include "arrow/util/thread_pool.h"
Expand All @@ -56,6 +57,7 @@ namespace arrow {

using internal::checked_cast;
using internal::FirstTimeBitmapWriter;
using util::span;

namespace compute {
namespace internal {
Expand Down Expand Up @@ -3319,9 +3321,242 @@ struct GroupedListFactory {
HashAggregateKernel kernel;
InputType argument_type;
};
} // namespace

namespace {
// ----------------------------------------------------------------------
// Pivot implementation

using PivotKeyIndex = uint8_t;

constexpr PivotKeyIndex kNullPivotKey = std::numeric_limits<PivotKeyIndex>::max();

// TODO move this into pivot_internal.h
struct PivotKeyMapper {
virtual ~PivotKeyMapper() = default;

virtual Status Init(const PivotOptions* options) {
key_name_map_.reserve(options->key_names.size());
PivotKeyIndex index = 0;
for (const auto& key_name : options->key_names) {
// TODO check for key duplicates
// TODO check for index overflow
key_name_map_[std::string_view(key_name)] = index++;
}
unexpected_key_behavior_ = options->unexpected_key_behavior;
return Status::OK();
}

virtual Result<span<PivotKeyIndex>> MapKeys(const ExecValue&) = 0;

protected:
Result<PivotKeyIndex> KeyNotFound(std::string_view key_name) {
if (unexpected_key_behavior_ == PivotOptions::kIgnore) {
return kNullPivotKey;
}
return Status::KeyError("Unexpected pivot key: ", key_name);
}

static constexpr int kBatchLength = 512;
std::unordered_map<std::string_view, PivotKeyIndex> key_name_map_;
PivotOptions::UnexpectedKeyBehavior unexpected_key_behavior_;
TypedBufferBuilder<PivotKeyIndex> key_indices_buffer_;
};

template <typename KeyType>
struct TypedPivotKeyMapper : public PivotKeyMapper {
Result<span<PivotKeyIndex>> MapKeys(const ExecValue& value) override {
RETURN_NOT_OK(key_indices_buffer_.Reserve(value.length()));
PivotKeyIndex* key_indices = key_indices_buffer_.mutable_data();

auto process_key = [&](std::string_view key_name) -> Result<PivotKeyIndex> {
const auto it = key_name_map_.find(key_name);
if (ARROW_PREDICT_FALSE(it == key_name_map_.end())) {
return KeyNotFound(key_name);
} else {
return it->second;
}
};

if (value.is_scalar()) {
const auto& scalar = value.scalar_as<BaseBinaryScalar>();
ARROW_ASSIGN_OR_RAISE(key_indices[0], process_key(scalar.view()));
return span(key_indices, 1);
}
const ArraySpan& array = value.array;
int64_t i = 0;
RETURN_NOT_OK(VisitArrayValuesInline<KeyType>(
array,
[&](std::string_view key_name) {
ARROW_ASSIGN_OR_RAISE(key_indices[i], process_key(key_name));
++i;
return Status::OK();
},
[&]() {
return Status::KeyError("key name cannot be null");
}));
return span(key_indices, array.length);
}
};

struct PivotKeyMapperFactory {
template <typename T>
Status Visit(const T& key_type) {
if constexpr (is_base_binary_like(T::type_id)) {
instance = std::make_unique<TypedPivotKeyMapper<T>>();
return instance->Init(options);
}
return Status::NotImplemented("Pivot key type: ", key_type);
}

static Result<std::unique_ptr<PivotKeyMapper>> Make(const DataType& key_type, const PivotOptions* options) {
PivotKeyMapperFactory factory{options};
RETURN_NOT_OK(VisitTypeInline(key_type, &factory));
return std::move(factory).instance;
}

const PivotOptions* options;
std::unique_ptr<PivotKeyMapper> instance{};
};

/*
TODO
would probably like to write:
Result<std::unique_ptr<PivotKeyMapper>> MakePivotKeyMapper(const DataType& key_type,
const PivotOptions* options) {
std::unique_ptr<PivotKeyMapper> instance;
RETURN_NOT_OK(VisitTypeInline(key_type, [&](auto key_type) {
using T = std::decay_t<decltype(key_type)>;
if constexpr (is_base_binary_like(T::type_id)) {
instance = std::make_unique<TypedPivotKeyMapper<T>>();
return instance->Init(options);
}
return Status::NotImplemented("Pivot key type: ", key_type);
}));
return instance;
}
or even:
Result<std::unique_ptr<PivotKeyMapper>> MakePivotKeyMapper(const DataType& key_type,
const PivotOptions* options) {
return VisitTypeInline(key_type, [&](auto key_type) -> Result<std::unique_ptr<PivotKeyMapper>> {
using T = std::decay_t<decltype(key_type)>;
if constexpr (is_base_binary_like(T::type_id)) {
auto instance = std::make_unique<TypedPivotKeyMapper<T>>();
RETURN_NOT_OK(instance->Init(options));
return instance;
}
return Status::NotImplemented("Pivot key type: ", key_type);
});
}
*/

template <typename KeyType>
struct GroupedPivotImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
DCHECK_EQ(args.inputs.size(), 3);
key_type_ = args.inputs[0].GetSharedPtr();
value_type_ = args.inputs[1].GetSharedPtr();
options_ = checked_cast<const PivotOptions*>(args.options);
DCHECK_NE(options_, nullptr);
FieldVector fields;
fields.reserve(options_->key_names.size());
for (const auto& key_name : options_->key_names) {
fields.push_back(field(key_name, value_type_));
}
out_type_ = struct_(std::move(fields));
out_struct_type_ = checked_cast<const StructType*>(out_type_.get());
// counts_ = BufferBuilder(ctx->memory_pool());
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
return Status::NotImplemented("GroupedPivotImpl::Resize");
// auto added_groups = new_num_groups - num_groups_;
// num_groups_ = new_num_groups;
// return counts_.Append(added_groups * sizeof(int64_t), 0);
}

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
return Status::NotImplemented("GroupedPivotImpl::Merge");
// auto other = checked_cast<GroupedCountAllImpl*>(&raw_other);
//
// auto* counts = counts_.mutable_data_as<int64_t>();
// const auto* other_counts = other->counts_.data_as<int64_t>();
//
// auto* g = group_id_mapping.GetValues<uint32_t>(1);
// for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
// counts[*g] += other_counts[other_g];
// }
// return Status::OK();
}

Status Consume(const ExecSpan& batch) override {
return Status::NotImplemented("GroupedPivotImpl::Consume");
// auto* counts = counts_.mutable_data_as<int64_t>();
// auto* g_begin = batch[0].array.GetValues<uint32_t>(1);
// for (auto g_itr = g_begin, end = g_itr + batch.length; g_itr != end; g_itr++) {
// counts[*g_itr] += 1;
// }
// return Status::OK();
}

Result<Datum> Finalize() override {
return Status::NotImplemented("GroupedPivotImpl::Finalize");
// ARROW_ASSIGN_OR_RAISE(auto counts, counts_.Finish());
// return std::make_shared<Int64Array>(num_groups_, std::move(counts));
}

std::shared_ptr<DataType> out_type() const override {
return out_type_;
}

std::shared_ptr<DataType> key_type_;
std::shared_ptr<DataType> value_type_;
std::shared_ptr<DataType> out_type_;
const StructType* out_struct_type_;
const PivotOptions* options_;
};

// TODO simplify this away?
template <typename KeyType>
Result<std::unique_ptr<KernelState>> GroupedPivotInit(KernelContext* ctx,
const KernelInitArgs& args) {
ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<GroupedPivotImpl<KeyType>>(ctx, args));
// auto instance = checked_cast<GroupedPivotImpl<KeyType>*>(impl.get());
// DCHECK_EQ(args.inputs.size(), 3);
// instance->key_type_ = args.inputs[0].GetSharedPtr();
// instance->value_type_ = args.inputs[1].GetSharedPtr();
return impl;
}

struct GroupedPivotFactory {
template <typename KeyType>
enable_if_base_binary<KeyType, Status> Visit(const KeyType& type) {
// TODO replace Any() with a more selective matcher for the value type
auto sig = KernelSignature::Make({type.id(), InputType::Any(), InputType(Type::UINT32)},
OutputType(ResolveGroupOutputType));
kernel = MakeKernel(std::move(sig), GroupedPivotInit<KeyType>);
return Status::OK();
}

Status Visit(const DataType& type) {
return Status::TypeError("Unsupported pivot key type: ", type);
}

static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& pivot_key_type) {
GroupedPivotFactory factory;
RETURN_NOT_OK(VisitTypeInline(*pivot_key_type, &factory));
return std::move(factory.kernel);
}

HashAggregateKernel kernel;
};

// ----------------------------------------------------------------------
// Docstrings

const FunctionDoc hash_count_doc{
"Count the number of null / non-null values in each group",
("By default, only non-null values are counted.\n"
Expand Down Expand Up @@ -3456,6 +3691,18 @@ const FunctionDoc hash_one_doc{"Get one value from each group",
const FunctionDoc hash_list_doc{"List all values in each group",
("Null values are also returned."),
{"array", "group_id_array"}};

const FunctionDoc hash_pivot_doc{
"Pivot values according to a pivot key column",
("Output is a struct array with as many fields as `PivotOptions.key_names`.\n"
"All output struct fields have the same type as `pivot_values`.\n"
"Each pivot key decides in which output field the corresponding pivot value\n"
"is emitted. If a pivot key doesn't appear in a given group, null is emitted.\n"
"If a pivot key appears twice in a given group, KeyError is raised.\n"
"Behavior of unexpected pivot keys is controlled by PivotOptions."),
{"pivot_keys", "pivot_values", "group_id_array"},
"PivotOptions"};

} // namespace

void RegisterHashAggregateBasic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -3705,6 +3952,13 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
GroupedListFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_pivot", Arity::Ternary(),
hash_pivot_doc);
DCHECK_OK(AddHashAggKernels(BaseBinaryTypes(), GroupedPivotFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down

0 comments on commit 6bc1ee7

Please sign in to comment.