From 157b179812adb8f29e5966682ff1937f85ce192a Mon Sep 17 00:00:00 2001 From: tianchen Date: Fri, 30 Aug 2019 19:55:04 -0700 Subject: [PATCH] ARROW-6078: [Java] Implement dictionary-encoded subfields for List type Related to [ARROW-6078](https://issues.apache.org/jira/browse/ARROW-6078). For example, int type List (valueCount = 5) has data like below: 10, 20 10, 20 30, 40, 50 30, 40, 50 10, 20 could be encoded to: 0, 1 0, 1 2, 3, 4 2, 3, 4 0, 1 with list type dictionary 10, 20, 30, 40, 50 or 10, 20, 30, 40, 50 Closes #4972 from tianchen92/ARROW-1175 and squashes the following commits: 5d2f751e3 Update java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java fbd122bfb fix c51ec003f add replaceDataVector in BaseListVector 658958b27 make BaseListVector extend FieldVector 6c9d95db3 refactor BaseListVector 0b6cec51c resolve conflict a54ecd1c6 ARROW-6078: Implement dictionary-encoded subfields for List type Lead-authored-by: tianchen Co-authored-by: tianchen92 <875529044@qq.com> Signed-off-by: Micah Kornfield --- .../arrow/vector/complex/BaseListVector.java | 36 +++++ .../complex/BaseRepeatedValueVector.java | 3 +- .../vector/complex/FixedSizeListVector.java | 12 +- .../arrow/vector/complex/ListVector.java | 12 +- .../vector/dictionary/DictionaryEncoder.java | 81 ++++++---- .../dictionary/ListSubfieldEncoder.java | 131 ++++++++++++++++ .../arrow/vector/TestDictionaryVector.java | 141 ++++++++++++++++++ 7 files changed, 387 insertions(+), 29 deletions(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/complex/BaseListVector.java create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseListVector.java new file mode 100644 index 0000000000000..5f547b90176f4 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseListVector.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.complex; + +import org.apache.arrow.vector.FieldVector; + +/** + * Abstraction for all list type vectors. + */ +public interface BaseListVector extends FieldVector { + + /** + * Get data vector start index with the given list index. + */ + int getElementStartIndex(int index); + + /** + * Get data vector end index with the given list index. + */ + int getElementEndIndex(int index); +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index 581f5d83ea5ad..363c92533c3d7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -44,7 +44,7 @@ import io.netty.buffer.ArrowBuf; /** Base class for Vectors that contain repeated values. */ -public abstract class BaseRepeatedValueVector extends BaseValueVector implements RepeatedValueVector { +public abstract class BaseRepeatedValueVector extends BaseValueVector implements RepeatedValueVector, BaseListVector { public static final FieldVector DEFAULT_DATA_VECTOR = ZeroVector.INSTANCE; public static final String DATA_VECTOR_NAME = "$data$"; @@ -305,7 +305,6 @@ protected void replaceDataVector(FieldVector v) { vector = v; } - @Override public int getValueCount() { return valueCount; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index 7cde40900e758..09ebf9e8067f7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -57,7 +57,7 @@ import io.netty.buffer.ArrowBuf; /** A ListVector where every list value is of the same size. */ -public class FixedSizeListVector extends BaseValueVector implements FieldVector, PromotableVector { +public class FixedSizeListVector extends BaseValueVector implements BaseListVector, PromotableVector { public static FixedSizeListVector empty(String name, int size, BufferAllocator allocator) { FieldType fieldType = FieldType.nullable(new ArrowType.FixedSizeList(size)); @@ -543,6 +543,16 @@ public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); } + @Override + public int getElementStartIndex(int index) { + return listSize * index; + } + + @Override + public int getElementEndIndex(int index) { + return listSize * (index + 1); + } + private class TransferImpl implements TransferPair { FixedSizeListVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 74489361c5cdb..66f54c6b98181 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -64,7 +64,7 @@ * * The latter two are managed by its superclass. */ -public class ListVector extends BaseRepeatedValueVector implements FieldVector, PromotableVector { +public class ListVector extends BaseRepeatedValueVector implements PromotableVector { public static ListVector empty(String name, BufferAllocator allocator) { return new ListVector(name, allocator, FieldType.nullable(ArrowType.List.INSTANCE), null); @@ -829,4 +829,14 @@ public void setLastSet(int value) { public int getLastSet() { return lastSet; } + + @Override + public int getElementStartIndex(int index) { + return offsetBuffer.getInt(index * OFFSET_WIDTH); + } + + @Override + public int getElementEndIndex(int index) { + return offsetBuffer.getInt((index + 1) * OFFSET_WIDTH); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 1171c807a0a3a..d431354c9dae6 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -71,6 +71,59 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary) { return encoder.decode(indices); } + /** + * Populates indices between start and end with the encoded values of vector. + * @param vector the vector to encode + * @param indices the index vector + * @param encoding the hash table for encoding + * @param start the start index + * @param end the end index + */ + static void buildIndexVector( + ValueVector vector, + BaseIntVector indices, + DictionaryHashTable encoding, + int start, + int end) { + + for (int i = start; i < end; i++) { + if (!vector.isNull(i)) { + // if it's null leave it null + // note: this may fail if value was not included in the dictionary + int encoded = encoding.getIndex(i, vector); + if (encoded == -1) { + throw new IllegalArgumentException("Dictionary encoding not defined for value:" + vector.getObject(i)); + } + indices.setWithPossibleTruncate(i, encoded); + } + } + } + + /** + * Retrieve values to target vector from index vector. + * @param indices the index vector + * @param transfer the {@link TransferPair} to copy dictionary data into target vector. + * @param dictionaryCount the value count of dictionary vector. + * @param start the start index + * @param end the end index + */ + static void retrieveIndexVector( + BaseIntVector indices, + TransferPair transfer, + int dictionaryCount, + int start, + int end) { + for (int i = start; i < end; i++) { + if (!indices.isNull(i)) { + int indexAsInt = (int) indices.getValueAsLong(i); + if (indexAsInt > dictionaryCount) { + throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt); + } + transfer.copyValueSafe(indexAsInt, i); + } + } + } + /** * Encodes a vector with the built hash table in this encoder. */ @@ -91,22 +144,8 @@ public ValueVector encode(ValueVector vector) { BaseIntVector indices = (BaseIntVector) createdVector; indices.allocateNew(); - int count = vector.getValueCount(); - - for (int i = 0; i < count; i++) { - if (!vector.isNull(i)) { // if it's null leave it null - // note: this may fail if value was not included in the dictionary - //int encoded = lookUps.get(value); - int encoded = hashTable.getIndex(i, vector); - if (encoded == -1) { - throw new IllegalArgumentException("Dictionary encoding not defined for value:" + vector.getObject(i)); - } - indices.setWithPossibleTruncate(i, encoded); - } - } - - indices.setValueCount(count); - + buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount()); + indices.setValueCount(vector.getValueCount()); return indices; } @@ -122,15 +161,7 @@ public ValueVector decode(ValueVector indices) { transfer.getTo().allocateNewSafe(); BaseIntVector baseIntVector = (BaseIntVector) indices; - for (int i = 0; i < count; i++) { - if (!baseIntVector.isNull(i)) { - int indexAsInt = (int) baseIntVector.getValueAsLong(i); - if (indexAsInt > dictionaryCount) { - throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt); - } - transfer.copyValueSafe(indexAsInt, i); - } - } + retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count); ValueVector decoded = transfer.getTo(); decoded.setValueCount(count); return decoded; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java new file mode 100644 index 0000000000000..9ffad8ca5a1ae --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.util.Collections; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.BaseListVector; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.TransferPair; + +/** + * Sub fields encoder/decoder for Dictionary encoded {@link BaseListVector}. + */ +public class ListSubfieldEncoder { + + private final DictionaryHashTable hashTable; + private final Dictionary dictionary; + private final BufferAllocator allocator; + + /** + * Construct an instance. + */ + public ListSubfieldEncoder(Dictionary dictionary, BufferAllocator allocator) { + this.dictionary = dictionary; + this.allocator = allocator; + BaseListVector dictVector = (BaseListVector) dictionary.getVector(); + hashTable = new DictionaryHashTable(getDataVector(dictVector)); + } + + private FieldVector getDataVector(BaseListVector vector) { + return vector.getChildrenFromFields().get(0); + } + + private BaseListVector cloneVector(BaseListVector vector) { + + final FieldType fieldType = vector.getField().getFieldType(); + BaseListVector cloned = (BaseListVector) fieldType.createNewSingleVector(vector.getField().getName(), + allocator, /*schemaCallBack=*/null); + + final ArrowFieldNode fieldNode = new ArrowFieldNode(vector.getValueCount(), vector.getNullCount()); + cloned.loadFieldBuffers(fieldNode, vector.getFieldBuffers()); + + return cloned; + } + + /** + * Dictionary encodes subfields for complex vector with a provided dictionary. + * The dictionary must contain all values in the sub fields vector. + * @param vector vector to encode + * @return dictionary encoded vector + */ + public BaseListVector encodeListSubField(BaseListVector vector) { + final int valueCount = vector.getValueCount(); + + FieldType indexFieldType = new FieldType(vector.getField().isNullable(), + dictionary.getEncoding().getIndexType(), dictionary.getEncoding(), vector.getField().getMetadata()); + Field valueField = new Field(vector.getField().getName(), indexFieldType,null); + + // clone list vector and initialize data vector + BaseListVector encoded = cloneVector(vector); + encoded.initializeChildrenFromFields(Collections.singletonList(valueField)); + BaseIntVector indices = (BaseIntVector) getDataVector(encoded); + + ValueVector dataVector = getDataVector(vector); + for (int i = 0; i < valueCount; i++) { + if (!vector.isNull(i)) { + int start = vector.getElementStartIndex(i); + int end = vector.getElementEndIndex(i); + + DictionaryEncoder.buildIndexVector(dataVector, indices, hashTable, start, end); + } + } + + return encoded; + } + + /** + * Decodes a dictionary subfields encoded vector using the provided dictionary. + * @param vector dictionary encoded vector, its data vector must be int type + * @return vector with values restored from dictionary + */ + public BaseListVector decodeListSubField(BaseListVector vector) { + + int valueCount = vector.getValueCount(); + BaseListVector dictionaryVector = (BaseListVector) dictionary.getVector(); + int dictionaryValueCount = getDataVector(dictionaryVector).getValueCount(); + + // clone list vector and initialize data vector + BaseListVector decoded = cloneVector(vector); + Field dataVectorField = getDataVector(dictionaryVector).getField(); + decoded.initializeChildrenFromFields(Collections.singletonList(dataVectorField)); + + // get data vector + ValueVector dataVector = getDataVector(decoded); + + TransferPair transfer = getDataVector(dictionaryVector).makeTransferPair(dataVector); + BaseIntVector indices = (BaseIntVector) getDataVector(vector); + + for (int i = 0; i < valueCount; i++) { + + if (!vector.isNull(i)) { + int start = vector.getElementStartIndex(i); + int end = vector.getElementEndIndex(i); + + DictionaryEncoder.retrieveIndexVector(indices, transfer, dictionaryValueCount, start, end); + } + } + return decoded; + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 46c0f0d166c67..761d727995fb8 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -27,6 +27,7 @@ import java.util.Arrays; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; @@ -34,12 +35,14 @@ import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.ListSubfieldEncoder; import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableUInt4Holder; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.JsonStringArrayList; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -705,6 +708,144 @@ public void testEncodeMultiVectors() { } } + @Test + public void testEncodeListSubField() { + // Create a new value vector + try (final ListVector vector = ListVector.empty("vector", allocator); + final ListVector dictionaryVector = ListVector.empty("dict", allocator);) { + + UnionListWriter writer = vector.getWriter(); + writer.allocate(); + + //set some values + writeListVector(writer, new int[]{10, 20}); + writeListVector(writer, new int[]{10, 20}); + writeListVector(writer, new int[]{10, 20}); + writeListVector(writer, new int[]{30, 40, 50}); + writeListVector(writer, new int[]{30, 40, 50}); + writeListVector(writer, new int[]{10, 20}); + writer.setValueCount(6); + + UnionListWriter dictWriter = dictionaryVector.getWriter(); + dictWriter.allocate(); + writeListVector(dictWriter, new int[]{10, 20, 30, 40, 50}); + dictionaryVector.setValueCount(1); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + ListSubfieldEncoder encoder = new ListSubfieldEncoder(dictionary, allocator); + + try (final ListVector encoded = (ListVector) encoder.encodeListSubField(vector)) { + // verify indices + assertEquals(ListVector.class, encoded.getClass()); + + assertEquals(6, encoded.getValueCount()); + int[] realValue1 = convertListToIntArray((JsonStringArrayList) encoded.getObject(0)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue1)); + int[] realValue2 = convertListToIntArray((JsonStringArrayList) encoded.getObject(1)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue2)); + int[] realValue3 = convertListToIntArray((JsonStringArrayList) encoded.getObject(2)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue3)); + int[] realValue4 = convertListToIntArray((JsonStringArrayList) encoded.getObject(3)); + assertTrue(Arrays.equals(new int[] {2,3,4}, realValue4)); + int[] realValue5 = convertListToIntArray((JsonStringArrayList) encoded.getObject(4)); + assertTrue(Arrays.equals(new int[] {2,3,4}, realValue5)); + int[] realValue6 = convertListToIntArray((JsonStringArrayList) encoded.getObject(5)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue6)); + + // now run through the decoder and verify we get the original back + try (ValueVector decoded = encoder.decodeListSubField(encoded)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + } + + @Test + public void testEncodeFixedSizeListSubField() { + // Create a new value vector + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", 2, allocator); + final FixedSizeListVector dictionaryVector = FixedSizeListVector.empty("dict", 2, allocator)) { + + vector.allocateNew(); + vector.setValueCount(4); + + IntVector dataVector = + (IntVector) vector.addOrGetVector(FieldType.nullable(Types.MinorType.INT.getType())).getVector(); + dataVector.allocateNew(8); + dataVector.setValueCount(8); + // set value at index 0 + vector.setNotNull(0); + dataVector.set(0, 10); + dataVector.set(1, 20); + // set value at index 1 + vector.setNotNull(1); + dataVector.set(2, 10); + dataVector.set(3, 20); + // set value at index 2 + vector.setNotNull(2); + dataVector.set(4, 30); + dataVector.set(5, 40); + // set value at index 3 + vector.setNotNull(3); + dataVector.set(6, 10); + dataVector.set(7, 20); + + dictionaryVector.allocateNew(); + dictionaryVector.setValueCount(2); + IntVector dictDataVector = + (IntVector) dictionaryVector.addOrGetVector(FieldType.nullable(Types.MinorType.INT.getType())).getVector(); + dictDataVector.allocateNew(4); + dictDataVector.setValueCount(4); + + dictionaryVector.setNotNull(0); + dictDataVector.set(0, 10); + dictDataVector.set(1, 20); + dictionaryVector.setNotNull(1); + dictDataVector.set(2, 30); + dictDataVector.set(3, 40); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + ListSubfieldEncoder encoder = new ListSubfieldEncoder(dictionary, allocator); + + try (final FixedSizeListVector encoded = + (FixedSizeListVector) encoder.encodeListSubField(vector)) { + // verify indices + assertEquals(FixedSizeListVector.class, encoded.getClass()); + + assertEquals(4, encoded.getValueCount()); + int[] realValue1 = convertListToIntArray((JsonStringArrayList) encoded.getObject(0)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue1)); + int[] realValue2 = convertListToIntArray((JsonStringArrayList) encoded.getObject(1)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue2)); + int[] realValue3 = convertListToIntArray((JsonStringArrayList) encoded.getObject(2)); + assertTrue(Arrays.equals(new int[] {2,3}, realValue3)); + int[] realValue4 = convertListToIntArray((JsonStringArrayList) encoded.getObject(3)); + assertTrue(Arrays.equals(new int[] {0,1}, realValue4)); + + // now run through the decoder and verify we get the original back + try (ValueVector decoded = encoder.decodeListSubField(encoded)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(vector.getObject(i), decoded.getObject(i)); + } + } + } + } + } + + private int[] convertListToIntArray(JsonStringArrayList list) { + int[] values = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + values[i] = (int) list.get(i); + } + return values; + } + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { writer.start(); writer.integer("f0").writeInt(value1);