Skip to content

Commit

Permalink
Implement and benchmark ArrowOperationPlus node (#10150)
Browse files Browse the repository at this point in the history
Prototype of #10056 showing `+` operation implemented in the _Arrow language_.
  • Loading branch information
JaroslavTulach authored Jun 11, 2024
1 parent 19c50ce commit aaaebca
Show file tree
Hide file tree
Showing 18 changed files with 1,006 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,42 @@ private ArrowParser() {}
public record Result(PhysicalLayout physicalLayout, LogicalLayout logicalLayout, Mode mode) {}

public static Result parse(Source source) {
String src = source.getCharacters().toString();
Matcher m = NEW_ARRAY_CONSTR.matcher(src);
String src = source.getCharacters().toString().replace('\n', ' ').trim();
Matcher m = PATTERN.matcher(src);
if (m.find()) {
try {
var layout = LogicalLayout.valueOf(m.group(1));
return new Result(PhysicalLayout.Primitive, layout, Mode.Allocate);
var layout = LogicalLayout.valueOf(m.group(2));
var mode = Mode.parse(m.group(1));
if (layout != null && mode != null) {
return new Result(PhysicalLayout.Primitive, layout, mode);
}
} catch (IllegalArgumentException iae) {
// propagate warning
return null;
}
}

m = CAST_PATTERN.matcher(src);
if (m.find()) {
try {
var layout = LogicalLayout.valueOf(m.group(1));
return new Result(PhysicalLayout.Primitive, layout, Mode.Cast);
} catch (IllegalArgumentException iae) {
// propagate warning
return null;
}
}
return null;
}

private static final Pattern NEW_ARRAY_CONSTR = Pattern.compile("^new\\[(.+)\\]$");
private static final Pattern CAST_PATTERN = Pattern.compile("^cast\\[(.+)\\]$");
private static final Pattern PATTERN = Pattern.compile("^([a-z\\+]+)\\[(.+)\\]$");

public enum Mode {
Allocate,
Cast
Allocate("new"),
Cast("cast"),
Plus("+");

private final String op;

private Mode(String text) {
this.op = text;
}

static Mode parse(String operation) {
for (var m : values()) {
if (m.op.equals(operation)) {
return m;
}
}
return null;
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.arrow.ArrowLanguage;
import org.enso.interpreter.arrow.ArrowParser;
import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory;
import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory;
import org.enso.interpreter.arrow.runtime.ArrowOperationPlus;

public class ArrowEvalNode extends RootNode {
private final ArrowParser.Result code;

@Child private ArrowFixedSizeNode fixedPhysicalLayout = ArrowFixedSizeNode.create();
@Child private ArrowCastFixedSizeNode castToFixedPhysicalLayout = ArrowCastFixedSizeNode.create();

public static ArrowEvalNode create(ArrowLanguage language, ArrowParser.Result code) {
return new ArrowEvalNode(language, code);
}
Expand All @@ -25,8 +25,10 @@ private ArrowEvalNode(ArrowLanguage language, ArrowParser.Result code) {
public Object execute(VirtualFrame frame) {
return switch (code.physicalLayout()) {
case Primitive -> switch (code.mode()) {
case Allocate -> fixedPhysicalLayout.execute(code.logicalLayout());
case Cast -> castToFixedPhysicalLayout.execute(code.logicalLayout());
case Allocate -> new ArrowFixedSizeArrayFactory(code.logicalLayout());
case Cast -> new ArrowCastToFixedSizeArrayFactory(code.logicalLayout());
case Plus -> new ArrowOperationPlus(code.logicalLayout());
default -> throw CompilerDirectives.shouldNotReachHere("unsupported mode");
};
default -> throw CompilerDirectives.shouldNotReachHere("unsupported physical layout");
};
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
package org.enso.interpreter.arrow.runtime;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.NeverDefault;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.StopIterationException;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import org.enso.interpreter.arrow.LogicalLayout;

@ExportLibrary(InteropLibrary.class)
Expand All @@ -27,7 +37,24 @@ public LogicalLayout getUnit() {
}

@ExportMessage
public boolean hasArrayElements() {
boolean hasArrayElements() {
return true;
}

@ExportMessage
Object getIterator(
@Cached(value = "this.getUnit()", allowUncached = true) LogicalLayout cachedUnit)
throws UnsupportedMessageException {
if (cachedUnit == LogicalLayout.Int64) {
var dataIt = new LongIterator(buffer.getDataBuffer(), cachedUnit.sizeInBytes());
var nullIt = new NullIterator(dataIt, buffer.getBitmapBuffer());
return nullIt;
}
return new GenericIterator(this);
}

@ExportMessage
boolean hasIterator() {
return true;
}

Expand Down Expand Up @@ -65,13 +92,18 @@ public static Object doInt(ArrowFixedArrayInt receiver, long index)
}

@Specialization(guards = "receiver.getUnit() == Int64")
public static Object doLong(ArrowFixedArrayInt receiver, long index)
public static Object doLong(
ArrowFixedArrayInt receiver,
long index,
@Bind("$node") Node node,
@CachedLibrary("receiver") InteropLibrary iop,
@Cached InlinedExactClassProfile bufferClazz)
throws UnsupportedMessageException, InvalidArrayIndexException {
var at = adjustedIndex(receiver.buffer, receiver.unit, receiver.size, index);
var at = adjustedIndex(receiver.buffer, LogicalLayout.Int64, receiver.size, index);
if (receiver.buffer.isNull((int) index)) {
return NullValue.get();
}
return receiver.buffer.getLong(at);
return receiver.buffer.getLong(at, iop, bufferClazz);
}
}

Expand All @@ -96,4 +128,133 @@ boolean isArrayElementReadable(long index) {
private static int typeAdjustedIndex(long index, SizeInBytes unit) {
return Math.toIntExact(index * unit.sizeInBytes());
}

@ExportLibrary(InteropLibrary.class)
static final class LongIterator implements TruffleObject {
private int at;
private final ByteBuffer buffer;
@NeverDefault final int step;

LongIterator(ByteBuffer buffer, int step) {
assert step != 0;
this.buffer = buffer;
this.step = step;
}

@ExportMessage
Object getIteratorNextElement(
@Bind("$node") Node node,
@Cached("this.step") int step,
@Cached InlinedExactClassProfile bufferTypeProfile)
throws StopIterationException {
var buf = bufferTypeProfile.profile(node, buffer);
try {
var res = buf.getLong(at);
at += step;
return res;
} catch (BufferOverflowException ex) {
CompilerDirectives.transferToInterpreter();
throw StopIterationException.create();
}
}

@ExportMessage
boolean isIterator() {
return true;
}

@ExportMessage
boolean hasIteratorNextElement() throws UnsupportedMessageException {
return at < buffer.limit();
}
}

@ExportLibrary(value = InteropLibrary.class)
static final class NullIterator implements TruffleObject {
private final TruffleObject it;
private final ByteBuffer buffer;
private byte byteMask;
private byte byteValue;

NullIterator(TruffleObject delegate, ByteBuffer buffer) {
this.it = delegate;
this.buffer = buffer;
}

final TruffleObject it() {
return it;
}

@ExportMessage(limit = "3")
Object getIteratorNextElement(
@Bind("$node") Node node,
@CachedLibrary("this.it()") InteropLibrary iopIt,
@Cached InlinedExactClassProfile bufferTypeProfile)
throws StopIterationException, UnsupportedMessageException {
var element = iopIt.getIteratorNextElement(it);
if (buffer != null) {
var buf = bufferTypeProfile.profile(node, buffer);
if (byteMask == 0) {
// (byte) (0x01 << 8) ==> 0
byteValue = buf.get();
byteMask = 0x01;
}
var include = byteValue & byteMask;
byteMask = (byte) (byteMask << 1);
if (include == 0) {
return NullValue.get();
}
}
return element;
}

@ExportMessage
boolean isIterator() {
return true;
}

@ExportMessage(limit = "3")
boolean hasIteratorNextElement(@CachedLibrary("this.it()") InteropLibrary iopIt)
throws UnsupportedMessageException {
return iopIt.hasIteratorNextElement(it);
}
}

@ExportLibrary(InteropLibrary.class)
static final class GenericIterator implements TruffleObject {
private int at;
private final TruffleObject array;

GenericIterator(TruffleObject array) {
assert InteropLibrary.getUncached().hasArrayElements(array);
this.array = array;
}

TruffleObject array() {
return array;
}

@ExportMessage(limit = "3")
Object getIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
throws StopIterationException {
try {
var res = iop.readArrayElement(array, at);
at++;
return res;
} catch (UnsupportedMessageException | InvalidArrayIndexException ex) {
throw StopIterationException.create();
}
}

@ExportMessage
boolean isIterator() {
return true;
}

@ExportMessage(limit = "3")
boolean hasIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
throws UnsupportedMessageException {
return at < iop.getArraySize(array);
}
}
}
Loading

0 comments on commit aaaebca

Please sign in to comment.