Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and test wrap_in_list_array #8998

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 2 additions & 36 deletions crates/store/re_sorbet/src/chunk_batch.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use std::sync::Arc;

use arrow::{
array::{
Array as ArrowArray, ArrayRef as ArrowArrayRef, AsArray, ListArray as ArrowListArray,
RecordBatch as ArrowRecordBatch, RecordBatchOptions, StructArray as ArrowStructArray,
},
datatypes::{
DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef,
Fields as ArrowFields, Schema as ArrowSchema,
},
datatypes::{FieldRef as ArrowFieldRef, Fields as ArrowFields, Schema as ArrowSchema},
};

use re_arrow_util::{into_arrow_ref, ArrowArrayDowncastRef};
Expand Down Expand Up @@ -276,7 +271,7 @@ fn make_all_data_columns_list_arrays(batch: &ArrowRecordBatch) -> ArrowRecordBat
.get("rerun.kind")
.is_some_and(|kind| kind == "data");
if is_data_column && !is_list_array {
let (field, array) = wrap_in_list_array(field, array);
let (field, array) = re_arrow_util::wrap_in_list_array(field, array.clone());
fields.push(field.into());
columns.push(into_arrow_ref(array));
} else {
Expand All @@ -294,32 +289,3 @@ fn make_all_data_columns_list_arrays(batch: &ArrowRecordBatch) -> ArrowRecordBat
)
.expect("Can't fail")
}

// TODO(cmc): we can do something faster/simpler here; see https://github.com/rerun-io/rerun/pull/8945#discussion_r1950689060
fn wrap_in_list_array(field: &ArrowField, data: &dyn ArrowArray) -> (ArrowField, ArrowListArray) {
re_tracing::profile_function!();

// We slice each column array into individual arrays and then convert the whole lot into a ListArray

let data_field_inner =
ArrowField::new("item", field.data_type().clone(), true /* nullable */);

let data_field = ArrowField::new(
field.name().clone(),
ArrowDataType::List(Arc::new(data_field_inner.clone())),
false, /* not nullable */
)
.with_metadata(field.metadata().clone());

let mut sliced: Vec<ArrowArrayRef> = Vec::new();
for idx in 0..data.len() {
sliced.push(data.slice(idx, 1));
}

let data_arrays = sliced.iter().map(|e| Some(e.as_ref())).collect::<Vec<_>>();
#[allow(clippy::unwrap_used)] // we know we've given the right field type
let list_array: ArrowListArray =
re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
.unwrap();
(data_field, list_array)
}
90 changes: 87 additions & 3 deletions crates/utils/re_arrow_util/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Helpers for working with arrow

use std::sync::Arc;

use arrow::{
array::{Array, ArrayRef, ArrowPrimitiveType, BooleanArray, ListArray, PrimitiveArray},
buffer::{NullBuffer, OffsetBuffer},
Expand Down Expand Up @@ -35,9 +37,7 @@ pub fn into_arrow_ref(array: impl Array + 'static) -> ArrayRef {
}

/// Returns an iterator with the lengths of the offsets.
pub fn offsets_lengths(
offsets: &arrow::buffer::OffsetBuffer<i32>,
) -> impl Iterator<Item = usize> + '_ {
pub fn offsets_lengths(offsets: &OffsetBuffer<i32>) -> impl Iterator<Item = usize> + '_ {
// TODO(emilk): remove when we update to Arrow 54 (which has an API for this)
offsets.windows(2).map(|w| {
let start = w[0];
Expand Down Expand Up @@ -384,3 +384,87 @@ where
array.shrink_to_fit(); // VERY IMPORTANT! https://github.com/rerun-io/rerun/issues/7222
array
}

// ----------------------------------------------------------------------------

/// Convert `[A, B, null, D, …]` into `[[A], [B], null, [D], …]`
pub fn wrap_in_list_array(field: &Field, array: ArrayRef) -> (Field, ListArray) {
re_tracing::profile_function!();

debug_assert_eq!(field.data_type(), array.data_type());

let item_field = Arc::new(Field::new(
"item",
field.data_type().clone(),
field.is_nullable(), // TODO(emilk): it would be nice to remove the "inner nullability", and just have outer nullability.
emilk marked this conversation as resolved.
Show resolved Hide resolved
));

let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(array.len()));
let nulls = array.nulls().cloned();
let list_array = ListArray::new(item_field, offsets, array, nulls);

let list_field = Field::new(
field.name().clone(),
list_array.data_type().clone(),
true, // All components in Rerun has "outer nullability"
)
.with_metadata(field.metadata().clone());

(list_field, list_array)
}

#[cfg(test)]
mod tests {

use arrow::{
array::{Array as _, AsArray as _, Int32Array},
buffer::{NullBuffer, ScalarBuffer},
datatypes::{DataType, Int32Type},
};

use super::*;

#[test]
fn test_wrap_in_list_array() {
// Convert [42, 1337, null, 69] into [[42], [1337], null, [69]]
emilk marked this conversation as resolved.
Show resolved Hide resolved
let original_field = Field::new("item", DataType::Int32, true);
let original_array = Int32Array::new(
ScalarBuffer::from(vec![42, 69, -1, 1337]),
Some(NullBuffer::from(vec![true, true, false, true])),
);
assert_eq!(original_array.len(), 4);
assert_eq!(original_array.null_count(), 1);

let (new_field, new_array) =
wrap_in_list_array(&original_field, into_arrow_ref(original_array.clone()));

assert_eq!(new_field.data_type(), new_array.data_type());
assert_eq!(new_array.len(), original_array.len());
assert_eq!(new_array.null_count(), original_array.null_count());
emilk marked this conversation as resolved.
Show resolved Hide resolved

assert_eq!(
new_array
.value(0)
.as_primitive::<Int32Type>()
.values()
.as_ref(),
&[42]
);
assert_eq!(
new_array
.value(1)
.as_primitive::<Int32Type>()
.values()
.as_ref(),
&[69]
);
assert_eq!(
new_array
.value(3)
.as_primitive::<Int32Type>()
.values()
.as_ref(),
&[1337]
);
}
}
Loading