From 524564af62646397a96e49fee19d7d35cf6d3f58 Mon Sep 17 00:00:00 2001 From: Hugo Guerrier Date: Fri, 14 Jun 2024 14:45:55 +0200 Subject: [PATCH 1/2] Rewrite all built-ins functions as specialized nodes --- .../AbstractBuiltInFunctionBody.java | 13 ++ .../built_ins/SpecializedBuiltInBody.java | 50 ++++++ .../built_ins/functions/BaseNameFunction.java | 48 +++--- .../built_ins/functions/ConcatFunction.java | 155 +++++++++--------- .../built_ins/functions/DocFunction.java | 6 +- .../built_ins/functions/DocumentBuiltins.java | 22 +-- .../functions/DocumentNamespace.java | 2 + .../built_ins/functions/HelpFunction.java | 6 +- .../built_ins/functions/ImgFunction.java | 40 +++-- .../built_ins/functions/MapFunction.java | 100 ++++++----- .../built_ins/functions/PatternFunction.java | 66 ++++---- .../built_ins/functions/PrintFunction.java | 70 ++++---- .../built_ins/functions/ProfileFunction.java | 6 +- .../built_ins/functions/ReduceFunction.java | 105 ++++++------ .../functions/SpecifiedUnitsFunction.java | 9 +- .../built_ins/functions/UniqueFunction.java | 50 +++--- .../built_ins/functions/UnitsFunction.java | 4 +- 17 files changed, 409 insertions(+), 343 deletions(-) create mode 100644 lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/SpecializedBuiltInBody.java diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/AbstractBuiltInFunctionBody.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/AbstractBuiltInFunctionBody.java index 41a4553f5..4ac9737b8 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/AbstractBuiltInFunctionBody.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/AbstractBuiltInFunctionBody.java @@ -5,6 +5,7 @@ package com.adacore.lkql_jit.built_ins; +import com.adacore.lkql_jit.nodes.LKQLNode; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.oracle.truffle.api.frame.VirtualFrame; @@ -24,12 +25,24 @@ protected AbstractBuiltInFunctionBody() { super(null); } + // ----- Getters ----- + + public FunCall getCallNode() { + return callNode; + } + // ----- Setters ----- public void setCallNode(FunCall callNode) { this.callNode = callNode; } + // ----- Instance methods ----- + + public LKQLNode argNode(int index) { + return this.callNode.getArgList().getArgs()[index]; + } + // ----- Class methods ----- /** Create a new built-in function body from the given callback representing its execution. */ diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/SpecializedBuiltInBody.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/SpecializedBuiltInBody.java new file mode 100644 index 000000000..02f5b0a19 --- /dev/null +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/SpecializedBuiltInBody.java @@ -0,0 +1,50 @@ +// +// Copyright (C) 2005-2024, AdaCore +// SPDX-License-Identifier: GPL-3.0-or-later +// + +package com.adacore.lkql_jit.built_ins; + +import com.adacore.lkql_jit.LKQLTypeSystem; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; + +public abstract class SpecializedBuiltInBody< + T extends SpecializedBuiltInBody.SpecializedBuiltInNode> + extends AbstractBuiltInFunctionBody { + + // ----- Attributes ----- + + /** This node represents the execution of the built-in function. */ + @Child protected T specializedNode; + + // ----- Constructors ----- + + /** Create a new specialized body with its corresponding execution node. */ + public SpecializedBuiltInBody(T specializedNode) { + this.specializedNode = specializedNode; + this.specializedNode.body = this; + } + + // ----- Instance methods ----- + + /** Dispatch the function arguments to the specialized execution node. */ + protected abstract Object dispatch(Object[] args); + + @Override + public Object executeGeneric(VirtualFrame frame) { + return this.dispatch(frame.getArguments()); + } + + // ----- Inner classes ----- + + /** This class represents an execution node, payload of a built-in body. */ + @TypeSystemReference(LKQLTypeSystem.class) + public abstract static class SpecializedBuiltInNode extends Node { + // ----- Attributes ----- + + /** The built-in body that owns this specialized node. */ + protected SpecializedBuiltInBody body; + } +} diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/BaseNameFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/BaseNameFunction.java index 1dc69d3f4..95c90a706 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/BaseNameFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/BaseNameFunction.java @@ -5,14 +5,14 @@ package com.adacore.lkql_jit.built_ins.functions; -import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.utils.LKQLTypesHelper; import com.adacore.lkql_jit.utils.functions.FileUtils; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; /** * This class represents the "base_name" built-in function in the LKQL language. @@ -21,33 +21,39 @@ */ public final class BaseNameFunction { - // ----- Attributes ----- - - /** The name of the built-in. */ public static final String NAME = "base_name"; - // ----- Class methods ----- - + /** Get a brand new "base_name" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Given a string that represents a file name, returns the basename", new String[] {"str"}, new Expr[] {null}, - (VirtualFrame frame, FunCall call) -> { - // Get the file full path - Object path = frame.getArguments()[0]; - - // Check the argument type - if (!LKQLTypeSystemGen.isString(path)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(path), - call.getArgList().getArgs()[0]); + new SpecializedBuiltInBody<>(BaseNameFunctionFactory.BaseNameExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeBaseName(args[0]); } - - // Return the base name of the file - return FileUtils.baseName(LKQLTypeSystemGen.asString(path)); }); } + + /** Expression of the "base_name" function. */ + abstract static class BaseNameExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract String executeBaseName(Object fileName); + + @Specialization + protected String executeOnString(String fileName) { + return FileUtils.baseName(fileName); + } + + @Fallback + protected String invalidType(Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_STRING, + LKQLTypesHelper.fromJava(notValid), + body.argNode(0)); + } + } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ConcatFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ConcatFunction.java index 025a07ea1..b494393ec 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ConcatFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ConcatFunction.java @@ -7,14 +7,17 @@ import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.runtime.values.lists.LKQLList; import com.adacore.lkql_jit.utils.LKQLTypesHelper; import com.adacore.lkql_jit.utils.functions.ArrayUtils; import com.adacore.lkql_jit.utils.functions.StringUtils; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; /** * This class represents the "concat" built-in function in the LKQL language. @@ -23,92 +26,86 @@ */ public final class ConcatFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "concat"; - // ----- Class methods ----- - + /** Get a brand new "concat" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Given a list of lists or strings, return a concatenated list or string", new String[] {"lists"}, new Expr[] {null}, - (VirtualFrame frame, FunCall call) -> { - - // Get the argument - Object lists = frame.getArguments()[0]; - - // Check the type of the argument - if (!LKQLTypeSystemGen.isLKQLList(lists)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_LIST, - LKQLTypesHelper.fromJava(lists), - call.getArgList().getArgs()[0]); - } - - // Cast the argument to list - LKQLList listValue = LKQLTypeSystemGen.asLKQLList(lists); - - // If the list is not empty - if (listValue.size() > 0) { - final Object firstItem = listValue.get(0); - - // If the first value is a string look for strings in the list - if (LKQLTypeSystemGen.isString(firstItem)) { - // Create a string builder and add all strings in the list - String result = LKQLTypeSystemGen.asString(firstItem); - for (int i = 1; i < listValue.size(); i++) { - final Object item = listValue.get(i); - if (!LKQLTypeSystemGen.isString(item)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(item), - call.getArgList().getArgs()[0]); - } - result = - StringUtils.concat( - result, LKQLTypeSystemGen.asString(item)); - } - - // Return the result - return result; - } - - // If the first item is a list look for lists in the list - if (LKQLTypeSystemGen.isLKQLList(firstItem)) { - // Create a result array and add all list of the argument - Object[] result = LKQLTypeSystemGen.asLKQLList(firstItem).getContent(); - for (int i = 1; i < listValue.size(); i++) { - final Object item = listValue.get(i); - if (!LKQLTypeSystemGen.isLKQLList(item)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_LIST, - LKQLTypesHelper.fromJava(item), - call.getArgList().getArgs()[0]); - } - result = - ArrayUtils.concat( - result, - LKQLTypeSystemGen.asLKQLList(item).getContent()); - } - return new LKQLList(result); - } - - // Else there is an error - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.typeUnion( - LKQLTypesHelper.LKQL_LIST, LKQLTypesHelper.LKQL_STRING), - LKQLTypesHelper.fromJava(firstItem), - call.getArgList().getArgs()[0]); - } - - // If the list is empty just return an empty list - else { - return new LKQLList(new Object[0]); + new SpecializedBuiltInBody<>(ConcatFunctionFactory.ConcatExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeConcat(args[0]); } }); } + + /** Expression of the "concat" function. */ + abstract static class ConcatExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract Object executeConcat(Object list); + + protected static boolean isString(Object o) { + return LKQLTypeSystemGen.isString(o); + } + + protected static boolean isList(Object o) { + return LKQLTypeSystemGen.isLKQLList(o); + } + + @Specialization(guards = {"list.size() > 0", "isString(list.get(0))"}) + protected String onListOfStrings(LKQLList list) { + // Create a string builder and add all strings in the list + String result = LKQLTypeSystemGen.asString(list.get(0)); + for (int i = 1; i < list.size(); i++) { + final Object item = list.get(i); + if (!LKQLTypeSystemGen.isString(item)) { + this.invalidElemType(list, item); + } + result = StringUtils.concat(result, LKQLTypeSystemGen.asString(item)); + } + return result; + } + + @Specialization(guards = {"list.size() > 0", "isList(list.get(0))"}) + protected LKQLList onListOfLists(LKQLList list) { + Object[] result = LKQLTypeSystemGen.asLKQLList(list.get(0)).getContent(); + for (int i = 1; i < list.size(); i++) { + final Object item = list.get(i); + if (!LKQLTypeSystemGen.isLKQLList(item)) { + this.invalidElemType(list, item); + } + result = ArrayUtils.concat(result, LKQLTypeSystemGen.asLKQLList(item).getContent()); + } + return new LKQLList(result); + } + + @Specialization(guards = "notValidElem.size() > 0") + @CompilerDirectives.TruffleBoundary + protected LKQLList invalidElemType( + @SuppressWarnings("unused") LKQLList notValidElem, + @Cached("notValidElem.get(0)") Object elem) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_LIST + + " of " + + LKQLTypesHelper.typeUnion( + LKQLTypesHelper.LKQL_LIST, LKQLTypesHelper.LKQL_STRING), + LKQLTypesHelper.fromJava(elem) + " element", + body.argNode(0)); + } + + @Specialization(guards = "emptyList.size() == 0") + protected LKQLList onEmptyList(@SuppressWarnings("unused") LKQLList emptyList) { + return new LKQLList(new Object[0]); + } + + @Fallback + protected LKQLList invalidType(Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_LIST, LKQLTypesHelper.fromJava(notValid), body.argNode(0)); + } + } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocFunction.java index ce2cbccc8..2f14bc5d4 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocFunction.java @@ -19,13 +19,9 @@ */ public final class DocFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "doc"; - // ----- Class methods ----- - + /** Get a brand new "doc" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentBuiltins.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentBuiltins.java index 3232a0d94..5d688d0af 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentBuiltins.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentBuiltins.java @@ -19,6 +19,18 @@ public class DocumentBuiltins { public static final String NAME = "document_builtins"; + /** Get a brand new "document_builtins" function value. */ + public static BuiltInFunctionValue getValue() { + return new BuiltInFunctionValue( + NAME, + "Return a string in the RsT format containing documentation for all built-ins", + new String[] {}, + new Expr[] {}, + (VirtualFrame frame, FunCall call) -> + documentBuiltinsImpl(frame.materialize(), call)); + } + + /** Function for the "document_builtins" execution. */ @CompilerDirectives.TruffleBoundary public static String documentBuiltinsImpl( @SuppressWarnings("unused") MaterializedFrame frame, @@ -93,14 +105,4 @@ public static String documentBuiltinsImpl( throw new RuntimeException(e); } } - - public static BuiltInFunctionValue getValue() { - return new BuiltInFunctionValue( - NAME, - "Return a string in the RsT format containing documentation for all built-ins", - new String[] {}, - new Expr[] {}, - (VirtualFrame frame, FunCall call) -> - documentBuiltinsImpl(frame.materialize(), call)); - } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentNamespace.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentNamespace.java index 0d5608c1b..d5cec3aa7 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentNamespace.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/DocumentNamespace.java @@ -25,6 +25,7 @@ public class DocumentNamespace { public static final String NAME = "document_namespace"; + /** Get a brand new "document_namespace" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, @@ -34,6 +35,7 @@ public static BuiltInFunctionValue getValue() { (VirtualFrame frame, FunCall call) -> impl(frame.materialize(), call)); } + /** Function for the "document_namespace" execution. */ @CompilerDirectives.TruffleBoundary private static Object impl(MaterializedFrame frame, FunCall call) { Object nsObj = frame.getArguments()[0]; diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/HelpFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/HelpFunction.java index 8f0ddd3d9..149704289 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/HelpFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/HelpFunction.java @@ -22,13 +22,9 @@ */ public final class HelpFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "help"; - // ----- Class methods ----- - + /** Get a brand new "help" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ImgFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ImgFunction.java index 0bc442fe7..7b68ed2c6 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ImgFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ImgFunction.java @@ -6,11 +6,13 @@ package com.adacore.lkql_jit.built_ins.functions; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; -import com.adacore.lkql_jit.utils.functions.ObjectUtils; +import com.adacore.lkql_jit.utils.Constants; import com.adacore.lkql_jit.utils.functions.StringUtils; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.library.CachedLibrary; /** * This class represents the "img" built-in function in the LKQL language. @@ -19,26 +21,36 @@ */ public final class ImgFunction { - // ----- Attributes ----- - - /** The name of the built-in. */ public static final String NAME = "img"; - // ----- Class methods ----- - + /** Get a brand new "img" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Return a string representation of an object", new String[] {"val"}, new Expr[] {null}, - (VirtualFrame frame, FunCall call) -> { - // Return the string representation of the argument - if (frame.getArguments()[0] instanceof String s) { - return StringUtils.toRepr(s); - } else { - return ObjectUtils.toString(frame.getArguments()[0]); + new SpecializedBuiltInBody<>(ImgFunctionFactory.ImgExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeImg(args[0]); } }); } + + /** Expression of the "img" function. */ + abstract static class ImgExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract String executeImg(Object obj); + + @Specialization + protected String onString(String string) { + return StringUtils.toRepr(string); + } + + @Specialization(limit = Constants.SPECIALIZED_LIB_LIMIT) + protected String onObject(Object obj, @CachedLibrary("obj") InteropLibrary objLibrary) { + return (String) objLibrary.toDisplayString(obj); + } + } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/MapFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/MapFunction.java index f1bfdc134..79511cd86 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/MapFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/MapFunction.java @@ -5,22 +5,23 @@ package com.adacore.lkql_jit.built_ins.functions; -import com.adacore.lkql_jit.LKQLTypeSystemGen; -import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.runtime.values.LKQLFunction; import com.adacore.lkql_jit.runtime.values.interfaces.Iterable; import com.adacore.lkql_jit.runtime.values.lists.LKQLList; +import com.adacore.lkql_jit.utils.Constants; import com.adacore.lkql_jit.utils.Iterator; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.ArityException; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; -import com.oracle.truffle.api.nodes.UnexpectedResultException; +import com.oracle.truffle.api.library.CachedLibrary; /** * This class represents the "map" built-in function in the LKQL language. @@ -29,13 +30,9 @@ */ public final class MapFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "map"; - // ----- Class methods ----- - + /** Get a brand new "map" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, @@ -44,71 +41,68 @@ public static BuiltInFunctionValue getValue() { + "map(lst, f) -> [f(lst[1]), f(lst[2]), ...]", new String[] {"indexable", "fn"}, new Expr[] {null, null}, - new MapExpr()); + new SpecializedBuiltInBody<>(MapFunctionFactory.MaxExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeMap(args[0], args[1]); + } + }); } - // ----- Inner classes ----- - /** Expression of the "map" function. */ - public static final class MapExpr extends AbstractBuiltInFunctionBody { - - /** An uncached interop library for the checker functions execution. */ - private InteropLibrary interopLibrary = InteropLibrary.getUncached(); - - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the arguments - Iterable iterable; - LKQLFunction mapFunction; - - try { - iterable = LKQLTypeSystemGen.expectIterable(frame.getArguments()[0]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_ITERABLE, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[0]); - } - - try { - mapFunction = LKQLTypeSystemGen.expectLKQLFunction(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_FUNCTION, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[1]); - } + abstract static class MaxExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Verify the function arity - if (mapFunction.parameterNames.length != 1) { - throw LKQLRuntimeException.fromMessage( - "Function passed to map should have arity of one", - this.callNode.getArgList().getArgs()[1]); - } + public abstract LKQLList executeMap(Object iterable, Object function); - // Prepare the result + @Specialization( + limit = Constants.SPECIALIZED_LIB_LIMIT, + guards = "function.parameterNames.length == 1") + protected LKQLList onValidArgs( + Iterable iterable, + LKQLFunction function, + @CachedLibrary("function") InteropLibrary functionLibrary) { Object[] res = new Object[(int) iterable.size()]; - - // Apply the mapping function int i = 0; Iterator iterator = iterable.iterator(); + while (iterator.hasNext()) { try { res[i] = - this.interopLibrary.execute( - mapFunction, mapFunction.closure.getContent(), iterator.next()); + functionLibrary.execute( + function, function.closure.getContent(), iterator.next()); } catch (ArityException | UnsupportedTypeException | UnsupportedMessageException e) { // TODO: Implement runtime checks in the LKQLFunction class and base computing // on them (#138) - throw LKQLRuntimeException.fromJavaException(e, this.callNode); + throw LKQLRuntimeException.fromJavaException(e, body.argNode(1)); } i++; } - // Return the result return new LKQLList(res); } + + @Specialization + protected LKQLList onInvalidFunction(Iterable iterable, LKQLFunction function) { + throw LKQLRuntimeException.wrongArity( + 1, function.parameterNames.length, body.argNode(1)); + } + + @Specialization + protected LKQLList invalidType(Iterable iterable, Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_FUNCTION, + LKQLTypesHelper.fromJava(notValid), + body.argNode(1)); + } + + @Fallback + protected LKQLList invalidType(Object notValid, Object function) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_ITERABLE, + LKQLTypesHelper.fromJava(notValid), + body.argNode(0)); + } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PatternFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PatternFunction.java index 96bfbe8ec..4c6e84304 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PatternFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PatternFunction.java @@ -5,16 +5,15 @@ package com.adacore.lkql_jit.built_ins.functions; -import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.nodes.expressions.literals.BooleanLiteral; import com.adacore.lkql_jit.runtime.values.LKQLPattern; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.UnexpectedResultException; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; /** * This class represents the "pattern" built-in function in the LKQL language. @@ -23,44 +22,47 @@ */ public final class PatternFunction { - // ----- Attributes ----- - - /** The name of the built-in. */ public static final String NAME = "pattern"; - // ----- Class methods ----- - + /** Get a brand new "pattern" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Given a regex pattern string, create a pattern object", new String[] {"regex", "case_sensitive"}, new Expr[] {null, new BooleanLiteral(null, true)}, - (VirtualFrame frame, FunCall call) -> { - // Get the string parameter - String regexString; - try { - regexString = LKQLTypeSystemGen.expectString(frame.getArguments()[0]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(e.getResult()), - call.getArgList().getArgs()[0]); + new SpecializedBuiltInBody<>(PatternFunctionFactory.PatternExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executePattern(args[0], args[1]); } + }); + } - // Get the case sensitiveness parameter - boolean caseSensitive; - try { - caseSensitive = LKQLTypeSystemGen.expectBoolean(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_BOOLEAN, - LKQLTypesHelper.fromJava(e.getResult()), - call.getArgList().getArgs()[1]); - } + /** Expression of the "pattern" function. */ + abstract static class PatternExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Create the pattern and return it - return new LKQLPattern(call, regexString, caseSensitive); - }); + public abstract LKQLPattern executePattern(Object regex, Object caseSensitive); + + @Specialization + protected LKQLPattern onValidArgs(String regex, boolean caseSensitive) { + return new LKQLPattern(body.getCallNode(), regex, caseSensitive); + } + + @Specialization + protected LKQLPattern invalidType(String regex, Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_BOOLEAN, + LKQLTypesHelper.fromJava(notValid), + body.argNode(1)); + } + + @Fallback + protected LKQLPattern invalidType(Object notValid, Object caseSensitive) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_STRING, + LKQLTypesHelper.fromJava(notValid), + body.argNode(0)); + } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PrintFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PrintFunction.java index 1e6afb79c..32ab1fd72 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PrintFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/PrintFunction.java @@ -6,17 +6,18 @@ package com.adacore.lkql_jit.built_ins.functions; import com.adacore.lkql_jit.LKQLLanguage; -import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.nodes.expressions.literals.BooleanLiteral; import com.adacore.lkql_jit.runtime.values.LKQLUnit; +import com.adacore.lkql_jit.utils.Constants; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.adacore.lkql_jit.utils.functions.ObjectUtils; -import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.UnexpectedResultException; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.library.CachedLibrary; /** * This class represents the "print" built-in function in the LKQL language. @@ -25,42 +26,49 @@ */ public final class PrintFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "print"; - // ----- Class methods ----- - + /** Get a brand new "print" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Built-in print function. Prints whatever is passed as an argument", new String[] {"val", "new_line"}, new Expr[] {null, new BooleanLiteral(null, true)}, - (VirtualFrame frame, FunCall call) -> { - // Get the arguments - Object toPrint = frame.getArguments()[0]; - - boolean newLine; - try { - newLine = LKQLTypeSystemGen.expectBoolean(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_BOOLEAN, - LKQLTypesHelper.fromJava(e.getResult()), - call.getArgList().getArgs()[1]); + new SpecializedBuiltInBody<>(PrintFunctionFactory.PrintExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executePrint(args[0], args[1]); } + }); + } - // Print the value - if (newLine) { - LKQLLanguage.getContext(call).println(ObjectUtils.toString(toPrint)); - } else { - LKQLLanguage.getContext(call).print(ObjectUtils.toString(toPrint)); - } + /** Expression of the "print" function. */ + abstract static class PrintExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Return the unit value - return LKQLUnit.INSTANCE; - }); + public abstract LKQLUnit executePrint(Object toPrint, Object newline); + + @Specialization(limit = Constants.SPECIALIZED_LIB_LIMIT) + protected LKQLUnit onBoolean( + Object toPrint, + boolean newline, + @CachedLibrary("toPrint") InteropLibrary printingLibrary) { + if (newline) { + LKQLLanguage.getContext(null) + .println((String) printingLibrary.toDisplayString(toPrint)); + } else { + LKQLLanguage.getContext(null) + .print((String) printingLibrary.toDisplayString(toPrint)); + } + return LKQLUnit.INSTANCE; + } + + @Fallback + protected LKQLUnit invalidType(Object toPrint, Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_BOOLEAN, + LKQLTypesHelper.fromJava(notValid), + body.argNode(1)); + } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ProfileFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ProfileFunction.java index 67cff7abc..b60d35da2 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ProfileFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ProfileFunction.java @@ -19,13 +19,9 @@ */ public final class ProfileFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "profile"; - // ----- Class methods ----- - + /** Get a brand new "profile" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ReduceFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ReduceFunction.java index ce7f4467d..fd6a1060f 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ReduceFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/ReduceFunction.java @@ -5,21 +5,22 @@ package com.adacore.lkql_jit.built_ins.functions; -import com.adacore.lkql_jit.LKQLTypeSystemGen; -import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.runtime.values.LKQLFunction; import com.adacore.lkql_jit.runtime.values.interfaces.Iterable; +import com.adacore.lkql_jit.utils.Constants; import com.adacore.lkql_jit.utils.Iterator; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.ArityException; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; -import com.oracle.truffle.api.nodes.UnexpectedResultException; +import com.oracle.truffle.api.library.CachedLibrary; /** * This class represents the "reduce" built-in function in the LKQL language. @@ -28,70 +29,43 @@ */ public final class ReduceFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "reduce"; - // ----- Class methods ----- - + /** Get a brand new "reduce" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Given a collection, a reduction function, and an initial value reduce the result", new String[] {"indexable", "fn", "init"}, new Expr[] {null, null, null}, - new ReduceExpr()); + new SpecializedBuiltInBody<>(ReduceFunctionFactory.ReduceExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeReduce(args[0], args[1], args[2]); + } + }); } - // ----- Inner classes ----- - - /** Expression for the "reduce" function. */ - public static final class ReduceExpr extends AbstractBuiltInFunctionBody { + /** Expression of the "reduce" function. */ + abstract static class ReduceExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - /** An uncached interop library for the checker functions execution. */ - private InteropLibrary interopLibrary = InteropLibrary.getUncached(); + public abstract Object executeReduce(Object iterable, Object function, Object initValue); - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the arguments - Iterable iterable; - LKQLFunction reduceFunction; - Object initValue = frame.getArguments()[2]; - - try { - iterable = LKQLTypeSystemGen.expectIterable(frame.getArguments()[0]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_ITERABLE, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[0]); - } - - try { - reduceFunction = LKQLTypeSystemGen.expectLKQLFunction(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_FUNCTION, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[1]); - } - - // Verify the function arity - if (reduceFunction.parameterNames.length != 2) { - throw LKQLRuntimeException.fromMessage( - "Function passed to reduce should have arity of two", - this.callNode.getArgList().getArgs()[1]); - } - - // Execute the reducing + @Specialization( + limit = Constants.SPECIALIZED_LIB_LIMIT, + guards = "function.parameterNames.length == 2") + protected Object onValidArgs( + Iterable iterable, + LKQLFunction function, + Object initValue, + @CachedLibrary("function") InteropLibrary functionLibrary) { Iterator iterator = iterable.iterator(); while (iterator.hasNext()) { try { initValue = - this.interopLibrary.execute( - reduceFunction, - reduceFunction.closure.getContent(), + functionLibrary.execute( + function, + function.closure.getContent(), initValue, iterator.next()); } catch (ArityException @@ -99,12 +73,33 @@ public Object executeGeneric(VirtualFrame frame) { | UnsupportedMessageException e) { // TODO: Implement runtime checks in the LKQLFunction class and base computing // on them (#138) - throw LKQLRuntimeException.fromJavaException(e, this.callNode); + throw LKQLRuntimeException.fromJavaException(e, body.argNode(1)); } } - - // Return the result return initValue; } + + @Specialization + protected Object onInvalidFunction( + Iterable iterable, LKQLFunction function, Object initValue) { + throw LKQLRuntimeException.wrongArity( + 2, function.parameterNames.length, body.argNode(1)); + } + + @Specialization + protected Object invalidType(Iterable iterable, Object notValid, Object initValue) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_FUNCTION, + LKQLTypesHelper.fromJava(notValid), + body.argNode(1)); + } + + @Fallback + protected Object invalidType(Object notValid, Object function, Object initValue) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_ITERABLE, + LKQLTypesHelper.fromJava(notValid), + body.argNode(0)); + } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/SpecifiedUnitsFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/SpecifiedUnitsFunction.java index da0b10853..c7d43b974 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/SpecifiedUnitsFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/SpecifiedUnitsFunction.java @@ -19,19 +19,16 @@ */ public final class SpecifiedUnitsFunction { - // ----- Attributes ----- - - /** The name of the built-in. */ public static final String NAME = "specified_units"; + /** Get a brand new "specified_units" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, "Return an iterator on units specified by the user", new String[] {}, new Expr[] {}, - (VirtualFrame frame, FunCall call) -> { - return new LKQLList(LKQLLanguage.getContext(call).getSpecifiedUnits()); - }); + (VirtualFrame frame, FunCall call) -> + new LKQLList(LKQLLanguage.getContext(call).getSpecifiedUnits())); } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UniqueFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UniqueFunction.java index 37510959d..5d287fb5b 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UniqueFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UniqueFunction.java @@ -5,16 +5,16 @@ package com.adacore.lkql_jit.built_ins.functions; -import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInFunctionValue; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.runtime.values.interfaces.Indexable; import com.adacore.lkql_jit.runtime.values.lists.LKQLList; import com.adacore.lkql_jit.utils.LKQLTypesHelper; import com.adacore.lkql_jit.utils.functions.ArrayUtils; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; /** * This class represents the "unique" built-in function in the LKQL language. @@ -23,13 +23,9 @@ */ public final class UniqueFunction { - // ----- Attributes ----- - - /** The name of the function. */ public static final String NAME = "unique"; - // ----- Class methods ----- - + /** Get a brand new "unique" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, @@ -37,24 +33,30 @@ public static BuiltInFunctionValue getValue() { + " of each", new String[] {"indexable"}, new Expr[] {null}, - (VirtualFrame frame, FunCall call) -> { - // Get the argument - Object indexableObject = frame.getArguments()[0]; - - // Verify the argument type - if (!LKQLTypeSystemGen.isIndexable(indexableObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_LIST, - LKQLTypesHelper.fromJava(indexableObject), - call.getArgList().getArgs()[0]); + new SpecializedBuiltInBody<>(UniqueFunctionFactory.UniqueExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return specializedNode.executeUnique(args[0]); } + }); + } - // Cast the argument - Indexable indexable = LKQLTypeSystemGen.asIndexable(indexableObject); + /** Expression of the "unique" function. */ + abstract static class UniqueExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Return the result list - return new LKQLList( - ArrayUtils.unique(indexable.getContent()).toArray(new Object[0])); - }); + public abstract LKQLList executeUnique(Object indexable); + + @Specialization + protected LKQLList onIndexable(Indexable indexable) { + return new LKQLList(ArrayUtils.unique(indexable.getContent()).toArray(new Object[0])); + } + + @Fallback + protected LKQLList invalidType(Object notValid) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_ITERABLE, + LKQLTypesHelper.fromJava(notValid), + body.argNode(0)); + } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UnitsFunction.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UnitsFunction.java index f77372016..53ed697bf 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UnitsFunction.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/functions/UnitsFunction.java @@ -19,11 +19,9 @@ */ public final class UnitsFunction { - // ----- Attributes ----- - - /** The name of the built-in. */ public static final String NAME = "units"; + /** Get a brand new "units" function value. */ public static BuiltInFunctionValue getValue() { return new BuiltInFunctionValue( NAME, From 9539af9dd327d21c4ba86b44407da235382851f7 Mon Sep 17 00:00:00 2001 From: Hugo Guerrier Date: Mon, 16 Sep 2024 14:45:07 +0200 Subject: [PATCH 2/2] Rewrite all built-in methods as specialized nodes --- .../built_ins/methods/ListMethods.java | 109 +++--- .../built_ins/methods/NodeMethods.java | 51 +-- .../built_ins/methods/StrMethods.java | 309 +++++++++--------- .../built_ins/methods/TokenMethods.java | 123 ++++--- .../interpreter/sublist_builtin/test.out | 4 +- 5 files changed, 325 insertions(+), 271 deletions(-) diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/ListMethods.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/ListMethods.java index 319660fe2..482074a2b 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/ListMethods.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/ListMethods.java @@ -10,13 +10,14 @@ import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory; import com.adacore.lkql_jit.built_ins.BuiltInsHolder; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.built_ins.functions.UniqueFunction; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; -import com.adacore.lkql_jit.nodes.expressions.FunCall; import com.adacore.lkql_jit.runtime.values.lists.LKQLList; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import java.util.Arrays; import java.util.Map; @@ -27,51 +28,6 @@ */ public class ListMethods { - private static final Map.Entry sublistFunction = - createMethod( - "sublist", - "Return a sublist of `list` from `low_bound` to `high_bound`", - new String[] {"low_bound", "high_bound"}, - new Expr[] {null, null}, - (VirtualFrame frame, FunCall call) -> { - var args = frame.getArguments(); - - if (!LKQLTypeSystemGen.isLKQLList(args[0])) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_LIST, - LKQLTypesHelper.fromJava(args[0]), - call.getArgList().getArgs()[0]); - } - - if (!LKQLTypeSystemGen.isLong(args[1])) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_INTEGER, - LKQLTypesHelper.fromJava(args[1]), - call.getArgList().getArgs()[1]); - } - - if (!LKQLTypeSystemGen.isLong(args[2])) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_INTEGER, - LKQLTypesHelper.fromJava(args[2]), - call.getArgList().getArgs()[2]); - } - - LKQLList list = LKQLTypeSystemGen.asLKQLList(args[0]); - long lowBound = LKQLTypeSystemGen.asLong(args[1]); - long highBound = LKQLTypeSystemGen.asLong(args[2]); - - if (lowBound < 1) { - throw LKQLRuntimeException.invalidIndex((int) lowBound, call); - } else if (highBound > list.getContent().length) { - throw LKQLRuntimeException.invalidIndex((int) highBound, call); - } - - return new LKQLList( - Arrays.copyOfRange( - list.getContent(), (int) lowBound - 1, (int) highBound)); - }); - public static final Map methods = BuiltInsHolder.combine( Map.ofEntries( @@ -79,6 +35,63 @@ public class ListMethods { UniqueFunction.NAME, BuiltInMethodFactory.fromFunctionValue( UniqueFunction.getValue(), true)), - sublistFunction), + createMethod( + "sublist", + "Return a sublist of `list` from " + + "`low_bound` to `high_bound`", + new String[] {"low_bound", "high_bound"}, + new Expr[] {null, null}, + new SpecializedBuiltInBody<>( + ListMethodsFactory.SublistExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeSublist( + LKQLTypeSystemGen.asLKQLList(args[0]), + args[1], + args[2]); + } + })), IterableMethods.methods); + + // ----- Inner classes ----- + + /** Expression of the "sublist" method. */ + public abstract static class SublistExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract LKQLList executeSublist(LKQLList list, Object low, Object high); + + @Specialization + protected LKQLList onValid(LKQLList list, long low, long high) { + // Offset the low bound by 1 since LKQL is 1-indexed + low = low - 1; + + // Check bounds validity + if (low < 0) { + throw LKQLRuntimeException.invalidIndex((int) low + 1, body.argNode(0)); + } else if (high > list.getContent().length) { + throw LKQLRuntimeException.invalidIndex((int) high, body.argNode(1)); + } + + // Return the sublist + return new LKQLList(Arrays.copyOfRange(list.getContent(), (int) low, (int) high)); + } + + @Specialization + protected LKQLList onInvalidHigh( + @SuppressWarnings("unused") LKQLList list, + @SuppressWarnings("unused") long low, + Object high) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_INTEGER, LKQLTypesHelper.fromJava(high), body.argNode(1)); + } + + @Fallback + protected LKQLList onInvalidLow( + @SuppressWarnings("unused") LKQLList list, + Object low, + @SuppressWarnings("unused") Object high) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_INTEGER, LKQLTypesHelper.fromJava(low), body.argNode(0)); + } + } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/NodeMethods.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/NodeMethods.java index b2b785192..7829d1aec 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/NodeMethods.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/NodeMethods.java @@ -13,6 +13,7 @@ import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody; import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.runtime.values.LKQLNull; @@ -22,8 +23,9 @@ import com.adacore.lkql_jit.utils.functions.ObjectUtils; import com.adacore.lkql_jit.utils.functions.ReflectionUtils; import com.adacore.lkql_jit.utils.functions.StringUtils; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.UnexpectedResultException; import java.util.ArrayList; import java.util.Map; @@ -66,7 +68,16 @@ public final class NodeMethods { "Return whether two nodes have the same tokens, ignoring trivias", new String[] {"other"}, new Expr[] {null}, - new SameTokensExpr())); + new SpecializedBuiltInBody<>( + NodeMethodsFactory.SameTokensExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeSameTokens( + LKQLTypeSystemGen.asAdaNode(args[0]), args[1]); + } + })); + + // ----- Inner classes ----- /** Expression of the "children" method. */ public static final class ChildrenExpr extends AbstractBuiltInFunctionBody { @@ -171,21 +182,13 @@ public Object executeGeneric(VirtualFrame frame) { } /** Expression of the "same_tokens" method. */ - public static final class SameTokensExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the nodes to compare - Libadalang.AdaNode leftNode = LKQLTypeSystemGen.asAdaNode(frame.getArguments()[0]); - Libadalang.AdaNode rightNode; - try { - rightNode = LKQLTypeSystemGen.expectAdaNode(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.ADA_NODE, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[0]); - } + public abstract static class SameTokensExpr + extends SpecializedBuiltInBody.SpecializedBuiltInNode { + public abstract boolean executeSameTokens(Libadalang.AdaNode leftNode, Object rightNode); + + @Specialization + protected boolean onAdaNode(Libadalang.AdaNode leftNode, Libadalang.AdaNode rightNode) { // Get the tokens Libadalang.Token leftToken = leftNode.tokenStart(); Libadalang.Token rightToken = rightNode.tokenStart(); @@ -218,12 +221,16 @@ public Object executeGeneric(VirtualFrame frame) { return true; } - /** - * Get the next token from the given one ignoring the trivias - * - * @param t The token to get the next from - * @return The next token - */ + @Fallback + protected boolean onInvalid( + @SuppressWarnings("unused") Libadalang.AdaNode leftNode, Object rightValue) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.ADA_NODE, + LKQLTypesHelper.fromJava(rightValue), + body.argNode(0)); + } + + /** Get the next token from the given one ignoring the trivias. */ private static Libadalang.Token next(Libadalang.Token t) { Libadalang.Token res = t.next(); while (!res.isNone() && res.triviaIndex != 0) { diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/StrMethods.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/StrMethods.java index 4477d6709..8fb08facd 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/StrMethods.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/StrMethods.java @@ -11,16 +11,17 @@ import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody; import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.built_ins.functions.BaseNameFunction; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.runtime.values.LKQLPattern; import com.adacore.lkql_jit.runtime.values.lists.LKQLList; import com.adacore.lkql_jit.utils.LKQLTypesHelper; -import com.adacore.lkql_jit.utils.functions.BigIntegerUtils; import com.adacore.lkql_jit.utils.functions.StringUtils; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; -import java.math.BigInteger; import java.util.Map; /** @@ -68,40 +69,82 @@ public class StrMethods { + " contained between indices from and to (both included)", new String[] {"from", "to"}, new Expr[] {null, null}, - new SubstringExpr()), + new SpecializedBuiltInBody<>( + StrMethodsFactory.SubstringExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeSubstring( + LKQLTypeSystemGen.asString(args[0]), args[1], args[2]); + } + }), createMethod( "split", "Given a string, return an iterator on the words contained by str" + " separated by separator", new String[] {"separator"}, new Expr[] {null}, - new SplitExpr()), + new SpecializedBuiltInBody<>( + StrMethodsFactory.SplitExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeSplit( + LKQLTypeSystemGen.asString(args[0]), args[1]); + } + }), createMethod( "contains", "Search for to_find in the given string. Return whether a match is" + " found. to_find can be either a pattern or a string", new String[] {"to_find"}, new Expr[] {null}, - new ContainsExpr()), + new SpecializedBuiltInBody<>( + StrMethodsFactory.ContainsExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeContains( + LKQLTypeSystemGen.asString(args[0]), args[1]); + } + }), createMethod( "find", "Search for to_find in the given string. Return position of the match," + " or -1 if no match. to_find can be either a pattern or a string", new String[] {"to_find"}, new Expr[] {null}, - new FindExpr()), + new SpecializedBuiltInBody<>( + StrMethodsFactory.FindExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeFind( + LKQLTypeSystemGen.asString(args[0]), args[1]); + } + }), createMethod( "starts_with", "Given a string, returns whether it starts with the given prefix", new String[] {"prefix"}, new Expr[] {null}, - new StartsWithExpr()), + new SpecializedBuiltInBody<>( + StrMethodsFactory.StartsWithExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeStartsWith( + LKQLTypeSystemGen.asString(args[0]), args[1]); + } + }), createMethod( "ends_with", "Given a string, returns whether it ends with the given suffix", new String[] {"suffix"}, new Expr[] {null}, - new EndsWithExpr())); + new SpecializedBuiltInBody<>( + StrMethodsFactory.EndsWithExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeEndsWith( + LKQLTypeSystemGen.asString(args[0]), args[1]); + } + })); // ----- Inner classes ----- @@ -186,192 +229,154 @@ public Object executeGeneric(VirtualFrame frame) { } /** Expression of the "substring" method. */ - public static final class SubstringExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the arguments - Object startObject = frame.getArguments()[1]; - Object endObject = frame.getArguments()[2]; - - // Verify the type of arguments - if (!LKQLTypeSystemGen.isImplicitBigInteger(startObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_INTEGER, - LKQLTypesHelper.fromJava(startObject), - this.callNode.getArgList().getArgs()[0]); - } + abstract static class SubstringExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - if (!LKQLTypeSystemGen.isImplicitBigInteger(endObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_INTEGER, - LKQLTypesHelper.fromJava(endObject), - this.callNode.getArgList().getArgs()[1]); - } - - // Cast the arguments - BigInteger startBig = - BigIntegerUtils.subtract( - LKQLTypeSystemGen.asImplicitBigInteger(startObject), BigInteger.ONE); - BigInteger endBig = LKQLTypeSystemGen.asImplicitBigInteger(endObject); + public abstract String executeSubstring(String source, Object start, Object end); - int start = BigIntegerUtils.intValue(startBig); - int end = BigIntegerUtils.intValue(endBig); + @Specialization + protected String onValid(String source, long start, long end) { + // Offset the start index by 1 since LKQL is 1-indexed + start = start - 1; - // Verify the start and end + // Verify start and end bounds if (start < 0) { - throw LKQLRuntimeException.invalidIndex( - start, this.callNode.getArgList().getArgs()[0]); + throw LKQLRuntimeException.invalidIndex((int) start + 1, this.body.argNode(0)); } - if (end > LKQLTypeSystemGen.asString(frame.getArguments()[0]).length()) { - throw LKQLRuntimeException.invalidIndex( - end, this.callNode.getArgList().getArgs()[1]); + if (end > source.length()) { + throw LKQLRuntimeException.invalidIndex((int) end, this.body.argNode(1)); } // Return the substring - return LKQLTypeSystemGen.asString(frame.getArguments()[0]).substring(start, end); + return source.substring((int) start, (int) end); + } + + @Specialization + protected String invalidEnd( + @SuppressWarnings("unused") String source, + @SuppressWarnings("unused") long start, + Object end) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_INTEGER, + LKQLTypesHelper.fromJava(end), + this.body.argNode(1)); + } + + @Fallback + protected String invalidStart( + @SuppressWarnings("unused") String source, + Object start, + @SuppressWarnings("unused") Object end) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_INTEGER, + LKQLTypesHelper.fromJava(start), + this.body.argNode(0)); } } /** Expression of the "split" method. */ - public static final class SplitExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the argument - Object toSplit = frame.getArguments()[0]; - Object separatorObject = frame.getArguments()[1]; - - // Verify the argument type - if (!LKQLTypeSystemGen.isString(separatorObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(separatorObject), - this.callNode.getArgList().getArgs()[0]); - } + abstract static class SplitExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract LKQLList executeSplit(String source, Object sep); - // Split the string - String[] separated = - StringUtils.split( - LKQLTypeSystemGen.asString(toSplit), - LKQLTypeSystemGen.asString(separatorObject)); + @Specialization + protected LKQLList onValid(String source, String sep) { + return new LKQLList(StringUtils.split(source, sep)); + } - // Return the list value of the split string - return new LKQLList(separated); + @Fallback + protected LKQLList onInvalid(@SuppressWarnings("unused") String source, Object sep) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_STRING, + LKQLTypesHelper.fromJava(sep), + this.body.argNode(0)); } } /** Expression of the "contains" method. */ - public static final class ContainsExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the arguments - String receiver = LKQLTypeSystemGen.asString(frame.getArguments()[0]); - Object toFindObject = frame.getArguments()[1]; - boolean contains; + abstract static class ContainsExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // If the argument is a string - if (LKQLTypeSystemGen.isString(toFindObject)) { - String toFind = LKQLTypeSystemGen.asString(toFindObject); - contains = StringUtils.contains(receiver, toFind); - } + public abstract boolean executeContains(String source, Object toFind); - // If the argument is a pattern - else if (LKQLTypeSystemGen.isLKQLPattern(toFindObject)) { - LKQLPattern pattern = LKQLTypeSystemGen.asLKQLPattern(toFindObject); - contains = pattern.contains(receiver); - } + @Specialization + protected boolean onString(String source, String toFind) { + return StringUtils.contains(source, toFind); + } - // Else, just thrown an error - else { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(toFindObject), - this.callNode.getArgList().getArgs()[0]); - } + @Specialization + protected boolean onPattern(String source, LKQLPattern toFind) { + return toFind.contains(source); + } - // Return if the receiver contains the to find - return contains; + @Fallback + protected boolean onInvalid(@SuppressWarnings("unused") String source, Object toFind) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.typeUnion( + LKQLTypesHelper.LKQL_STRING, LKQLTypesHelper.LKQL_PATTERN), + LKQLTypesHelper.fromJava(toFind), + this.body.argNode(0)); } } /** Expression of the "find" method. */ - public static final class FindExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the arguments - String receiver = LKQLTypeSystemGen.asString(frame.getArguments()[0]); - Object toFindObject = frame.getArguments()[1]; - int index; + abstract static class FindExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // If the argument is a string - if (LKQLTypeSystemGen.isString(toFindObject)) { - String toFind = LKQLTypeSystemGen.asString(toFindObject); - index = StringUtils.indexOf(receiver, toFind); - } + public abstract long executeFind(String source, Object toFind); - // If the argument is a pattern - else if (LKQLTypeSystemGen.isLKQLPattern(toFindObject)) { - LKQLPattern pattern = LKQLTypeSystemGen.asLKQLPattern(toFindObject); - index = pattern.find(receiver); - } + @Specialization + protected long onString(String source, String toFind) { + return StringUtils.indexOf(source, toFind) + 1; + } - // Else, just throw an error - else { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(toFindObject), - this.callNode.getArgList().getArgs()[0]); - } + @Specialization + protected long onPattern(String source, LKQLPattern toFind) { + return toFind.find(source) + 1; + } - // Return the index - return (long) index + 1; + @Fallback + protected long onInvalid(@SuppressWarnings("unused") String source, Object toFind) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.typeUnion( + LKQLTypesHelper.LKQL_STRING, LKQLTypesHelper.LKQL_PATTERN), + LKQLTypesHelper.fromJava(toFind), + this.body.argNode(0)); } } /** Expression of the "starts_with" method. */ - public static final class StartsWithExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the argument - Object prefixObject = frame.getArguments()[1]; - - // Verify the argument type - if (!LKQLTypeSystemGen.isString(prefixObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(prefixObject), - this.callNode.getArgList().getArgs()[0]); - } + abstract static class StartsWithExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Cast the arguments - String receiver = LKQLTypeSystemGen.asString(frame.getArguments()[0]); - String prefix = LKQLTypeSystemGen.asString(prefixObject); + public abstract boolean executeStartsWith(String source, Object prefix); - // Return if the receiver has the prefix - return receiver.startsWith(prefix); + @Specialization + protected boolean onValid(String source, String prefix) { + return source.startsWith(prefix); + } + + @Fallback + protected boolean onInvalid(@SuppressWarnings("unused") String source, Object prefix) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_STRING, + LKQLTypesHelper.fromJava(prefix), + this.body.argNode(0)); } } /** Expression of the "ends_with" method. */ - public static final class EndsWithExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the argument - Object suffixObject = frame.getArguments()[1]; - - // Verify the argument type - if (!LKQLTypeSystemGen.isString(suffixObject)) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_STRING, - LKQLTypesHelper.fromJava(suffixObject), - this.callNode.getArgList().getArgs()[0]); - } + abstract static class EndsWithExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Cast the arguments - String receiver = LKQLTypeSystemGen.asString(frame.getArguments()[0]); - String suffix = LKQLTypeSystemGen.asString(suffixObject); + public abstract boolean executeEndsWith(String source, Object suffix); + + @Specialization + protected boolean onValid(String source, String suffix) { + return source.endsWith(suffix); + } - // Return if the receiver has the prefix - return receiver.endsWith(suffix); + @Specialization + protected boolean onInvalid(@SuppressWarnings("unused") String source, Object suffix) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_STRING, + LKQLTypesHelper.fromJava(suffix), + this.body.argNode(0)); } } } diff --git a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/TokenMethods.java b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/TokenMethods.java index 3860b0023..3f2c951e6 100644 --- a/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/TokenMethods.java +++ b/lkql_jit/language/src/main/java/com/adacore/lkql_jit/built_ins/methods/TokenMethods.java @@ -12,14 +12,16 @@ import com.adacore.lkql_jit.LKQLTypeSystemGen; import com.adacore.lkql_jit.built_ins.AbstractBuiltInFunctionBody; import com.adacore.lkql_jit.built_ins.BuiltInMethodFactory; +import com.adacore.lkql_jit.built_ins.SpecializedBuiltInBody; import com.adacore.lkql_jit.exception.LKQLRuntimeException; import com.adacore.lkql_jit.nodes.expressions.Expr; import com.adacore.lkql_jit.nodes.expressions.literals.BooleanLiteral; import com.adacore.lkql_jit.utils.LKQLTypesHelper; import com.adacore.lkql_jit.utils.functions.ObjectUtils; import com.adacore.lkql_jit.utils.functions.StringUtils; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.UnexpectedResultException; import java.util.Map; /** @@ -40,7 +42,14 @@ public final class TokenMethods { "Return whether two tokens are structurally equivalent", new String[] {"other"}, new Expr[] {null}, - new IsEquivalentExpr()), + new SpecializedBuiltInBody<>( + TokenMethodsFactory.IsEquivalentExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeIsEquivalent( + LKQLTypeSystemGen.asToken(args[0]), args[1]); + } + }), createAttribute( "is_trivia", "Return whether this token is a trivia", @@ -50,13 +59,27 @@ public final class TokenMethods { "Return the next token", new String[] {"exclude_trivia"}, new Expr[] {new BooleanLiteral(null, false)}, - new NextExpr()), + new SpecializedBuiltInBody<>( + TokenMethodsFactory.NextExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executeNext( + LKQLTypeSystemGen.asToken(args[0]), args[1]); + } + }), createMethod( "previous", "Return the previous token", new String[] {"exclude_trivia"}, new Expr[] {new BooleanLiteral(null, false)}, - new PrevExpr()), + new SpecializedBuiltInBody<>( + TokenMethodsFactory.PrevExprNodeGen.create()) { + @Override + protected Object dispatch(Object[] args) { + return this.specializedNode.executePrev( + LKQLTypeSystemGen.asToken(args[0]), args[1]); + } + }), createAttribute("unit", "Return the unit for this token", new UnitExpr()), createAttribute("text", "Return the text of the token", new TextExpr()), createAttribute("kind", "Return the kind of the token", new KindExpr())); @@ -109,22 +132,20 @@ public Object executeGeneric(VirtualFrame frame) { } /** Expression of the "is_equivalent" method. */ - public static final class IsEquivalentExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get the other token to compare - Libadalang.Token other; - try { - other = LKQLTypeSystemGen.expectToken(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.TOKEN, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[1]); - } + abstract static class IsEquivalentExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - // Return the comparison - return LKQLTypeSystemGen.asToken(frame.getArguments()[0]).isEquivalent(other); + public abstract boolean executeIsEquivalent(Libadalang.Token receiver, Object other); + + @Specialization + protected boolean onValid(Libadalang.Token receiver, Libadalang.Token other) { + return receiver.isEquivalent(other); + } + + @Fallback + protected boolean onInvalid( + @SuppressWarnings("unused") Libadalang.Token receiver, Object other) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.TOKEN, LKQLTypesHelper.fromJava(other), this.body.argNode(0)); } } @@ -137,55 +158,63 @@ public Object executeGeneric(VirtualFrame frame) { } /** Expression of the "next" method. */ - public static final class NextExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get if the trivia tokens should be ignored - boolean ignoreTrivia; - try { - ignoreTrivia = LKQLTypeSystemGen.expectBoolean(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_BOOLEAN, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[1]); - } + abstract static class NextExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { - Libadalang.Token res = LKQLTypeSystemGen.asToken(frame.getArguments()[0]).next(); + public abstract Libadalang.Token executeNext( + Libadalang.Token receiver, Object ignoreTrivia); + + @Specialization + protected Libadalang.Token onValid(Libadalang.Token receiver, boolean ignoreTrivia) { + // Skip trivia if required + Libadalang.Token res = receiver.next(); if (ignoreTrivia) { while (!res.isNone() && res.triviaIndex != 0) { res = res.next(); } } + // Return the result return res; } + + @Fallback + protected Libadalang.Token onInvalid( + @SuppressWarnings("unused") Libadalang.Token receiver, Object ignoreTrivia) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_BOOLEAN, + LKQLTypesHelper.fromJava(ignoreTrivia), + this.body.argNode(0)); + } } /** Expression of the "previous" method. */ - public static final class PrevExpr extends AbstractBuiltInFunctionBody { - @Override - public Object executeGeneric(VirtualFrame frame) { - // Get if the trivia tokens should be ignored - boolean ignoreTrivia; - try { - ignoreTrivia = LKQLTypeSystemGen.expectBoolean(frame.getArguments()[1]); - } catch (UnexpectedResultException e) { - throw LKQLRuntimeException.wrongType( - LKQLTypesHelper.LKQL_BOOLEAN, - LKQLTypesHelper.fromJava(e.getResult()), - this.callNode.getArgList().getArgs()[1]); - } + abstract static class PrevExpr extends SpecializedBuiltInBody.SpecializedBuiltInNode { + + public abstract Libadalang.Token executePrev( + Libadalang.Token receiver, Object ignoreTrivia); - Libadalang.Token res = LKQLTypeSystemGen.asToken(frame.getArguments()[0]).previous(); + @Specialization + protected Libadalang.Token onValid(Libadalang.Token receiver, boolean ignoreTrivia) { + // Skip trivia if required + Libadalang.Token res = receiver.previous(); if (ignoreTrivia) { while (!res.isNone() && res.triviaIndex != 0) { res = res.previous(); } } + // Return the result return res; } + + @Fallback + protected Libadalang.Token onInvalid( + @SuppressWarnings("unused") Libadalang.Token receiver, Object ignoreTrivia) { + throw LKQLRuntimeException.wrongType( + LKQLTypesHelper.LKQL_BOOLEAN, + LKQLTypesHelper.fromJava(ignoreTrivia), + this.body.argNode(0)); + } } /** Expression of the "unit" method. */ diff --git a/testsuite/tests/interpreter/sublist_builtin/test.out b/testsuite/tests/interpreter/sublist_builtin/test.out index b81307611..81aa71f74 100644 --- a/testsuite/tests/interpreter/sublist_builtin/test.out +++ b/testsuite/tests/interpreter/sublist_builtin/test.out @@ -2,7 +2,7 @@ 3 [2, 3] [1, 2, 3, 4, 5] -script.lkql:6:7: error: Invalid index: 0 +script.lkql:6:17: error: Invalid index: 0 6 | print(l.sublist(0, 5)) - | ^^^^^^^^^^^^^^^ + | ^ in sublist (called at script.lkql:6:7)