diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 2e5210b073ee4..886d992640c5f 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -175,6 +175,75 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { uint32_t min_count; }; +/// \brief Control Pivot kernel behavior +/// +/// These options apply to the "pivot" (TODO) and "hash_pivot" (TODO) functions. +/// +/// Constraints: +/// - The corresponding `Aggregate::target` must have two FieldRef elements; +/// the first one points to the pivot key column, the second points to the +/// pivoted data column. +/// - The pivot key column must be string-like; its values will be matched +/// against `key_names` in order to dispatch the pivoted data into the +/// output. +/// +/// "hash_pivot" example +/// -------------------- +/// +/// Assuming the following input with schema +/// `{"group": int32, "key": utf8, "value": int16}`: +/// ``` +/// group | key | value +/// ----------------------------- +/// 1 | height | 11 +/// 1 | width | 12 +/// 2 | width | 13 +/// 3 | height | 14 +/// 3 | depth | 15 +/// ``` +/// and the following settings: +/// - a hash grouping key "group" +/// - Aggregate( +/// .function = "hash_pivot", +/// .options = PivotOptions(.key_names = {"height", "width"}), +/// .target = {"key", "value"}, +/// .name = {"props"}) +/// +/// then the output will have the schema +/// `{"group": int32, "props": struct{"height": int16, "width": int16}}` +/// and the following value: +/// ``` +/// group | props +/// | height | width +/// ----------------------------- +/// 1 | 11 | 12 +/// 2 | null | 13 +/// 3 | 14 | null +/// ``` +class ARROW_EXPORT PivotOptions : public FunctionOptions { + public: + // Configure the behavior of pivot keys not in `key_names` + enum UnexpectedKeyBehavior { + // Unexpected pivot keys are ignored silently + kIgnore, + // Unexpected pivot keys return a KeyError + kRaise + }; + // TODO should duplicate key behavior be configurable as well? + + explicit PivotOptions(std::vector key_names, + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore); + // Default constructor for serialization + PivotOptions(); + static constexpr char const kTypeName[] = "PivotOptions"; + static PivotOptions Defaults() { return PivotOptions{}; } + + // The values expected in the pivot key column + std::vector key_names; + // The behavior when pivot keys not in `key_names` are encountered + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore; +}; + /// \brief Control Index kernel behavior class ARROW_EXPORT IndexOptions : public FunctionOptions { public: diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 21b7bd9bf6632..b93e535f9e73f 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -47,6 +47,7 @@ #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_overflow.h" #include "arrow/util/ree_util.h" +#include "arrow/util/span.h" #include "arrow/util/task_group.h" #include "arrow/util/tdigest.h" #include "arrow/util/thread_pool.h" @@ -56,6 +57,7 @@ namespace arrow { using internal::checked_cast; using internal::FirstTimeBitmapWriter; +using util::span; namespace compute { namespace internal { @@ -3319,9 +3321,242 @@ struct GroupedListFactory { HashAggregateKernel kernel; InputType argument_type; }; -} // namespace -namespace { +// ---------------------------------------------------------------------- +// Pivot implementation + +using PivotKeyIndex = uint8_t; + +constexpr PivotKeyIndex kNullPivotKey = std::numeric_limits::max(); + +// TODO move this into pivot_internal.h +struct PivotKeyMapper { + virtual ~PivotKeyMapper() = default; + + virtual Status Init(const PivotOptions* options) { + key_name_map_.reserve(options->key_names.size()); + PivotKeyIndex index = 0; + for (const auto& key_name : options->key_names) { + // TODO check for key duplicates + // TODO check for index overflow + key_name_map_[std::string_view(key_name)] = index++; + } + unexpected_key_behavior_ = options->unexpected_key_behavior; + return Status::OK(); + } + + virtual Result> MapKeys(const ExecValue&) = 0; + + protected: + Result KeyNotFound(std::string_view key_name) { + if (unexpected_key_behavior_ == PivotOptions::kIgnore) { + return kNullPivotKey; + } + return Status::KeyError("Unexpected pivot key: ", key_name); + } + + static constexpr int kBatchLength = 512; + std::unordered_map key_name_map_; + PivotOptions::UnexpectedKeyBehavior unexpected_key_behavior_; + TypedBufferBuilder key_indices_buffer_; +}; + +template +struct TypedPivotKeyMapper : public PivotKeyMapper { + Result> MapKeys(const ExecValue& value) override { + RETURN_NOT_OK(key_indices_buffer_.Reserve(value.length())); + PivotKeyIndex* key_indices = key_indices_buffer_.mutable_data(); + + auto process_key = [&](std::string_view key_name) -> Result { + const auto it = key_name_map_.find(key_name); + if (ARROW_PREDICT_FALSE(it == key_name_map_.end())) { + return KeyNotFound(key_name); + } else { + return it->second; + } + }; + + if (value.is_scalar()) { + const auto& scalar = value.scalar_as(); + ARROW_ASSIGN_OR_RAISE(key_indices[0], process_key(scalar.view())); + return span(key_indices, 1); + } + const ArraySpan& array = value.array; + int64_t i = 0; + RETURN_NOT_OK(VisitArrayValuesInline( + array, + [&](std::string_view key_name) { + ARROW_ASSIGN_OR_RAISE(key_indices[i], process_key(key_name)); + ++i; + return Status::OK(); + }, + [&]() { + return Status::KeyError("key name cannot be null"); + })); + return span(key_indices, array.length); + } +}; + +struct PivotKeyMapperFactory { + template + Status Visit(const T& key_type) { + if constexpr (is_base_binary_like(T::type_id)) { + instance = std::make_unique>(); + return instance->Init(options); + } + return Status::NotImplemented("Pivot key type: ", key_type); + } + + static Result> Make(const DataType& key_type, const PivotOptions* options) { + PivotKeyMapperFactory factory{options}; + RETURN_NOT_OK(VisitTypeInline(key_type, &factory)); + return std::move(factory).instance; + } + + const PivotOptions* options; + std::unique_ptr instance{}; +}; + +/* +TODO +would probably like to write: + +Result> MakePivotKeyMapper(const DataType& key_type, + const PivotOptions* options) { + std::unique_ptr instance; + RETURN_NOT_OK(VisitTypeInline(key_type, [&](auto key_type) { + using T = std::decay_t; + if constexpr (is_base_binary_like(T::type_id)) { + instance = std::make_unique>(); + return instance->Init(options); + } + return Status::NotImplemented("Pivot key type: ", key_type); + })); + return instance; +} + +or even: + +Result> MakePivotKeyMapper(const DataType& key_type, + const PivotOptions* options) { + return VisitTypeInline(key_type, [&](auto key_type) -> Result> { + using T = std::decay_t; + if constexpr (is_base_binary_like(T::type_id)) { + auto instance = std::make_unique>(); + RETURN_NOT_OK(instance->Init(options)); + return instance; + } + return Status::NotImplemented("Pivot key type: ", key_type); + }); +} +*/ + +template +struct GroupedPivotImpl : public GroupedAggregator { + Status Init(ExecContext* ctx, const KernelInitArgs& args) override { + DCHECK_EQ(args.inputs.size(), 3); + key_type_ = args.inputs[0].GetSharedPtr(); + value_type_ = args.inputs[1].GetSharedPtr(); + options_ = checked_cast(args.options); + DCHECK_NE(options_, nullptr); + FieldVector fields; + fields.reserve(options_->key_names.size()); + for (const auto& key_name : options_->key_names) { + fields.push_back(field(key_name, value_type_)); + } + out_type_ = struct_(std::move(fields)); + out_struct_type_ = checked_cast(out_type_.get()); +// counts_ = BufferBuilder(ctx->memory_pool()); + return Status::OK(); + } + + Status Resize(int64_t new_num_groups) override { + return Status::NotImplemented("GroupedPivotImpl::Resize"); +// auto added_groups = new_num_groups - num_groups_; +// num_groups_ = new_num_groups; +// return counts_.Append(added_groups * sizeof(int64_t), 0); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + return Status::NotImplemented("GroupedPivotImpl::Merge"); +// auto other = checked_cast(&raw_other); +// +// auto* counts = counts_.mutable_data_as(); +// const auto* other_counts = other->counts_.data_as(); +// +// auto* g = group_id_mapping.GetValues(1); +// for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { +// counts[*g] += other_counts[other_g]; +// } +// return Status::OK(); + } + + Status Consume(const ExecSpan& batch) override { + return Status::NotImplemented("GroupedPivotImpl::Consume"); +// auto* counts = counts_.mutable_data_as(); +// auto* g_begin = batch[0].array.GetValues(1); +// for (auto g_itr = g_begin, end = g_itr + batch.length; g_itr != end; g_itr++) { +// counts[*g_itr] += 1; +// } +// return Status::OK(); + } + + Result Finalize() override { + return Status::NotImplemented("GroupedPivotImpl::Finalize"); +// ARROW_ASSIGN_OR_RAISE(auto counts, counts_.Finish()); +// return std::make_shared(num_groups_, std::move(counts)); + } + + std::shared_ptr out_type() const override { + return out_type_; + } + + std::shared_ptr key_type_; + std::shared_ptr value_type_; + std::shared_ptr out_type_; + const StructType* out_struct_type_; + const PivotOptions* options_; +}; + +// TODO simplify this away? +template +Result> GroupedPivotInit(KernelContext* ctx, + const KernelInitArgs& args) { + ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit>(ctx, args)); +// auto instance = checked_cast*>(impl.get()); +// DCHECK_EQ(args.inputs.size(), 3); +// instance->key_type_ = args.inputs[0].GetSharedPtr(); +// instance->value_type_ = args.inputs[1].GetSharedPtr(); + return impl; +} + +struct GroupedPivotFactory { + template + enable_if_base_binary Visit(const KeyType& type) { + // TODO replace Any() with a more selective matcher for the value type + auto sig = KernelSignature::Make({type.id(), InputType::Any(), InputType(Type::UINT32)}, + OutputType(ResolveGroupOutputType)); + kernel = MakeKernel(std::move(sig), GroupedPivotInit); + return Status::OK(); + } + + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported pivot key type: ", type); + } + + static Result Make(const std::shared_ptr& pivot_key_type) { + GroupedPivotFactory factory; + RETURN_NOT_OK(VisitTypeInline(*pivot_key_type, &factory)); + return std::move(factory.kernel); + } + + HashAggregateKernel kernel; +}; + +// ---------------------------------------------------------------------- +// Docstrings + const FunctionDoc hash_count_doc{ "Count the number of null / non-null values in each group", ("By default, only non-null values are counted.\n" @@ -3456,6 +3691,18 @@ const FunctionDoc hash_one_doc{"Get one value from each group", const FunctionDoc hash_list_doc{"List all values in each group", ("Null values are also returned."), {"array", "group_id_array"}}; + +const FunctionDoc hash_pivot_doc{ + "Pivot values according to a pivot key column", + ("Output is a struct array with as many fields as `PivotOptions.key_names`.\n" + "All output struct fields have the same type as `pivot_values`.\n" + "Each pivot key decides in which output field the corresponding pivot value\n" + "is emitted. If a pivot key doesn't appear in a given group, null is emitted.\n" + "If a pivot key appears twice in a given group, KeyError is raised.\n" + "Behavior of unexpected pivot keys is controlled by PivotOptions."), + {"pivot_keys", "pivot_values", "group_id_array"}, + "PivotOptions"}; + } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { @@ -3705,6 +3952,13 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { GroupedListFactory::Make, func.get())); DCHECK_OK(registry->AddFunction(std::move(func))); } + + { + auto func = std::make_shared("hash_pivot", Arity::Ternary(), + hash_pivot_doc); + DCHECK_OK(AddHashAggKernels(BaseBinaryTypes(), GroupedPivotFactory::Make, func.get())); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal