Skip to content

Commit

Permalink
Rewrite all built-in methods as specialized nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoGGuerrier committed Jan 6, 2025
1 parent 524564a commit 9539af9
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 271 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import com.adacore.lkql_jit.LKQLTypeSystemGen;
import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory;
import com.adacore.lkql_jit.built_ins.BuiltInsHolder;
import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody;
import com.adacore.lkql_jit.built_ins.functions.UniqueFunction;
import com.adacore.lkql_jit.exception.LKQLRuntimeException;
import com.adacore.lkql_jit.nodes.expressions.Expr;
import com.adacore.lkql_jit.nodes.expressions.FunCall;
import com.adacore.lkql_jit.runtime.values.lists.LKQLList;
import com.adacore.lkql_jit.utils.LKQLTypesHelper;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import java.util.Arrays;
import java.util.Map;

Expand All @@ -27,58 +28,70 @@
*/
public class ListMethods {

private static final Map.Entry<String, BuiltInMethodFactory> sublistFunction =
createMethod(
"sublist",
"Return a sublist of `list` from `low_bound` to `high_bound`",
new String[] {"low_bound", "high_bound"},
new Expr[] {null, null},
(VirtualFrame frame, FunCall call) -> {
var args = frame.getArguments();

if (!LKQLTypeSystemGen.isLKQLList(args[0])) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.LKQL_LIST,
LKQLTypesHelper.fromJava(args[0]),
call.getArgList().getArgs()[0]);
}

if (!LKQLTypeSystemGen.isLong(args[1])) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.LKQL_INTEGER,
LKQLTypesHelper.fromJava(args[1]),
call.getArgList().getArgs()[1]);
}

if (!LKQLTypeSystemGen.isLong(args[2])) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.LKQL_INTEGER,
LKQLTypesHelper.fromJava(args[2]),
call.getArgList().getArgs()[2]);
}

LKQLList list = LKQLTypeSystemGen.asLKQLList(args[0]);
long lowBound = LKQLTypeSystemGen.asLong(args[1]);
long highBound = LKQLTypeSystemGen.asLong(args[2]);

if (lowBound < 1) {
throw LKQLRuntimeException.invalidIndex((int) lowBound, call);
} else if (highBound > list.getContent().length) {
throw LKQLRuntimeException.invalidIndex((int) highBound, call);
}

return new LKQLList(
Arrays.copyOfRange(
list.getContent(), (int) lowBound - 1, (int) highBound));
});

public static final Map<String, BuiltInMethodFactory> methods =
BuiltInsHolder.combine(
Map.ofEntries(
Map.entry(
UniqueFunction.NAME,
BuiltInMethodFactory.fromFunctionValue(
UniqueFunction.getValue(), true)),
sublistFunction),
createMethod(
"sublist",
"Return a sublist of `list` from "
+ "`low_bound` to `high_bound`",
new String[] {"low_bound", "high_bound"},
new Expr[] {null, null},
new SpecializedBuiltInBody<>(
ListMethodsFactory.SublistExprNodeGen.create()) {
@Override
protected Object dispatch(Object[] args) {
return this.specializedNode.executeSublist(
LKQLTypeSystemGen.asLKQLList(args[0]),
args[1],
args[2]);
}
})),
IterableMethods.methods);

// ----- Inner classes -----

/** Expression of the "sublist" method. */
public abstract static class SublistExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode {

public abstract LKQLList executeSublist(LKQLList list, Object low, Object high);

@Specialization
protected LKQLList onValid(LKQLList list, long low, long high) {
// Offset the low bound by 1 since LKQL is 1-indexed
low = low - 1;

// Check bounds validity
if (low < 0) {
throw LKQLRuntimeException.invalidIndex((int) low + 1, body.argNode(0));
} else if (high > list.getContent().length) {
throw LKQLRuntimeException.invalidIndex((int) high, body.argNode(1));
}

// Return the sublist
return new LKQLList(Arrays.copyOfRange(list.getContent(), (int) low, (int) high));
}

@Specialization
protected LKQLList onInvalidHigh(
@SuppressWarnings("unused") LKQLList list,
@SuppressWarnings("unused") long low,
Object high) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.LKQL_INTEGER, LKQLTypesHelper.fromJava(high), body.argNode(1));
}

@Fallback
protected LKQLList onInvalidLow(
@SuppressWarnings("unused") LKQLList list,
Object low,
@SuppressWarnings("unused") Object high) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.LKQL_INTEGER, LKQLTypesHelper.fromJava(low), body.argNode(0));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.adacore.lkql_jit.LKQLTypeSystemGen;
import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody;
import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory;
import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody;
import com.adacore.lkql_jit.exception.LKQLRuntimeException;
import com.adacore.lkql_jit.nodes.expressions.Expr;
import com.adacore.lkql_jit.runtime.values.LKQLNull;
Expand All @@ -22,8 +23,9 @@
import com.adacore.lkql_jit.utils.functions.ObjectUtils;
import com.adacore.lkql_jit.utils.functions.ReflectionUtils;
import com.adacore.lkql_jit.utils.functions.StringUtils;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.UnexpectedResultException;
import java.util.ArrayList;
import java.util.Map;

Expand Down Expand Up @@ -66,7 +68,16 @@ public final class NodeMethods {
"Return whether two nodes have the same tokens, ignoring trivias",
new String[] {"other"},
new Expr[] {null},
new SameTokensExpr()));
new SpecializedBuiltInBody<>(
NodeMethodsFactory.SameTokensExprNodeGen.create()) {
@Override
protected Object dispatch(Object[] args) {
return this.specializedNode.executeSameTokens(
LKQLTypeSystemGen.asAdaNode(args[0]), args[1]);
}
}));

// ----- Inner classes -----

/** Expression of the "children" method. */
public static final class ChildrenExpr extends AbstractBuiltInFunctionBody {
Expand Down Expand Up @@ -171,21 +182,13 @@ public Object executeGeneric(VirtualFrame frame) {
}

/** Expression of the "same_tokens" method. */
public static final class SameTokensExpr extends AbstractBuiltInFunctionBody {
@Override
public Object executeGeneric(VirtualFrame frame) {
// Get the nodes to compare
Libadalang.AdaNode leftNode = LKQLTypeSystemGen.asAdaNode(frame.getArguments()[0]);
Libadalang.AdaNode rightNode;
try {
rightNode = LKQLTypeSystemGen.expectAdaNode(frame.getArguments()[1]);
} catch (UnexpectedResultException e) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.ADA_NODE,
LKQLTypesHelper.fromJava(e.getResult()),
this.callNode.getArgList().getArgs()[0]);
}
public abstract static class SameTokensExpr
extends SpecializedBuiltInBody.SpecializedBuiltInNode {

public abstract boolean executeSameTokens(Libadalang.AdaNode leftNode, Object rightNode);

@Specialization
protected boolean onAdaNode(Libadalang.AdaNode leftNode, Libadalang.AdaNode rightNode) {
// Get the tokens
Libadalang.Token leftToken = leftNode.tokenStart();
Libadalang.Token rightToken = rightNode.tokenStart();
Expand Down Expand Up @@ -218,12 +221,16 @@ public Object executeGeneric(VirtualFrame frame) {
return true;
}

/**
* Get the next token from the given one ignoring the trivias
*
* @param t The token to get the next from
* @return The next token
*/
@Fallback
protected boolean onInvalid(
@SuppressWarnings("unused") Libadalang.AdaNode leftNode, Object rightValue) {
throw LKQLRuntimeException.wrongType(
LKQLTypesHelper.ADA_NODE,
LKQLTypesHelper.fromJava(rightValue),
body.argNode(0));
}

/** Get the next token from the given one ignoring the trivias. */
private static Libadalang.Token next(Libadalang.Token t) {
Libadalang.Token res = t.next();
while (!res.isNone() && res.triviaIndex != 0) {
Expand Down
Loading

0 comments on commit 9539af9

Please sign in to comment.