From aaaebcabf87214b61adc1894d0d74592566d9796 Mon Sep 17 00:00:00 2001 From: Jaroslav Tulach Date: Tue, 11 Jun 2024 14:50:59 +0200 Subject: [PATCH] Implement and benchmark `ArrowOperationPlus` node (#10150) Prototype of #10056 showing `+` operation implemented in the _Arrow language_. --- .../enso/interpreter/arrow/ArrowParser.java | 46 ++-- .../arrow/node/ArrowCastFixedSizeNode.java | 16 -- .../interpreter/arrow/node/ArrowEvalNode.java | 12 +- .../arrow/node/ArrowFixedSizeNode.java | 16 -- .../arrow/runtime/ArrowFixedArrayInt.java | 169 +++++++++++- .../runtime/ArrowFixedSizeArrayBuilder.java | 115 +++++--- .../runtime/ArrowFixedSizeArrayFactory.java | 102 +++---- .../arrow/runtime/ArrowOperationPlus.java | 97 +++++++ .../arrow/runtime/ByteBufferDirect.java | 250 ++++++++++++------ .../arrow/runtime/OperationPlus.java | 69 +++++ .../arrow/runtime/RoundingUtil.java | 21 +- .../arrow/runtime/ScalarOperationNode.java | 10 + ...uilderNode.java => ValueToNumberNode.java} | 146 ++++------ .../enso/interpreter/arrow/AddArrowTest.java | 142 ++++++++++ .../interpreter/arrow/VerifyArrowTest.java | 30 ++- .../node/callable/ApplicationNode.java | 9 +- .../node/callable/InvokeCallableNode.java | 19 +- test/Benchmarks/src/Table/Arithmetic.enso | 105 ++++++-- 18 files changed, 1006 insertions(+), 368 deletions(-) delete mode 100644 engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowCastFixedSizeNode.java delete mode 100644 engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowFixedSizeNode.java create mode 100644 engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowOperationPlus.java create mode 100644 engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/OperationPlus.java create mode 100644 engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ScalarOperationNode.java rename engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/{WriteToBuilderNode.java => ValueToNumberNode.java} (56%) create mode 100644 engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/AddArrowTest.java diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/ArrowParser.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/ArrowParser.java index 5d4127bc3ef4..363a7474089b 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/ArrowParser.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/ArrowParser.java @@ -11,36 +11,42 @@ private ArrowParser() {} public record Result(PhysicalLayout physicalLayout, LogicalLayout logicalLayout, Mode mode) {} public static Result parse(Source source) { - String src = source.getCharacters().toString(); - Matcher m = NEW_ARRAY_CONSTR.matcher(src); + String src = source.getCharacters().toString().replace('\n', ' ').trim(); + Matcher m = PATTERN.matcher(src); if (m.find()) { try { - var layout = LogicalLayout.valueOf(m.group(1)); - return new Result(PhysicalLayout.Primitive, layout, Mode.Allocate); + var layout = LogicalLayout.valueOf(m.group(2)); + var mode = Mode.parse(m.group(1)); + if (layout != null && mode != null) { + return new Result(PhysicalLayout.Primitive, layout, mode); + } } catch (IllegalArgumentException iae) { // propagate warning - return null; - } - } - - m = CAST_PATTERN.matcher(src); - if (m.find()) { - try { - var layout = LogicalLayout.valueOf(m.group(1)); - return new Result(PhysicalLayout.Primitive, layout, Mode.Cast); - } catch (IllegalArgumentException iae) { - // propagate warning - return null; } } return null; } - private static final Pattern NEW_ARRAY_CONSTR = Pattern.compile("^new\\[(.+)\\]$"); - private static final Pattern CAST_PATTERN = Pattern.compile("^cast\\[(.+)\\]$"); + private static final Pattern PATTERN = Pattern.compile("^([a-z\\+]+)\\[(.+)\\]$"); public enum Mode { - Allocate, - Cast + Allocate("new"), + Cast("cast"), + Plus("+"); + + private final String op; + + private Mode(String text) { + this.op = text; + } + + static Mode parse(String operation) { + for (var m : values()) { + if (m.op.equals(operation)) { + return m; + } + } + return null; + } } } diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowCastFixedSizeNode.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowCastFixedSizeNode.java deleted file mode 100644 index 7561d09c51ab..000000000000 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowCastFixedSizeNode.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.enso.interpreter.arrow.node; - -import com.oracle.truffle.api.nodes.Node; -import org.enso.interpreter.arrow.LogicalLayout; -import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory; - -public class ArrowCastFixedSizeNode extends Node { - - static ArrowCastFixedSizeNode create() { - return new ArrowCastFixedSizeNode(); - } - - public Object execute(LogicalLayout layoutType) { - return new ArrowCastToFixedSizeArrayFactory(layoutType); - } -} diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowEvalNode.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowEvalNode.java index 61f430f9ec82..7619cb2fc396 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowEvalNode.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowEvalNode.java @@ -6,13 +6,13 @@ import com.oracle.truffle.api.nodes.RootNode; import org.enso.interpreter.arrow.ArrowLanguage; import org.enso.interpreter.arrow.ArrowParser; +import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory; +import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory; +import org.enso.interpreter.arrow.runtime.ArrowOperationPlus; public class ArrowEvalNode extends RootNode { private final ArrowParser.Result code; - @Child private ArrowFixedSizeNode fixedPhysicalLayout = ArrowFixedSizeNode.create(); - @Child private ArrowCastFixedSizeNode castToFixedPhysicalLayout = ArrowCastFixedSizeNode.create(); - public static ArrowEvalNode create(ArrowLanguage language, ArrowParser.Result code) { return new ArrowEvalNode(language, code); } @@ -25,8 +25,10 @@ private ArrowEvalNode(ArrowLanguage language, ArrowParser.Result code) { public Object execute(VirtualFrame frame) { return switch (code.physicalLayout()) { case Primitive -> switch (code.mode()) { - case Allocate -> fixedPhysicalLayout.execute(code.logicalLayout()); - case Cast -> castToFixedPhysicalLayout.execute(code.logicalLayout()); + case Allocate -> new ArrowFixedSizeArrayFactory(code.logicalLayout()); + case Cast -> new ArrowCastToFixedSizeArrayFactory(code.logicalLayout()); + case Plus -> new ArrowOperationPlus(code.logicalLayout()); + default -> throw CompilerDirectives.shouldNotReachHere("unsupported mode"); }; default -> throw CompilerDirectives.shouldNotReachHere("unsupported physical layout"); }; diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowFixedSizeNode.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowFixedSizeNode.java deleted file mode 100644 index 91e50fdc8456..000000000000 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/node/ArrowFixedSizeNode.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.enso.interpreter.arrow.node; - -import com.oracle.truffle.api.nodes.Node; -import org.enso.interpreter.arrow.LogicalLayout; -import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory; - -public class ArrowFixedSizeNode extends Node { - - static ArrowFixedSizeNode create() { - return new ArrowFixedSizeNode(); - } - - public Object execute(LogicalLayout layoutType) { - return new ArrowFixedSizeArrayFactory(layoutType); - } -} diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedArrayInt.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedArrayInt.java index efd263b9bd6a..86d3144bba4f 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedArrayInt.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedArrayInt.java @@ -1,13 +1,23 @@ package org.enso.interpreter.arrow.runtime; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Bind; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.ImportStatic; +import com.oracle.truffle.api.dsl.NeverDefault; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.InvalidArrayIndexException; +import com.oracle.truffle.api.interop.StopIterationException; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.InlinedExactClassProfile; +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; import org.enso.interpreter.arrow.LogicalLayout; @ExportLibrary(InteropLibrary.class) @@ -27,7 +37,24 @@ public LogicalLayout getUnit() { } @ExportMessage - public boolean hasArrayElements() { + boolean hasArrayElements() { + return true; + } + + @ExportMessage + Object getIterator( + @Cached(value = "this.getUnit()", allowUncached = true) LogicalLayout cachedUnit) + throws UnsupportedMessageException { + if (cachedUnit == LogicalLayout.Int64) { + var dataIt = new LongIterator(buffer.getDataBuffer(), cachedUnit.sizeInBytes()); + var nullIt = new NullIterator(dataIt, buffer.getBitmapBuffer()); + return nullIt; + } + return new GenericIterator(this); + } + + @ExportMessage + boolean hasIterator() { return true; } @@ -65,13 +92,18 @@ public static Object doInt(ArrowFixedArrayInt receiver, long index) } @Specialization(guards = "receiver.getUnit() == Int64") - public static Object doLong(ArrowFixedArrayInt receiver, long index) + public static Object doLong( + ArrowFixedArrayInt receiver, + long index, + @Bind("$node") Node node, + @CachedLibrary("receiver") InteropLibrary iop, + @Cached InlinedExactClassProfile bufferClazz) throws UnsupportedMessageException, InvalidArrayIndexException { - var at = adjustedIndex(receiver.buffer, receiver.unit, receiver.size, index); + var at = adjustedIndex(receiver.buffer, LogicalLayout.Int64, receiver.size, index); if (receiver.buffer.isNull((int) index)) { return NullValue.get(); } - return receiver.buffer.getLong(at); + return receiver.buffer.getLong(at, iop, bufferClazz); } } @@ -96,4 +128,133 @@ boolean isArrayElementReadable(long index) { private static int typeAdjustedIndex(long index, SizeInBytes unit) { return Math.toIntExact(index * unit.sizeInBytes()); } + + @ExportLibrary(InteropLibrary.class) + static final class LongIterator implements TruffleObject { + private int at; + private final ByteBuffer buffer; + @NeverDefault final int step; + + LongIterator(ByteBuffer buffer, int step) { + assert step != 0; + this.buffer = buffer; + this.step = step; + } + + @ExportMessage + Object getIteratorNextElement( + @Bind("$node") Node node, + @Cached("this.step") int step, + @Cached InlinedExactClassProfile bufferTypeProfile) + throws StopIterationException { + var buf = bufferTypeProfile.profile(node, buffer); + try { + var res = buf.getLong(at); + at += step; + return res; + } catch (BufferOverflowException ex) { + CompilerDirectives.transferToInterpreter(); + throw StopIterationException.create(); + } + } + + @ExportMessage + boolean isIterator() { + return true; + } + + @ExportMessage + boolean hasIteratorNextElement() throws UnsupportedMessageException { + return at < buffer.limit(); + } + } + + @ExportLibrary(value = InteropLibrary.class) + static final class NullIterator implements TruffleObject { + private final TruffleObject it; + private final ByteBuffer buffer; + private byte byteMask; + private byte byteValue; + + NullIterator(TruffleObject delegate, ByteBuffer buffer) { + this.it = delegate; + this.buffer = buffer; + } + + final TruffleObject it() { + return it; + } + + @ExportMessage(limit = "3") + Object getIteratorNextElement( + @Bind("$node") Node node, + @CachedLibrary("this.it()") InteropLibrary iopIt, + @Cached InlinedExactClassProfile bufferTypeProfile) + throws StopIterationException, UnsupportedMessageException { + var element = iopIt.getIteratorNextElement(it); + if (buffer != null) { + var buf = bufferTypeProfile.profile(node, buffer); + if (byteMask == 0) { + // (byte) (0x01 << 8) ==> 0 + byteValue = buf.get(); + byteMask = 0x01; + } + var include = byteValue & byteMask; + byteMask = (byte) (byteMask << 1); + if (include == 0) { + return NullValue.get(); + } + } + return element; + } + + @ExportMessage + boolean isIterator() { + return true; + } + + @ExportMessage(limit = "3") + boolean hasIteratorNextElement(@CachedLibrary("this.it()") InteropLibrary iopIt) + throws UnsupportedMessageException { + return iopIt.hasIteratorNextElement(it); + } + } + + @ExportLibrary(InteropLibrary.class) + static final class GenericIterator implements TruffleObject { + private int at; + private final TruffleObject array; + + GenericIterator(TruffleObject array) { + assert InteropLibrary.getUncached().hasArrayElements(array); + this.array = array; + } + + TruffleObject array() { + return array; + } + + @ExportMessage(limit = "3") + Object getIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop) + throws StopIterationException { + try { + var res = iop.readArrayElement(array, at); + at++; + return res; + } catch (UnsupportedMessageException | InvalidArrayIndexException ex) { + throw StopIterationException.create(); + } + } + + @ExportMessage + boolean isIterator() { + return true; + } + + @ExportMessage(limit = "3") + boolean hasIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop) + throws UnsupportedMessageException { + return at < iop.getArraySize(array); + } + } } diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayBuilder.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayBuilder.java index 8ad0315941b6..a0dbe79f2b43 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayBuilder.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayBuilder.java @@ -1,22 +1,27 @@ package org.enso.interpreter.arrow.runtime; +import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Cached.Shared; +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.interop.UnknownIdentifierException; import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; +import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; +import com.oracle.truffle.api.nodes.Node; import org.enso.interpreter.arrow.LogicalLayout; @ExportLibrary(InteropLibrary.class) public final class ArrowFixedSizeArrayBuilder implements TruffleObject { - private final ByteBufferDirect buffer; private final LogicalLayout unit; private final int size; - private int index; - private boolean sealed; + private ByteBufferDirect buffer; private static final String APPEND_OP = "append"; private static final String BUILD_OP = "build"; @@ -25,8 +30,6 @@ public ArrowFixedSizeArrayBuilder(int size, LogicalLayout unit) { this.size = size; this.unit = unit; this.buffer = ByteBufferDirect.forSize(size, unit); - this.index = 0; - this.sealed = false; } public LogicalLayout getUnit() { @@ -34,7 +37,7 @@ public LogicalLayout getUnit() { } public boolean isSealed() { - return sealed; + return buffer == null; } public ByteBufferDirect getBuffer() { @@ -53,7 +56,7 @@ public boolean hasMembers() { @ExportMessage public boolean isMemberInvocable(String member) { return switch (member) { - case APPEND_OP -> !this.sealed; + case APPEND_OP -> buffer != null; case BUILD_OP -> true; default -> false; }; @@ -65,33 +68,83 @@ Object getMembers(boolean includeInternal) throws UnsupportedMessageException { } @ExportMessage - Object invokeMember( - String name, - Object[] args, - @Cached(value = "buildWriterOrNull(name)", neverDefault = true) - WriteToBuilderNode writeToBuilderNode) + Object invokeMember(String name, Object[] args, @Cached AppendNode append) throws UnsupportedMessageException, UnknownIdentifierException, UnsupportedTypeException { - switch (name) { - case BUILD_OP: - sealed = true; - return switch (unit) { - case Date32, Date64 -> new ArrowFixedArrayDate(buffer, size, unit); - case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(buffer, size, unit); - }; - case APPEND_OP: - if (sealed) { - throw UnsupportedMessageException.create(); - } - var current = index; - writeToBuilderNode.executeWrite(this, current, args[0]); - index += 1; - return NullValue.get(); - default: - throw UnknownIdentifierException.create(name); + return switch (name) { + case BUILD_OP -> build(); + case APPEND_OP -> { + append.executeAppend(this, args[0]); + yield NullValue.get(); + } + default -> throw UnknownIdentifierException.create(name); + }; + } + + private final TruffleObject build() throws UnsupportedMessageException { + var b = buffer; + if (b == null) { + throw UnsupportedMessageException.create(); + } + buffer = null; + return switch (unit) { + case Date32, Date64 -> new ArrowFixedArrayDate(b, size, unit); + case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(b, size, unit); + }; + } + + @GenerateUncached + @GenerateInline(false) + abstract static class AppendNode extends Node { + abstract void executeAppend(ArrowFixedSizeArrayBuilder builder, Object value) + throws UnsupportedTypeException, UnsupportedMessageException; + + @Specialization( + limit = "3", + guards = {"builder.getUnit() == cachedUnit"}) + static void writeToBuffer( + ArrowFixedSizeArrayBuilder builder, + Object value, + @Cached(value = "builder.getUnit()", allowUncached = true) LogicalLayout cachedUnit, + @Shared("put") @Cached ByteBufferDirect.PutNode put, + @Shared("value") @Cached ValueToNumberNode valueNode, + @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) + throws UnsupportedTypeException, UnsupportedMessageException { + if (iop.isNull(value)) { + put.putNull(builder.buffer, cachedUnit); + return; + } + var number = valueNode.executeAdjust(cachedUnit, value); + switch (number) { + case Byte b -> put.put(builder.buffer, b); + case Short s -> put.putShort(builder.buffer, s); + case Integer i -> put.putInt(builder.buffer, i); + case Long l -> put.putLong(builder.buffer, l); + default -> throw CompilerDirectives.shouldNotReachHere(); + } + } + + @Specialization(replaces = "writeToBuffer") + static void writeToBufferUncached( + ArrowFixedSizeArrayBuilder builder, + Object value, + @Shared("put") @Cached ByteBufferDirect.PutNode put, + @Shared("value") @Cached ValueToNumberNode valueNode, + @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) + throws UnsupportedTypeException, UnsupportedMessageException { + writeToBuffer(builder, value, builder.getUnit(), put, valueNode, iop); } } - static WriteToBuilderNode buildWriterOrNull(String op) { - return APPEND_OP.equals(op) ? WriteToBuilderNode.build() : WriteToBuilderNodeGen.getUncached(); + @GenerateUncached + @GenerateInline(false) + abstract static class BuildNode extends Node { + abstract TruffleObject executeBuild(ArrowFixedSizeArrayBuilder builder) + throws UnsupportedMessageException; + + @Specialization + static TruffleObject buildIt(ArrowFixedSizeArrayBuilder builder) + throws UnsupportedMessageException { + return builder.build(); + } } } diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayFactory.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayFactory.java index e1962422c975..7ca43828de17 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayFactory.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowFixedSizeArrayFactory.java @@ -2,8 +2,8 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.Fallback; -import com.oracle.truffle.api.dsl.ImportStatic; +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.dsl.GenerateUncached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.TruffleObject; @@ -11,10 +11,11 @@ import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; +import com.oracle.truffle.api.nodes.Node; import org.enso.interpreter.arrow.LogicalLayout; @ExportLibrary(InteropLibrary.class) -public class ArrowFixedSizeArrayFactory implements TruffleObject { +public final class ArrowFixedSizeArrayFactory implements TruffleObject { private final LogicalLayout logicalLayout; @@ -23,7 +24,7 @@ public ArrowFixedSizeArrayFactory(LogicalLayout logicalLayout) { } @ExportMessage - public boolean isInstantiable() { + boolean isInstantiable() { return true; } @@ -32,79 +33,36 @@ public LogicalLayout getLayout() { } @ExportMessage - @ImportStatic(LogicalLayout.class) - static class Instantiate { - @Specialization(guards = "receiver.getLayout() == Date32") - static Object doDate32( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); - } - - @Specialization(guards = "receiver.getLayout() == Date64") - static Object doDate64( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); - } - - @Specialization(guards = "receiver.getLayout() == Int8") - static Object doInt8( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); - } - - @Specialization(guards = "receiver.getLayout() == Int16") - static Object doInt16( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); - } + ArrowFixedSizeArrayBuilder instantiate( + Object[] args, + @Cached InstantiateNode instantiate, + @CachedLibrary(limit = "1") InteropLibrary iop) + throws UnsupportedMessageException { + var size = arraySize(args, iop); + return instantiate.allocateBuilder(logicalLayout, size); + } - @Specialization(guards = "receiver.getLayout() == Int32") - static Object doInt32( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); + private static int arraySize(Object[] args, InteropLibrary interop) + throws UnsupportedMessageException { + if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) { + throw UnsupportedMessageException.create(); } + return interop.asInt(args[0]); + } - @Specialization(guards = "receiver.getLayout() == Int64") - static Object doInt64( - ArrowFixedSizeArrayFactory receiver, - Object[] args, - @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) - throws UnsupportedMessageException { - return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout); - } + @GenerateUncached + @GenerateInline(false) + abstract static class InstantiateNode extends Node { + abstract ArrowFixedSizeArrayBuilder executeNew(LogicalLayout logicalLayout, long size); - @CompilerDirectives.TruffleBoundary - private static int arraySize(Object[] args, InteropLibrary interop) - throws UnsupportedMessageException { - if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) { - throw UnsupportedMessageException.create(); + @Specialization + final ArrowFixedSizeArrayBuilder allocateBuilder(LogicalLayout logicalLayout, long size) { + try { + return new ArrowFixedSizeArrayBuilder(Math.toIntExact(size), logicalLayout); + } catch (ArithmeticException ex) { + CompilerDirectives.transferToInterpreter(); + throw ex; } - return interop.asInt(args[0]); - } - - @Fallback - static Object doOther(ArrowFixedSizeArrayFactory receiver, Object[] args) { - throw CompilerDirectives.shouldNotReachHere(unknownLayoutMessage(receiver.getLayout())); - } - - @CompilerDirectives.TruffleBoundary - private static String unknownLayoutMessage(SizeInBytes layout) { - return "unknown layout: " + layout.toString(); } } } diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowOperationPlus.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowOperationPlus.java new file mode 100644 index 000000000000..4351a328e67c --- /dev/null +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ArrowOperationPlus.java @@ -0,0 +1,97 @@ +package org.enso.interpreter.arrow.runtime; + +import com.oracle.truffle.api.dsl.Bind; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.interop.ArityException; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.interop.StopIterationException; +import com.oracle.truffle.api.interop.TruffleObject; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.interop.UnsupportedTypeException; +import com.oracle.truffle.api.library.CachedLibrary; +import com.oracle.truffle.api.library.ExportLibrary; +import com.oracle.truffle.api.library.ExportMessage; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.InlinedExactClassProfile; +import org.enso.interpreter.arrow.LogicalLayout; + +@ExportLibrary(InteropLibrary.class) +public final class ArrowOperationPlus implements TruffleObject { + private final LogicalLayout layout; + + public ArrowOperationPlus(LogicalLayout layout) { + this.layout = layout; + } + + final LogicalLayout layout() { + return layout; + } + + @ExportMessage + boolean isExecutable() { + return true; + } + + static Object args(Object[] args, int index) throws ArityException { + if (args.length != 2) { + throw ArityException.create(2, 2, args.length); + } + return args[index]; + } + + static Object it(Object[] args, InteropLibrary iop, int index) + throws ArityException, UnsupportedMessageException { + if (args.length != 2) { + throw ArityException.create(2, 2, args.length); + } + return iop.getIterator(args[index]); + } + + ScalarOperationNode createScalarOp(boolean cached) { + return cached ? OperationPlus.create() : OperationPlus.getUncached(); + } + + @ExportMessage(limit = "3") + Object execute( + Object[] args, + @Bind("$node") Node node, + @Cached(value = "this.layout()", allowUncached = true) LogicalLayout cachedLayout, + @Cached ArrowFixedSizeArrayFactory.InstantiateNode factory, + @CachedLibrary("args(args, 0)") InteropLibrary iopArray0, + @CachedLibrary("args(args, 1)") InteropLibrary iopArray1, + @CachedLibrary("it(args, iopArray0, 0)") InteropLibrary iopIt0, + @CachedLibrary("it(args, iopArray1, 1)") InteropLibrary iopIt1, + @CachedLibrary(limit = "3") InteropLibrary iopElem, + @Cached(value = "this.createScalarOp(true)", uncached = "this.createScalarOp(false)") + ScalarOperationNode opNode, + @Cached ArrowFixedSizeArrayBuilder.AppendNode append, + @Cached ArrowFixedSizeArrayBuilder.BuildNode build, + @Cached InlinedExactClassProfile typeOfBuf0, + @Cached InlinedExactClassProfile typeOfBuf1) + throws ArityException, UnsupportedTypeException, UnsupportedMessageException { + var arr0 = args[0]; + var arr1 = args[1]; + if (!iopArray0.hasArrayElements(arr0) || !iopArray1.hasArrayElements(arr1)) { + throw UnsupportedTypeException.create(args); + } + var len = iopArray0.getArraySize(arr0); + if (len != iopArray1.getArraySize(arr1)) { + throw UnsupportedTypeException.create(args, "Arrays must have the same length"); + } + var it0 = iopArray0.getIterator(arr0); + var it1 = iopArray1.getIterator(arr1); + var builder = factory.allocateBuilder(cachedLayout, len); + + for (long i = 0; i < len; i++) { + try { + var elem0 = iopIt0.getIteratorNextElement(it0); + var elem1 = iopIt1.getIteratorNextElement(it1); + var res = opNode.executeOp(elem0, elem1); + append.executeAppend(builder, res); + } catch (StopIterationException ex) { + throw UnsupportedTypeException.create(new Object[] {it0, it1}); + } + } + return build.executeBuild(builder); + } +} diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ByteBufferDirect.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ByteBufferDirect.java index a5462418b847..a419566e8eec 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ByteBufferDirect.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ByteBufferDirect.java @@ -1,14 +1,25 @@ package org.enso.interpreter.arrow.runtime; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Bind; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.NeverDefault; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.InlinedExactClassProfile; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.enso.interpreter.arrow.LogicalLayout; +import org.enso.interpreter.arrow.runtime.ByteBufferDirect.DataBufferNode; import org.enso.interpreter.arrow.util.MemoryUtil; final class ByteBufferDirect implements AutoCloseable { private final ByteBuffer allocated; private final ByteBuffer dataBuffer; - private final ByteBuffer bitmapBuffer; + private ByteBuffer bitmapBuffer; /** * Creates a fresh buffer with an empty non-null bitmap.. @@ -22,10 +33,7 @@ private ByteBufferDirect(int valueCount, SizeInBytes unit) { this.allocated = buffer; this.dataBuffer = buffer.slice(0, padded.getDataBufferSizeInBytes()); - this.bitmapBuffer = buffer.slice(dataBuffer.capacity(), padded.getValidityBitmapSizeInBytes()); - for (int i = 0; i < bitmapBuffer.capacity(); i++) { - bitmapBuffer.put(i, (byte) 0); - } + this.bitmapBuffer = null; } /** @@ -53,8 +61,13 @@ private ByteBufferDirect(ByteBuffer allocated, ByteBuffer dataBuffer, int bitmap this.dataBuffer = dataBuffer; this.bitmapBuffer = allocated.slice(dataBuffer.capacity(), bitmapSizeInBytes); for (int i = 0; i < bitmapBuffer.capacity(); i++) { - bitmapBuffer.put(i, (byte) 255); + bitmapBuffer.put(i, (byte) 0xff); } + bitmapBuffer.rewind(); + } + + static ByteBufferDirect forBuffer(ByteBuffer buf) { + return new ByteBufferDirect(buf, buf, null); } /** @@ -93,120 +106,201 @@ public static ByteBufferDirect fromAddress( return new ByteBufferDirect(allocated, dataBuffer, bitmapBuffer); } - public void put(byte b) throws UnsupportedMessageException { - setValidityBitmap(0, 1); - dataBuffer.put(b); + @CompilerDirectives.TruffleBoundary + ByteBuffer initializeBitmapBuffer() { + assert bitmapBuffer == null; + bitmapBuffer = + allocated.slice(dataBuffer.capacity(), allocated.capacity() - dataBuffer.capacity()); + for (var i = 0; i < bitmapBuffer.capacity(); i++) { + bitmapBuffer.put(i, (byte) 0xff); + } + return bitmapBuffer; } - public byte get(int index) throws UnsupportedMessageException { - return dataBuffer.get(index); + final ByteBuffer getDataBuffer() { + return dataBuffer; } - public void put(int index, byte b) throws UnsupportedMessageException { - setValidityBitmap(index, 1); - dataBuffer.put(index, b); + final ByteBuffer getBitmapBuffer() { + return bitmapBuffer; } - public void putShort(short value) throws UnsupportedMessageException { - setValidityBitmap(0, 2); - dataBuffer.putShort(value); - } + @GenerateInline(false) + @GenerateUncached + abstract static class DataBufferNode extends Node { + static DataBufferNode create() { + return ByteBufferDirectFactory.DataBufferNodeGen.create(); + } - public short getShort(int index) throws UnsupportedMessageException { - return dataBuffer.getShort(index); - } + static DataBufferNode getUncached() { + return ByteBufferDirectFactory.DataBufferNodeGen.getUncached(); + } - public void putShort(int index, short value) throws UnsupportedMessageException { - setValidityBitmap(index, 2); - dataBuffer.putShort(index, value); - } + abstract ByteBuffer executeDataBuffer(ByteBufferDirect direct); - public void putInt(int value) throws UnsupportedMessageException { - setValidityBitmap(0, 4); - dataBuffer.putInt(value); + @Specialization + static ByteBuffer profiledDataBuffer( + ByteBufferDirect direct, + @Bind("$node") Node node, + @Cached InlinedExactClassProfile bufferClazz) { + return bufferClazz.profile(node, direct.dataBuffer); + } } - public int getInt(int index) throws UnsupportedMessageException { - return dataBuffer.getInt(index); - } + @GenerateInline(false) + @GenerateUncached + abstract static class BitmapBufferNode extends Node { + static BitmapBufferNode create() { + return ByteBufferDirectFactory.BitmapBufferNodeGen.create(); + } - public void putInt(int index, int value) { - setValidityBitmap(index, 4); - dataBuffer.putInt(index, value); - } + static BitmapBufferNode getUncached() { + return ByteBufferDirectFactory.BitmapBufferNodeGen.getUncached(); + } - public void putLong(long value) throws UnsupportedMessageException { - setValidityBitmap(0, 8); - dataBuffer.putLong(value); + abstract ByteBuffer executeBitmapBuffer(ByteBufferDirect direct, boolean forceCreation); + + @Specialization + static ByteBuffer profiledBitmapBuffer( + ByteBufferDirect direct, + boolean forceCreation, + @Bind("$node") Node node, + @Cached InlinedExactClassProfile bufferClazz) { + + if (direct.bitmapBuffer == null) { + if (forceCreation) { + direct.bitmapBuffer = direct.initializeBitmapBuffer(); + } else { + return null; + } + } + return bufferClazz.profile(node, direct.bitmapBuffer); + } } - public long getLong(int index) throws UnsupportedMessageException { - return dataBuffer.getLong(index); - } + static final class PutNode extends Node { + private static final PutNode UNCACHED = + new PutNode(DataBufferNode.getUncached(), BitmapBufferNode.getUncached()); + private @Child DataBufferNode dataBuffer; + private @Child BitmapBufferNode bitmapBuffer; - public void putLong(int index, long value) { - setValidityBitmap(index, 8); - dataBuffer.putLong(index, value); - } + private PutNode(DataBufferNode dbn, BitmapBufferNode bbn) { + this.dataBuffer = dbn; + this.bitmapBuffer = bbn; + } + + @NeverDefault + static PutNode create() { + return new PutNode(DataBufferNode.create(), BitmapBufferNode.create()); + } + + @NeverDefault + static PutNode getUncached() { + return UNCACHED; + } + + final void put(ByteBufferDirect direct, byte b) { + var db = dataBuffer.executeDataBuffer(direct); + addValidityBitmap(direct, db.position(), 1); + db.put(b); + } + + final void putNull(ByteBufferDirect direct, LogicalLayout unit) { + var db = dataBuffer.executeDataBuffer(direct); + var index = db.position() / unit.sizeInBytes(); + + var bb = bitmapBuffer.executeBitmapBuffer(direct, true); + + var bufferIndex = index >> 3; + var slot = bb.get(bufferIndex); + var byteIndex = index & BYTE_MASK; + var mask = ~(1 << byteIndex); + bb.put(bufferIndex, (byte) (slot & mask)); + + db.position(db.position() + unit.sizeInBytes()); + } + + final void putShort(ByteBufferDirect direct, short value) { + var db = dataBuffer.executeDataBuffer(direct); + addValidityBitmap(direct, db.position(), 2); + db.putShort(value); + } - public void putFloat(float value) throws UnsupportedMessageException { - setValidityBitmap(0, 4); - dataBuffer.putFloat(value); + final void putInt(ByteBufferDirect direct, int value) { + var db = dataBuffer.executeDataBuffer(direct); + addValidityBitmap(direct, db.position(), 4); + db.putInt(value); + } + + final void putLong(ByteBufferDirect direct, long value) throws UnsupportedMessageException { + var db = dataBuffer.executeDataBuffer(direct); + addValidityBitmap(direct, db.position(), 8); + db.putLong(value); + } + + private void addValidityBitmap(ByteBufferDirect direct, int pos, int size) { + var bb = bitmapBuffer.executeBitmapBuffer(direct, false); + if (bb == null) { + return; + } + var index = pos / size; + var bufferIndex = index >> 3; + var slot = bb.get(bufferIndex); + var byteIndex = index & BYTE_MASK; + + var mask = 1 << byteIndex; + var updated = (slot | mask); + bb.put(bufferIndex, (byte) (updated)); + } } - public float getFloat(int index) throws UnsupportedMessageException { - return dataBuffer.getFloat(index); + public byte get(int index) throws UnsupportedMessageException { + return dataBuffer.get(index); } - public void putFloat(int index, float value) throws UnsupportedMessageException { - setValidityBitmap(index, 4); - dataBuffer.putFloat(index, value); + public short getShort(int index) throws UnsupportedMessageException { + return dataBuffer.getShort(index); } - public void putDouble(double value) throws UnsupportedMessageException { - setValidityBitmap(0, 8); - dataBuffer.putDouble(value); + public int getInt(int index) throws UnsupportedMessageException { + return dataBuffer.getInt(index); } - public double getDouble(int index) throws UnsupportedMessageException { - return dataBuffer.getDouble(index); + public long getLong(int index) throws UnsupportedMessageException { + return dataBuffer.getLong(index); } - public void putDouble(int index, double value) throws UnsupportedMessageException { - setValidityBitmap(index, 8); - dataBuffer.putDouble(index, value); + public long getLong(int index, Node node, InlinedExactClassProfile profile) + throws UnsupportedMessageException { + var buf = profile.profile(node, dataBuffer); + return buf.getLong(index); } public int capacity() throws UnsupportedMessageException { return dataBuffer.capacity(); } - public boolean isNull(int index) { - var bufferIndex = index >> 3; - var slot = bitmapBuffer.get(bufferIndex); - var byteIndex = index & ~(1 << 3); - var mask = 1 << byteIndex; - return (slot & mask) == 0; + boolean hasNulls() { + return bitmapBuffer != null; } - public void setNull(int index) { - var bufferIndex = index >> 3; - var slot = bitmapBuffer.get(bufferIndex); - var byteIndex = index & ~(1 << 3); - var mask = ~(1 << byteIndex); - bitmapBuffer.put(bufferIndex, (byte) (slot & mask)); + public boolean isNull(int index) { + if (bitmapBuffer == null) { + return false; + } + return checkForNull(index); } - private void setValidityBitmap(int index0, int unitSize) { - var index = index0 / unitSize; + private boolean checkForNull(int index) { var bufferIndex = index >> 3; var slot = bitmapBuffer.get(bufferIndex); - var byteIndex = index & ~(1 << 3); + var byteIndex = index & BYTE_MASK; var mask = 1 << byteIndex; - var updated = (slot | mask); - bitmapBuffer.put(bufferIndex, (byte) (updated)); + return (slot & mask) == 0; } + private static final int BYTE_MASK = ~(~(1 << 3) + 1); // 7 + @Override public void close() throws Exception { this.dataBuffer.clear(); diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/OperationPlus.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/OperationPlus.java new file mode 100644 index 000000000000..2a15415055ed --- /dev/null +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/OperationPlus.java @@ -0,0 +1,69 @@ +package org.enso.interpreter.arrow.runtime; + +import com.oracle.truffle.api.dsl.Cached.Shared; +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.library.CachedLibrary; + +@GenerateUncached +@GenerateInline(false) +abstract class OperationPlus extends ScalarOperationNode { + @Override + abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException; + + static OperationPlus create() { + return OperationPlusNodeGen.create(); + } + + static OperationPlus getUncached() { + return OperationPlusNodeGen.getUncached(); + } + + @Specialization(rewriteOn = ArithmeticException.class) + long doLongs(long a, long b) { + return Math.addExact(a, b); + } + + @Specialization(replaces = "doLongs") + Object doLongsWithOverflowCheck(long a, long b) { + long res = a + b; + long check1 = a ^ res; + long check2 = b ^ res; + long checkBoth = check1 & check2; + if (checkBoth < 0) { + return NullValue.get(); + } + return res; + } + + @Specialization( + guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"}, + rewriteOn = ArithmeticException.class) + Object doFitInLong( + Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) + throws UnsupportedMessageException { + var la = iop.asLong(a); + var lb = iop.asLong(b); + return doLongs(la, lb); + } + + @Specialization( + guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"}, + replaces = "doFitInLong") + Object doFitInLongWithOverflowCheck( + Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) + throws UnsupportedMessageException { + var la = iop.asLong(a); + var lb = iop.asLong(b); + return doLongsWithOverflowCheck(la, lb); + } + + @Specialization(guards = {"iop.isNull(a) || iop.isNull(b)"}) + NullValue nothing( + Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) { + return NullValue.get(); + } +} diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/RoundingUtil.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/RoundingUtil.java index b61aa5a4e2c2..894759a74483 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/RoundingUtil.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/RoundingUtil.java @@ -19,16 +19,16 @@ package org.enso.interpreter.arrow.runtime; -class RoundingUtil { +final class RoundingUtil { /** The mask for rounding an integer to a multiple of 8. (i.e. clear the lowest 3 bits) */ - static int ROUND_8_MASK_INT = 0xFFFFFFF8; + static final int ROUND_8_MASK_INT = 0xFFFFFFF8; /** The mask for rounding a long integer to a multiple of 8. (i.e. clear the lowest 3 bits) */ - static long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L; + static final long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L; /** The number of bits to shift for dividing by 8. */ - static int DIVIDE_BY_8_SHIFT_BITS = 3; + static final int DIVIDE_BY_8_SHIFT_BITS = 3; private RoundingUtil() {} @@ -91,13 +91,15 @@ public int getTotalSizeInBytes() { return (int) (dataBufferSize + validityBitmapSize); } - private long validityBitmapSize; - private long dataBufferSize; + private final long validityBitmapSize; + private final long dataBufferSize; private PaddedSize(int valueCount, SizeInBytes unit) { this.valueCount = valueCount; this.unit = unit; - computeBufferSize(valueCount, unit); + var pair = computeBufferSize(valueCount, unit); + this.validityBitmapSize = pair[0]; + this.dataBufferSize = pair[1]; } private long defaultRoundedSize(long val) { @@ -127,7 +129,7 @@ long computeCombinedBufferSize(int valueCount, int typeWidth) { return defaultRoundedSize(bufferSize); } - private void computeBufferSize(int valueCount, SizeInBytes unit) { + private long[] computeBufferSize(int valueCount, SizeInBytes unit) { var typeWidth = unit.sizeInBytes(); long bufferSize = computeCombinedBufferSize(valueCount, typeWidth); assert bufferSize <= Long.MAX_VALUE; @@ -149,8 +151,7 @@ private void computeBufferSize(int valueCount, SizeInBytes unit) { --actualCount; } while (true); } - this.validityBitmapSize = validityBufferSize; - this.dataBufferSize = dataBufferSize; + return new long[] {validityBufferSize, dataBufferSize}; } } } diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ScalarOperationNode.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ScalarOperationNode.java new file mode 100644 index 000000000000..559853fd82b9 --- /dev/null +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ScalarOperationNode.java @@ -0,0 +1,10 @@ +package org.enso.interpreter.arrow.runtime; + +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.nodes.Node; + +@GenerateInline(false) +abstract class ScalarOperationNode extends Node { + abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException; +} diff --git a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/WriteToBuilderNode.java b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ValueToNumberNode.java similarity index 56% rename from engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/WriteToBuilderNode.java rename to engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ValueToNumberNode.java index acff598d0b7f..4ca15c1c1ebc 100644 --- a/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/WriteToBuilderNode.java +++ b/engine/runtime-language-arrow/src/main/java/org/enso/interpreter/arrow/runtime/ValueToNumberNode.java @@ -1,7 +1,13 @@ package org.enso.interpreter.arrow.runtime; import com.oracle.truffle.api.CompilerDirectives; -import com.oracle.truffle.api.dsl.*; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.GenerateInline; +import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.ImportStatic; +import com.oracle.truffle.api.dsl.NeverDefault; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; @@ -17,58 +23,56 @@ @ImportStatic(LogicalLayout.class) @GenerateUncached @GenerateInline(value = false) -abstract class WriteToBuilderNode extends Node { +abstract class ValueToNumberNode extends Node { + /** + * Converts {@code value} to a suitable representation to be stored in an appropriate + * DirectBuffer. + * + * @param unit type of layout + * @param value a value to convert + * @return byte, short, int or long + * @throws UnsupportedTypeException if the conversion isn't possible + */ + abstract Number executeAdjust(LogicalLayout unit, Object value) throws UnsupportedTypeException; - public abstract void executeWrite(ArrowFixedSizeArrayBuilder receiver, long index, Object value) - throws UnsupportedTypeException; + @NeverDefault + static ValueToNumberNode build() { + return ValueToNumberNodeGen.create(); + } @NeverDefault - static WriteToBuilderNode build() { - return WriteToBuilderNodeGen.create(); + static ValueToNumberNode getUncached() { + return ValueToNumberNodeGen.getUncached(); } - @Specialization(guards = "receiver.getUnit() == Date32") - void doWriteDay( - ArrowFixedSizeArrayBuilder receiver, - long index, + @Specialization(guards = "unit == Date32") + Integer doDay( + LogicalLayout unit, Object value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } if (!iop.isDate(value)) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date"); } - var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 4); long time; try { time = iop.asDate(value).toEpochDay(); } catch (UnsupportedMessageException e) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date"); } - receiver.getBuffer().putInt(at, Math.toIntExact(time)); + return Math.toIntExact(time); } - @Specialization(guards = {"receiver.getUnit() == Date64"}) - void doWriteMilliseconds( - ArrowFixedSizeArrayBuilder receiver, - long index, + @Specialization(guards = {"unit == Date64"}) + Long doMilliseconds( + LogicalLayout unit, Object value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } if (!iop.isDate(value) || !iop.isTime(value)) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date and a time"); } - var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 8); if (iop.isTimeZone(value)) { Instant zoneDateTimeInstant; try { @@ -84,7 +88,7 @@ void doWriteMilliseconds( var secondsPlusNano = zoneDateTimeInstant.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV + zoneDateTimeInstant.getNano(); - receiver.getBuffer().putLong(at, secondsPlusNano); + return secondsPlusNano; } else { Instant dateTime; try { @@ -94,7 +98,7 @@ void doWriteMilliseconds( } var secondsPlusNano = dateTime.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV + dateTime.getNano(); - receiver.getBuffer().putLong(at, secondsPlusNano); + return secondsPlusNano; } } @@ -109,109 +113,75 @@ private static Instant instantForOffset(LocalDate date, LocalTime time, ZoneOffs return date.atTime(time).toInstant(offset); } - @Specialization(guards = "receiver.getUnit() == Int8") - void doWriteByte( - ArrowFixedSizeArrayBuilder receiver, - long index, + @Specialization(guards = "unit == Int8") + Byte doByte( + LogicalLayout unit, Object value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } if (!iop.fitsInByte(value)) { throw UnsupportedTypeException.create(new Object[] {value}, "value does not fit a byte"); } try { - receiver.getBuffer().put(typeAdjustedIndex(index, receiver.getUnit()), (iop.asByte(value))); + return iop.asByte(value); } catch (UnsupportedMessageException e) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not a byte"); } } - @Specialization(guards = "receiver.getUnit() == Int16") - void doWriteShort( - ArrowFixedSizeArrayBuilder receiver, - long index, + @Specialization(guards = "unit == Int16") + Short doShort( + LogicalLayout unit, Object value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } if (!iop.fitsInShort(value)) { throw UnsupportedTypeException.create( new Object[] {value}, "value does not fit a 2 byte short"); } try { - receiver - .getBuffer() - .putShort(typeAdjustedIndex(index, receiver.getUnit()), (iop.asShort(value))); + return iop.asShort(value); } catch (UnsupportedMessageException e) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not a short"); } } - @Specialization(guards = "receiver.getUnit() == Int32") - void doWriteInt( - ArrowFixedSizeArrayBuilder receiver, - long index, + @Specialization(guards = "unit == Int32") + Integer doInt( + LogicalLayout unit, int value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } if (!iop.fitsInInt(value)) { throw UnsupportedTypeException.create( new Object[] {value}, "value does not fit a 4 byte int"); } try { - receiver.getBuffer().putInt(typeAdjustedIndex(index, receiver.getUnit()), (iop.asInt(value))); + return iop.asInt(value); } catch (UnsupportedMessageException e) { throw UnsupportedTypeException.create(new Object[] {value}, "value is not an int"); } } - @Specialization(guards = "receiver.getUnit() == Int64") - public static void doWriteLong( - ArrowFixedSizeArrayBuilder receiver, - long index, - long value, + @Specialization(guards = "unit == Int64") + static Long doLong( + LogicalLayout unit, + Object value, @Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop) throws UnsupportedTypeException { - validAccess(receiver, index); - if (iop.isNull(value)) { - receiver.getBuffer().setNull((int) index); - return; - } - receiver.getBuffer().putLong(typeAdjustedIndex(index, receiver.getUnit()), value); - } - - @Fallback - void doWriteOther(ArrowFixedSizeArrayBuilder receiver, long index, Object value) - throws UnsupportedTypeException { - throw UnsupportedTypeException.create(new Object[] {index, value}, "unknown type of receiver"); - } - - private static void validAccess(ArrowFixedSizeArrayBuilder receiver, long index) - throws UnsupportedTypeException { - if (receiver.isSealed()) { + if (!iop.fitsInLong(value)) { throw UnsupportedTypeException.create( - new Object[] {receiver}, "receiver is not an unsealed buffer"); + new Object[] {value}, "value does not fit a 8 byte int"); } - if (index >= receiver.getSize() || index < 0) { - throw UnsupportedTypeException.create(new Object[] {index}, "index is out of range"); + try { + return iop.asLong(value); + } catch (UnsupportedMessageException e) { + throw UnsupportedTypeException.create(new Object[] {value}, "value is not a long"); } } - private static int typeAdjustedIndex(long index, SizeInBytes unit) { - return ArrowFixedArrayDate.typeAdjustedIndex(index, unit.sizeInBytes()); + @Fallback + Number doOther(LogicalLayout unit, Object value) throws UnsupportedTypeException { + throw UnsupportedTypeException.create(new Object[] {unit, value}, "unknown type"); } } diff --git a/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/AddArrowTest.java b/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/AddArrowTest.java new file mode 100644 index 000000000000..b527bec60bb1 --- /dev/null +++ b/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/AddArrowTest.java @@ -0,0 +1,142 @@ +package org.enso.interpreter.arrow; + +import static org.junit.Assert.*; + +import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.io.IOAccess; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class AddArrowTest { + private static Context ctx; + + @BeforeClass + public static void initEnsoContext() { + ctx = + Context.newBuilder() + .allowExperimentalOptions(true) + .allowIO(IOAccess.ALL) + .out(System.out) + .err(System.err) + .allowAllAccess(true) + .build(); + } + + @AfterClass + public static void closeEnsoContext() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + public void addTwoInt8ArrowArrays() { + var arrow = ctx.getEngine().getLanguages().get("arrow"); + assertNotNull("Arrow is available", arrow); + var int8Constr = ctx.eval("arrow", "new[Int8]"); + assertNotNull(int8Constr); + + var arrLength = 10; + var builder1 = int8Constr.newInstance(arrLength); + var builder2 = int8Constr.newInstance(arrLength); + + for (var i = 0; i < arrLength; i++) { + var ni = arrLength - i - 1; + var v = i * i; + builder1.invokeMember("append", i, (byte) v); + builder2.invokeMember("append", ni, (byte) v); + } + + var arr1 = builder1.invokeMember("build"); + assertEquals("Right size of arr1", arrLength, arr1.getArraySize()); + var arr2 = builder2.invokeMember("build"); + assertEquals("Right size of arr2", arrLength, arr2.getArraySize()); + + var int8Plus = ctx.eval("arrow", "+[Int8]"); + var resultArr = int8Plus.execute(arr1, arr2); + + assertTrue("Result is an array", resultArr.hasArrayElements()); + assertEquals("Right size", arrLength, resultArr.getArraySize()); + + for (var i = 0; i < arrLength; i++) { + var ni = arrLength - i - 1; + var v1 = resultArr.getArrayElement(i).asLong(); + var v2 = resultArr.getArrayElement(ni).asLong(); + + assertEquals("Values at " + i + " and " + ni + " are the same", v1, v2); + assertTrue("Values are always bigger than zero: " + v1, v1 > 0); + } + } + + @Test + public void addTwoInt64ArrowArraysWithNulls() { + var arrow = ctx.getEngine().getLanguages().get("arrow"); + assertNotNull("Arrow is available", arrow); + var constr = ctx.eval("arrow", "new[Int64]"); + assertNotNull(constr); + + var arrLength = 10; + var builder1 = constr.newInstance(arrLength); + for (int i = 0; i < arrLength; i++) { + if (i % 7 < 2) { + builder1.invokeMember("append", Long.MAX_VALUE); + } else { + builder1.invokeMember("append", i); + } + } + + var builder2 = constr.newInstance(arrLength); + for (var i = 0; i < arrLength; i++) { + builder2.invokeMember("append", 10 + i); + } + + var arr1 = builder1.invokeMember("build"); + assertEquals("Right size of arr1", arrLength, arr1.getArraySize()); + var addArr = builder2.invokeMember("build"); + assertEquals("Right size of arr2", arrLength, addArr.getArraySize()); + + var plus = ctx.eval("arrow", "+[Int64]"); + var res1 = plus.execute(arr1, addArr); + + assertTrue("Result is an array", res1.hasArrayElements()); + assertEquals("Right size", arrLength, res1.getArraySize()); + + assertTrue("is null", res1.getArrayElement(0).isNull()); + assertTrue("is null", res1.getArrayElement(1).isNull()); + assertTrue("is null", res1.getArrayElement(7).isNull()); + assertTrue("is null", res1.getArrayElement(8).isNull()); + + var countNulls = 0; + for (var i = 0; i < arrLength; i++) { + var v = res1.getArrayElement(i); + if (v.isNull()) { + countNulls++; + } else { + assertEquals(i * 2 + 10, v.asLong()); + } + } + assertEquals("Four nulls", 4, countNulls); + + var res2 = plus.execute(res1, addArr); + + assertTrue("Result is an array", res2.hasArrayElements()); + assertEquals("Right size", arrLength, res2.getArraySize()); + + assertTrue("is null", res2.getArrayElement(0).isNull()); + assertTrue("is null", res2.getArrayElement(1).isNull()); + assertTrue("is null", res2.getArrayElement(7).isNull()); + assertTrue("is null", res2.getArrayElement(8).isNull()); + + var countNullsAgain = 0; + for (var i = 0; i < arrLength; i++) { + var v = res2.getArrayElement(i); + if (v.isNull()) { + countNullsAgain++; + } else { + assertEquals(i * 3 + 20, v.asLong()); + } + } + assertEquals("Four nulls", 4, countNullsAgain); + } +} diff --git a/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/VerifyArrowTest.java b/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/VerifyArrowTest.java index 7578a5f2cd48..f829ae71a612 100644 --- a/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/VerifyArrowTest.java +++ b/engine/runtime-language-arrow/src/test/java/org/enso/interpreter/arrow/VerifyArrowTest.java @@ -19,6 +19,7 @@ import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.IntVector; import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.PolyglotException; import org.graalvm.polyglot.Value; import org.graalvm.polyglot.io.IOAccess; import org.junit.AfterClass; @@ -84,7 +85,7 @@ public void arrowDate32() { date32ArrayBuilder.invokeMember("build"); assertFalse(date32ArrayBuilder.canInvokeMember("append")); assertThrows( - UnsupportedOperationException.class, + PolyglotException.class, () -> finalDate32ArrayBuilder.invokeMember("append", startDateTime)); assertFalse(date32Array.canInvokeMember("append")); } @@ -153,6 +154,33 @@ public void arrowInt8() { assertEquals((byte) 5, v.asByte()); } + @Test + public void arrowInt64() { + var arrow = ctx.getEngine().getLanguages().get("arrow"); + assertNotNull("Arrow is available", arrow); + var constr = ctx.eval("arrow", "new[Int64]"); + assertNotNull(constr); + + var arrLength = 48; + Value builder = constr.newInstance(arrLength); + for (var i = 0; i < arrLength; i++) { + builder.invokeMember("append", i); + } + var arr = builder.invokeMember("build"); + assertEquals(arrLength, arr.getArraySize()); + for (var i = 0; i < arrLength; i++) { + var ith = arr.getArrayElement(i); + assertEquals("Checking value at " + i, i, ith.asLong()); + } + + var plus = ctx.eval("arrow", "+[Int64]"); + var doubled = plus.execute(arr, arr); + for (var i = 0; i < arrLength; i++) { + var ith = doubled.getArrayElement(i); + assertEquals("Checking double value at " + i, 2 * i, ith.asInt()); + } + } + @Test public void castInt() { var typeLength = LogicalLayout.Int32; diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/ApplicationNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/ApplicationNode.java index 622f0d65526c..f521619e3ed4 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/ApplicationNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/ApplicationNode.java @@ -9,7 +9,6 @@ import org.enso.interpreter.runtime.callable.argument.CallArgument; import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.function.Function; -import org.enso.interpreter.runtime.state.State; /** * This node is responsible for organising callable calls so that they are ready to be made. @@ -92,10 +91,10 @@ private Object[] evaluateArguments(VirtualFrame frame) { */ @Override public Object executeGeneric(VirtualFrame frame) { - State state = Function.ArgumentsHelper.getState(frame.getArguments()); - Object[] evaluatedArguments = evaluateArguments(frame); - return this.invokeCallableNode.execute( - this.callable.executeGeneric(frame), frame, state, evaluatedArguments); + var state = Function.ArgumentsHelper.getState(frame.getArguments()); + var evaluatedArguments = evaluateArguments(frame); + var self = this.callable.executeGeneric(frame); + return this.invokeCallableNode.execute(self, frame, state, evaluatedArguments); } /** diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java index 1ecd1e028457..fb85d1542c78 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java @@ -1,6 +1,7 @@ package org.enso.interpreter.node.callable; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Bind; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached.Shared; import com.oracle.truffle.api.dsl.Fallback; @@ -12,6 +13,7 @@ import com.oracle.truffle.api.interop.UnsupportedTypeException; import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.InlinedBranchProfile; import com.oracle.truffle.api.source.SourceSection; import java.util.UUID; import java.util.concurrent.locks.Lock; @@ -337,32 +339,37 @@ public Object invokeWarnings( "!types.hasSpecialDispatch(self)", "iop.isExecutable(self)", }) - Object doPolyglot( + static Object doPolyglot( Object self, VirtualFrame frame, State state, Object[] arguments, + @Bind("$node") Node node, @CachedLibrary(limit = "3") InteropLibrary iop, @Shared("warnings") @CachedLibrary(limit = "3") WarningsLibrary warnings, @CachedLibrary(limit = "3") TypesLibrary types, - @Cached ThunkExecutorNode thunkNode) { - var errors = EnsoContext.get(this).getBuiltins().error(); + @Cached ThunkExecutorNode thunkNode, + @Cached InlinedBranchProfile errorNeedsToBeReported) { + var errors = EnsoContext.get(node).getBuiltins().error(); try { for (int i = 0; i < arguments.length; i++) { arguments[i] = thunkNode.executeThunk(frame, arguments[i], state, TailStatus.NOT_TAIL); } return iop.execute(self, arguments); } catch (UnsupportedTypeException ex) { + errorNeedsToBeReported.enter(node); var err = errors.makeUnsupportedArgumentsError(ex.getSuppliedValues(), ex.getMessage()); - throw new PanicException(err, this); + throw new PanicException(err, node); } catch (ArityException ex) { + errorNeedsToBeReported.enter(node); var err = errors.makeArityError( ex.getExpectedMinArity(), ex.getExpectedMaxArity(), arguments.length); - throw new PanicException(err, this); + throw new PanicException(err, node); } catch (UnsupportedMessageException ex) { + errorNeedsToBeReported.enter(node); var err = errors.makeNotInvokable(self); - throw new PanicException(err, this); + throw new PanicException(err, node); } } diff --git a/test/Benchmarks/src/Table/Arithmetic.enso b/test/Benchmarks/src/Table/Arithmetic.enso index 2393255bd462..8bf04d1e9557 100644 --- a/test/Benchmarks/src/Table/Arithmetic.enso +++ b/test/Benchmarks/src/Table/Arithmetic.enso @@ -8,45 +8,118 @@ polyglot java import java.lang.Long as Java_Long options = Bench.options . set_warmup (Bench.phase_conf 3 5) . set_measure (Bench.phase_conf 3 5) - -create_table : Table -create_table num_rows = +create_vectors num_rows = x = Vector.new num_rows i-> i+1 y = Vector.new num_rows i-> if i % 10 < 2 then Java_Long.MAX_VALUE else i+1 u = Vector.new num_rows i-> 10 + (i % 100) + z = Vector.new num_rows i-> + if i % 10 < 2 then Nothing else i+1 - t = Table.new [["X", x], ["Y", y], ["U", u]] - - assert condition = - if condition.not then Panic.throw "Assertion failed" + [x, y, u, z] - assert ((t.at "X" . value_type) == Value_Type.Integer) - assert ((t.at "Y" . value_type) == Value_Type.Integer) - assert ((t.at "U" . value_type) == Value_Type.Integer) +create_table : Table +create_table num_rows = + v = create_vectors num_rows + x = v.at 0 + y = v.at 1 + u = v.at 2 + z = v.at 3 + + t = Table.new [["X", x], ["Y", y], ["U", u], ["Z", z]] + + Runtime.assert ((t.at "X" . value_type) == Value_Type.Integer) + Runtime.assert ((t.at "Y" . value_type) == Value_Type.Integer) + Runtime.assert ((t.at "U" . value_type) == Value_Type.Integer) + Runtime.assert ((t.at "Z" . value_type) == Value_Type.Integer) t +create_arrow_columns num_rows = + column_to_arrow v:Vector -> Array = + builder = int64_new.new v.length + v.map e-> builder.append e + builder.build + + v = create_vectors num_rows + + x = column_to_arrow (v.at 0) + y = column_to_arrow (v.at 1) + u = column_to_arrow (v.at 2) + z = column_to_arrow (v.at 3) + [int64_plus, x, y, u, z] + +foreign arrow int64_new = """ + new[Int64] + +foreign arrow int64_plus = """ + +[Int64] type Data - Value ~table + private Value ~table ~arrow - create num_rows = Data.Value (create_table num_rows) + arrow_plus self = self.arrow.at 0 + arrow_x self = self.arrow.at 1 + arrow_y self = self.arrow.at 2 + arrow_u self = self.arrow.at 3 + arrow_z self = self.arrow.at 4 + create num_rows = Data.Value (create_table num_rows) (create_arrow_columns num_rows) collect_benches = Bench.build builder-> + column_arithmetic_plus_fitting d = + (d.table.at "X") + (d.table.at "U") + + column_arithmetic_plus_overflowing d = + (d.table.at "Y") + (d.table.at "U") + + column_arithmetic_plus_nothing d = + (d.table.at "Z") + (d.table.at "U") + + column_arithmetic_multiply_fitting d = + (d.table.at "X") * (d.table.at "U") + + column_arithmetic_multiply_overflowing d = + (d.table.at "Y") * (d.table.at "U") + + arrow_arithmetic_plus_fitting d = + d.arrow_plus d.arrow_x d.arrow_u + + arrow_arithmetic_plus_overflowing d = + d.arrow_plus d.arrow_y d.arrow_u + + arrow_arithmetic_plus_nothing d = + d.arrow_plus d.arrow_y d.arrow_u + num_rows = 1000000 data = Data.create num_rows + Runtime.assert ((column_arithmetic_plus_fitting data . to_vector) == (arrow_arithmetic_plus_fitting data)) "Column and arrow correctness check one" + Runtime.assert ((column_arithmetic_plus_overflowing data . to_vector) == (arrow_arithmetic_plus_overflowing data)) "Column and arrow correctness check two" + Runtime.assert ((column_arithmetic_plus_nothing data . to_vector) == (arrow_arithmetic_plus_nothing data)) "Column and arrow correctness check three" + builder.group ("Column_Arithmetic_" + num_rows.to_text) options group_builder-> group_builder.specify "Plus_Fitting" <| - (data.table.at "X") + (data.table.at "U") + column_arithmetic_plus_fitting data group_builder.specify "Plus_Overflowing" <| - (data.table.at "Y") + (data.table.at "U") + column_arithmetic_plus_overflowing data + group_builder.specify "Plus_Nothing" <| + column_arithmetic_plus_nothing data + group_builder.specify "Multiply_Fitting" <| - (data.table.at "X") * (data.table.at "U") + column_arithmetic_multiply_fitting data group_builder.specify "Multiply_Overflowing" <| - (data.table.at "Y") * (data.table.at "U") + column_arithmetic_multiply_overflowing data + + builder.group ("Arrow_Arithmetic_" + num_rows.to_text) options group_builder-> + group_builder.specify "Plus_Fitting" <| + arrow_arithmetic_plus_fitting data + + group_builder.specify "Plus_Overflowing" <| + arrow_arithmetic_plus_overflowing data + + group_builder.specify "Plus_Nothing" <| + arrow_arithmetic_plus_nothing data main = collect_benches . run_main