diff --git a/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala b/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala index 6e1e90fa8..c599c86a5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala @@ -74,7 +74,7 @@ class MLsCompiler(preludeFile: os.Path): val parsed = mainParse.resultBlk val (blk, newCtx) = elab.importFrom(parsed) val low = ltl.givenIn: - codegen.Lowering(lowerHandlers = false) + codegen.Lowering(lowerHandlers = false, stackLimit = None) // TODO: properly hook up stack limit val jsb = codegen.js.JSBuilder() val le = low.program(blk) val baseScp: utils.Scope = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index feaf4ad4a..b17ce159c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -14,7 +14,6 @@ import semantics.* import semantics.Term.* import sem.Elaborator.State - case class Program( imports: Ls[Local -> Str], main: Block, @@ -99,62 +98,25 @@ sealed abstract class Block extends Product with AutoLocated: case HandleBlock(_, _, _, _, handlers, body, rest) => handlers.map(_.body) :+ body :+ rest case _: Return | _: Throw | _: Label | _: Break | _: Continue | _: End | _: HandleBlockReturn => Nil + // Moves definitions in a block to the top. Only scans the top-level definitions of the block; + // i.e, definitions inside other definitions are not moved out. Definitions inside `match`/`if` + // and `while` statements are moved out. + // + // Note that this returns the definitions in reverse order, with the bottommost definiton appearing + // last. This is so that using defns.foldLeft later to add the definitions to the front of a block, + // we don't need to reverse the list again to preserve the order of the definitions. def floatOutDefns = - def rec(b: Block, acc: List[Defn]): (Block, List[Defn]) = b match - case Match(scrut, arms, dflt, rest) => - val (armsRes, armsDefns) = arms.foldLeft[(List[(Case, Block)], List[Defn])](Nil, acc)( - (accc, d) => - val (accCases, accDefns) = accc - val (cse, blk) = d - val (resBlk, resDefns) = rec(blk, accDefns) - ((cse, resBlk) :: accCases, resDefns) - ) - dflt match - case None => - val (rstRes, rstDefns) = rec(rest, armsDefns) - (Match(scrut, armsRes, None, rstRes), rstDefns) - - case Some(dflt) => - val (dfltRes, dfltDefns) = rec(dflt, armsDefns) - val (rstRes, rstDefns) = rec(rest, dfltDefns) - (Match(scrut, armsRes, S(dfltRes), rstRes), rstDefns) - - case Return(res, implct) => (b, acc) - case Throw(exc) => (b, acc) - case Label(label, body, rest) => - val (bodyRes, bodyDefns) = rec(body, acc) - val (rstRes, rstDefns) = rec(rest, bodyDefns) - (Label(label, bodyRes, rstRes), rstDefns) - case Break(label) => (b, acc) - case Continue(label) => (b, acc) - case Begin(sub, rest) => - val (subRes, subDefns) = rec(sub, acc) - val (rstRes, rstDefns) = rec(rest, subDefns) - (Begin(subRes, rstRes), rstDefns) - case TryBlock(sub, finallyDo, rest) => - val (subRes, subDefns) = rec(sub, acc) - val (finallyRes, finallyDefns) = rec(rest, subDefns) - val (rstRes, rstDefns) = rec(rest, finallyDefns) - (TryBlock(subRes, finallyRes, rstRes), rstDefns) - case Assign(lhs, rhs, rest) => - val (rstRes, rstDefns) = rec(rest, acc) - (Assign(lhs, rhs, rstRes), rstDefns) - case a @ AssignField(path, nme, result, rest) => - val (rstRes, rstDefns) = rec(rest, acc) - (AssignField(path, nme, result, rstRes)(a.symbol), rstDefns) - case Define(defn, rest) => defn match - case ValDefn(owner, k, sym, rhs) => - val (rstRes, rstDefns) = rec(rest, acc) - (Define(defn, rstRes), rstDefns) - case _ => - val (rstRes, rstDefns) = rec(rest, defn :: acc) - (rstRes, rstDefns) - case HandleBlock(lhs, res, par, cls, handlers, body, rest) => - val (rstRes, rstDefns) = rec(rest, acc) - (HandleBlock(lhs, res, par, cls, handlers, body, rstRes), rstDefns) - case HandleBlockReturn(res) => (b, acc) - case End(msg) => (b, acc) - rec(this, Nil) + var defns: List[Defn] = Nil + val transformer = new BlockTransformerShallow(SymbolSubst()): + override def applyBlock(b: Block): Block = b match + case Define(defn, rest) => defn match + case v: ValDefn => super.applyBlock(b) + case _ => + defns ::= defn + applyBlock(rest) + case _ => super.applyBlock(b) + + (transformer.applyBlock(this), defns) end Block @@ -325,3 +287,5 @@ def blockBuilder: Block => Block = identity extension (l: Local) def asPath: Path = Value.Ref(l) + + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index c0d1409cb..8f2b8c85e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -27,6 +27,15 @@ object HandlerLowering: private case class LinkState(res: Path, cls: Path, uid: StateId) + // isHandleFree: whether the current block is inside a function or top level directly free of any handler in scope + // isTopLevel: + // whether the current block is the top level block, as we do not emit code for continuation class on the top level + // since we cannot return an effect signature on the top level (we are not in a function so return statement are invalid) + // and we do not have any `return` statement in the top level block so we do not need the `ReturnCont` workarounds. + // ctorThis: the path to `this` in the constructor, this is used to insert `return this;` at the end of constructor. + // linkAndHandle: + // a function that takes a LinkState and returns a block that links the continuation class and handles the effect + // this is a convenience function which initializes the continuation class in function context or throw an error in top level private case class HandlerCtx( isHandleFree: Bool, isTopLevel: Bool, @@ -277,7 +286,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): */ private def translateBlock(b: Block, h: HandlerCtx): Block = - given HandlerCtx = h + given HandlerCtx = h val stage1 = firstPass(b) val stage2 = secondPass(stage1) if h.isTopLevel then stage2 else thirdPass(stage2) @@ -288,25 +297,29 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): case b: HandleBlock => val rest = applyBlock(b.rest) translateHandleBlock(b.copy(rest = rest)) - case Return(c: Call, implct) if handlerCtx.isHandleFree => - val fun2 = applyPath(c.fun) - val args2 = c.args.map(applyArg) - val c2 = if (fun2 is c.fun) && (args2 zip c.args).forall(_ is _) then c else Call(fun2, args2)(c.isMlsFun) - if c2 is c then b else Return(c2, implct) + // This block optimizes tail-calls in the handler transformation. We do not optimize implicit returns. + // Implicit returns are used in top level and constructor: + // For top level, this correspond to the last statement which should also be checked for effect. + // For constructor, we will append `return this;` after the implicit return so it is not a tail call. + case Return(c @ Call(fun, args), false) if handlerCtx.isHandleFree => + val fun2 = applyPath(fun) + val args2 = args.mapConserve(applyArg) + val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun) + if c2 is c then b else Return(c2, false) case _ => super.applyBlock(b) override def applyResult2(r: Result)(k: Result => Block): Block = r match case r @ Call(Value.Ref(_: BuiltinSymbol), _) => super.applyResult2(r)(k) case c @ Call(fun, args) => val res = freshTmp("res") val fun2 = applyPath(fun) - val args2 = c.args.map(applyArg) - val c2 = if (fun2 is fun) && (args2 zip args).forall(_ is _) then c else Call(fun2, args2)(c.isMlsFun) + val args2 = args.mapConserve(applyArg) + val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun) ResultPlaceholder(res, freshId(), false, c2, k(Value.Ref(res))) case c @ Instantiate(cls, args) => val res = freshTmp("res") val cls2 = applyPath(cls) - val args2 = c.args.map(applyPath) - val c2 = if (cls2 is cls) && (args2 zip args).forall(_ is _) then c else Instantiate(cls2, args2) + val args2 = args.mapConserve(applyPath) + val c2 = if (cls2 is cls) && (args2 is args) then c else Instantiate(cls2, args2) ResultPlaceholder(res, freshId(), false, c2, k(Value.Ref(res))) case r => super.applyResult2(r)(k) override def applyLam(lam: Value.Lam): Value.Lam = Value.Lam(lam.params, translateBlock(lam.body, functionHandlerCtx)) @@ -327,89 +340,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): // to ensure the fun and class references in the continuation class are properly scoped, // we move all function defns to the top level of the handler block val (blk, defns) = b.floatOutDefns - val clsDefns = defns.collect: - case ClsLikeDefn(own, isym, sym, k, paramsOpt, parentPath, methods, privateFields, publicFields, preCtor, ctor) => sym - - val funDefns = defns.collect: - case FunDefn(own, sym, params, body) => sym - - def getBms = - var l: List[BlockMemberSymbol] = Nil - val subst = new SymbolSubst: - override def mapBlockMemberSym(b: BlockMemberSymbol) = - l = b :: l - b - BlockTransformer(subst).applyBlock(b) - l - - val toConvert = getBms - .map: b => - val clsDefn = b.asCls - val modDefn = b.asMod - // check if this BlockMemberSymbol belongs to a definition in this block - val isThisBlock = clsDefn match - case None => modDefn match - case None => false - case Some(value) => clsDefns.contains(value) - case Some(value) => clsDefns.contains(value) - if isThisBlock then Some(b) - else None - .collect: - case Some(b) => b - - val fnBmsMap = funDefns - .map: b => - b -> BlockMemberSymbol(b.nme, b.trees) - .toMap - - val clsBmsMap = toConvert - .map: b => - b -> BlockMemberSymbol(b.nme, b.trees) - .toMap - - val bmsMap = (fnBmsMap ++ clsBmsMap).toMap - - val clsMap = clsBmsMap - .map: - case b1 -> b2 => b1.asCls match - case Some(value) => - val newSym = ClassSymbol(value.tree, Tree.Ident(b2.nme)) - newSym.defn = value.defn - S(value -> newSym) - case None => None - .collect: - case Some(x) => x - .toMap - - val modMap = clsBmsMap - .map: - case b1 -> b2 => b1.asMod match - case Some(value) => - val newSym = ModuleSymbol(value.tree, Tree.Ident(b2.nme)) - newSym.defn = value.defn - S(value -> newSym) - case None => None - .collect: - case Some(x) => x - .toMap - - val newBlk = defns.foldLeft(blk)((acc, defn) => Define(defn, acc)) - - val subst = new SymbolSubst: - override def mapBlockMemberSym(b: BlockMemberSymbol) = bmsMap.get(b) match - case None => b.asCls match - case None => b - case Some(cls) => - clsMap.get(cls) match - case None => b - case Some(sym) => - BlockMemberSymbol(sym.nme, b.trees) // TODO: properly map trees - case Some(value) => value - override def mapClsSym(s: ClassSymbol): ClassSymbol = clsMap.get(s).getOrElse(s) - override def mapModuleSym(s: ModuleSymbol): ModuleSymbol = modMap.get(s).getOrElse(s) - override def mapTermSym(s: TermSymbol): TermSymbol = TermSymbol(s.k, s.owner.map(_.subst(using this)), s.id) - - BlockTransformer(subst).applyBlock(newBlk) + defns.foldLeft(blk)((acc, defn) => Define(defn, acc)) private def translateFun(f: FunDefn): FunDefn = FunDefn(f.owner, f.sym, f.params, translateBlock(f.body, functionHandlerCtx)) @@ -455,7 +386,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): h.cls, BlockMemberSymbol(h.cls.id.name, Nil), syntax.Cls, - h.cls.defn.get.paramsOpt, + N, S(h.par), handlers, Nil, Nil, Assign(freshTmp(), SimpleCall(Value.Ref(State.builtinOpsMap("super")), Nil), End()), End()) @@ -478,15 +409,6 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): ) val pcVar = VarSymbol(Tree.Ident("pc")) - clsSym.defn = S(ClassDef( - N, - syntax.Cls, - clsSym, - BlockMemberSymbol(clsSym.nme, Nil), - Nil, - S(PlainParamList(Param(FldFlags.empty, pcVar, N) :: Nil)), - ObjBody(Term.Blk(Nil, Term.Lit(Tree.UnitLit(true)))), - List())) var trivial = true def prepareBlock(b: Block): Block = @@ -582,7 +504,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx): clsSym, BlockMemberSymbol(clsSym.nme, Nil), syntax.Cls, - clsSym.defn.get.paramsOpt, + S(PlainParamList(Param(FldFlags.empty, pcVar, N) :: Nil)), S(contClsPath), resumeFnDef :: Nil, Nil, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 7ca6c9b43..4c1d9e2f5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -49,7 +49,7 @@ end Subst import Subst.subst -class Lowering(lowerHandlers: Bool)(using TL, Raise, State, Ctx): +class Lowering(lowerHandlers: Bool, stackLimit: Option[Int])(using TL, Raise, State, Ctx): def returnedTerm(t: st)(using Subst): Block = term(t)(Ret) @@ -503,8 +503,12 @@ class Lowering(lowerHandlers: Bool)(using TL, Raise, State, Ctx): def topLevel(t: st): Block = val res = term(t)(ImplctRet)(using Subst.empty) - if lowerHandlers then HandlerLowering().translateTopLevel(res) - else res + val stackSafe = stackLimit match + case None => res + case Some(lim) => StackSafeTransform(lim).transformTopLevel(res) + + if lowerHandlers then HandlerLowering().translateTopLevel(stackSafe) + else stackSafe def program(main: st): Program = def go(acc: Ls[Local -> Str], trm: st): Program = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala new file mode 100644 index 000000000..7baab171d --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -0,0 +1,169 @@ +package hkmc2 + +import mlscript.utils.*, shorthands.* +import utils.* + +import hkmc2.codegen.* +import hkmc2.semantics.Elaborator.State +import hkmc2.semantics.* +import hkmc2.syntax.Tree + +class StackSafeTransform(depthLimit: Int)(using State): + private val STACK_LIMIT_IDENT: Tree.Ident = Tree.Ident("__stackLimit") + private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("__stackDepth") + private val STACK_OFFSET_IDENT: Tree.Ident = Tree.Ident("__stackOffset") + private val STACK_HANDLER_IDENT: Tree.Ident = Tree.Ident("__stackHandler") + + private val predefPath: Path = State.globalThisSymbol.asPath.selN(Tree.Ident("Predef")) + private val stackDelayClsPath: Path = predefPath.selN(Tree.Ident("__StackDelay")).selN(Tree.Ident("class")) + private val stackLimitPath: Path = predefPath.selN(STACK_LIMIT_IDENT) + private val stackDepthPath: Path = predefPath.selN(STACK_DEPTH_IDENT) + private val stackOffsetPath: Path = predefPath.selN(STACK_OFFSET_IDENT) + private val stackHandlerPath: Path = predefPath.selN(STACK_HANDLER_IDENT) + + private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n)) + + private def op(op: String, a: Path, b: Path) = + Call(State.builtinOpsMap(op).asPath, a.asArg :: b.asArg :: Nil)(true) + + // Increases the stack depth, assigns the call to a value, then decreases the stack depth + // then binds that value to a desired block + def extractRes(res: Result, isTailCall: Bool, f: Result => Block) = + if isTailCall then + blockBuilder + .assignFieldN(predefPath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1))) + .ret(res) + else + val tmp = TempSymbol(None, "tmp") + val prevDepth = TempSymbol(None, "prevDepth") + blockBuilder + .assign(prevDepth, stackDepthPath) + .assignFieldN(predefPath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1))) + .assign(tmp, res) + .assignFieldN(predefPath, STACK_DEPTH_IDENT, prevDepth.asPath) + .rest(f(tmp.asPath)) + + def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block) = + val resumeSym = VarSymbol(Tree.Ident("resume")) + val handlerSym = TempSymbol(None, "stackHandler") + val resSym = TempSymbol(None, "res") + val handlerRes = TempSymbol(None, "res") + val curOffsetSym = TempSymbol(None, "curOffset") + + val clsSym = ClassSymbol( + Tree.TypeDef(syntax.Cls, Tree.Error(), N, N), + Tree.Ident("StackDelay$") + ) + + // the global stack handler is created here + HandleBlock( + handlerSym, resSym, + stackDelayClsPath, clsSym, + Handler( + BlockMemberSymbol("perform", Nil), resumeSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil, + /* + fun perform() = + let curOffset = stackOffset + stackOffset = stackDepth + let ret = resume() + stackOffset = curOffset + ret + */ + blockBuilder + .assign(curOffsetSym, stackOffsetPath) + .assignFieldN(predefPath, STACK_OFFSET_IDENT, stackDepthPath) + .assign(handlerRes, Call(Value.Ref(resumeSym), Nil)(true)) + .assignFieldN(predefPath, STACK_OFFSET_IDENT, curOffsetSym.asPath) + .ret(handlerRes.asPath) + ) :: Nil, + blockBuilder + .assignFieldN(predefPath, STACK_LIMIT_IDENT, intLit(depthLimit)) // set stackLimit before call + .assignFieldN(predefPath, STACK_DEPTH_IDENT, intLit(1)) // set stackDepth = 1 before call + .assignFieldN(predefPath, STACK_HANDLER_IDENT, handlerSym.asPath) // assign stack handler + .rest(HandleBlockReturn(res)), + blockBuilder // reset the stack safety values + .assignFieldN(predefPath, STACK_DEPTH_IDENT, intLit(0)) // set stackDepth = 0 after call + .assignFieldN(predefPath, STACK_HANDLER_IDENT, Value.Lit(Tree.UnitLit(false))) // set stackHandler = null + .rest(f(resSym.asPath)) + ) + + // Rewrites anything that can contain a Call to increase the stack depth + def transform(b: Block, isTopLevel: Bool = false): Block = + def usesStack(r: Result) = r match + case Call(Value.Ref(_: BuiltinSymbol), _) => false + case _: Call | _: Instantiate => true + case _ => false + + val extract = if isTopLevel then extractResTopLevel else extractRes + + val transform = new BlockTransformer(SymbolSubst()): + + override def applyFunDefn(fun: FunDefn): FunDefn = rewriteFn(fun) + + override def applyDefn(defn: Defn): Defn = defn match + case defn: ClsLikeDefn => rewriteCls(defn) + case _: FunDefn | _: ValDefn => super.applyDefn(defn) + + override def applyBlock(b: Block): Block = b match + case Return(res, implct) if usesStack(res) => + extract(applyResult(res), true, Return(_, implct)) + case _ => super.applyBlock(b) + + override def applyResult2(r: Result)(k: Result => Block): Block = + if usesStack(r) then + extract(r, false, k) + else + super.applyResult2(r)(k) + + override def applyLam(lam: Value.Lam): Value.Lam = + Value.Lam(lam.params, rewriteBlk(lam.body)) + + transform.applyBlock(b) + + def isTrivial(b: Block): Boolean = + var trivial = true + val walker = new BlockTransformerShallow(SymbolSubst()): + override def applyResult(r: Result): Result = r match + case Call(Value.Ref(_: BuiltinSymbol), _) => r + case _: Call | _: Instantiate => trivial = false; r + case _ => r + walker.applyBlock(b) + trivial + + def rewriteCls(defn: ClsLikeDefn): ClsLikeDefn = + val ClsLikeDefn(owner, isym, sym, k, paramsOpt, + parentPath, methods, privateFields, publicFields, preCtor, ctor) = defn + ClsLikeDefn( + owner, isym, sym, k, paramsOpt, parentPath, methods.map(rewriteFn), privateFields, + publicFields, rewriteBlk(preCtor), rewriteBlk(ctor) + ) + + def rewriteBlk(blk: Block) = + val newBody = transform(blk) + + if isTrivial(blk) then + newBody + else + val diffSym = TempSymbol(None, "diff") + val diffGeqLimitSym = TempSymbol(None, "diffGeqLimit") + val handlerExistsSym = TempSymbol(None, "handlerExists") + val scrutSym = TempSymbol(None, "scrut") + val diff = op("-", stackDepthPath, stackOffsetPath) + val diffGeqLimit = op(">=", diffSym.asPath, stackLimitPath) + val handlerExists = op("!==", stackHandlerPath, Value.Lit(Tree.UnitLit(false))) + val scrutVal = op("&&", diffGeqLimitSym.asPath, handlerExistsSym.asPath) + blockBuilder + .assign(diffSym, diff) // diff = stackDepth - stackOffset + .assign(diffGeqLimitSym, diffGeqLimit) // diff >= depthLimit + .assign(handlerExistsSym, handlerExists) // stackHandler !== null + .assign(scrutSym, scrutVal) // diff >= depthLimit && stackHandler !== null + .ifthen( + scrutSym.asPath, Case.Lit(Tree.BoolLit(true)), + blockBuilder.assign( // dummy = perform(undefined) (is called `dummy` as the value is not used) + TempSymbol(None, "dummy"), + Call(Select(stackHandlerPath, Tree.Ident("perform"))(N), Nil)(true)).end) + .rest(newBody) + + def rewriteFn(defn: FunDefn) = FunDefn(defn.owner, defn.sym, defn.params, rewriteBlk(defn.body)) + + def transformTopLevel(b: Block) = transform(b, true) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 3a6bd6b54..dc9ed4da0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -40,10 +40,18 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder: case Argument case Operand(prec: Int) - def err(errMsg: Message)(using Raise, Scope): Document = + def mkErr(errMsg: Message)(using Raise, Scope): Document = + doc"throw globalThis.Error(${result(Value.Lit(syntax.Tree.StrLit(errMsg.show)))})" + + def errExpr(errMsg: Message)(using Raise, Scope): Document = + raise(ErrorReport(errMsg -> N :: Nil, + source = Diagnostic.Source.Compilation)) + doc"(()=>{${mkErr(errMsg)}})()" + + def errStmt(errMsg: Message)(using Raise, Scope): Document = raise(ErrorReport(errMsg -> N :: Nil, source = Diagnostic.Source.Compilation)) - doc"(()=>{throw globalThis.Error(${result(Value.Lit(syntax.Tree.StrLit(errMsg.show)))})})()" + doc" # ${mkErr(errMsg)};" def getVar(l: Local)(using Raise, Scope): Document = l match case ts: semantics.TermSymbol => @@ -75,24 +83,24 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder: case Value.Lit(lit) => lit.idStr case Value.Ref(l: BuiltinSymbol) => if l.nullary then l.nme - else err(msg"Illegal reference to builtin symbol '${l.nme}'") + else errExpr(msg"Illegal reference to builtin symbol '${l.nme}'") case Value.Ref(l) => getVar(l) case Call(Value.Ref(l: BuiltinSymbol), lhs :: rhs :: Nil) if !l.functionLike => if l.binary then val res = doc"${operand(lhs)} ${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res - else err(msg"Cannot call non-binary builtin symbol '${l.nme}'") + else errExpr(msg"Cannot call non-binary builtin symbol '${l.nme}'") case Call(Value.Ref(l: BuiltinSymbol), rhs :: Nil) if !l.functionLike => if l.unary then val res = doc"${l.nme} ${operand(rhs)}" if needsParens(l.nme) then doc"(${res})" else res - else err(msg"Cannot call non-unary builtin symbol '${l.nme}'") + else errExpr(msg"Cannot call non-unary builtin symbol '${l.nme}'") case Call(Value.Ref(l: BuiltinSymbol), args) => if l.functionLike then val argsDoc = args.map(argument).mkDocument(", ") doc"${l.nme}(${argsDoc})" - else err(msg"Illegal arity for builtin symbol '${l.nme}'") + else errExpr(msg"Illegal arity for builtin symbol '${l.nme}'") case Call(s @ Select(_, id), lhs :: rhs :: Nil) => Elaborator.ctx.Builtins.getBuiltinOp(id.name) match @@ -122,7 +130,8 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder: case Value.Arr(es) => doc"[ #{ # ${es.map(argument).mkDocument(doc", # ")} #} # ]" def returningTerm(t: Block)(using Raise, Scope): Document = t match - case _: (HandleBlockReturn | HandleBlock) => die + case _: (HandleBlockReturn | HandleBlock) => + errStmt(msg"This code requires effect handler instrumentation but was compiled without it.") case Assign(l, r, rst) => doc" # ${getVar(l)} = ${result(r)};${returningTerm(rst)}" case AssignField(p, n, r, rst) => diff --git a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs index dcbae6033..983aca7c0 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs @@ -127,6 +127,15 @@ const Predef$class = class Predef { } toString() { return "__Return(" + this.value + ")"; } }; + this.__stackLimit = 0; + this.__stackDepth = 0; + this.__stackOffset = 0; + this.__stackHandler = null; + this.__StackDelay = function __StackDelay() { return new __StackDelay.class(); }; + this.__StackDelay.class = class __StackDelay { + constructor() {} + toString() { return "__StackDelay(" + + ")"; } + }; } id(x) { return x; @@ -411,7 +420,7 @@ const Predef$class = class Predef { } __resume(cur2, tail) { return (value) => { - let scrut, cont, scrut1, scrut2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6; + let scrut, cont, scrut1, scrut2, scrut3, scrut4, scrut5, scrut6, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9; scrut = cur2.resumed; if (scrut === true) { throw globalThis.Error("Multiple resumption"); @@ -420,86 +429,121 @@ const Predef$class = class Predef { } cur2.resumed = true; cont = cur2.next; - tmp7: while (true) { + tmp10: while (true) { if (cont instanceof this.__Cont.class) { tmp1 = cont.resume(value) ?? null; value = tmp1; if (value instanceof this.__EffectSig.class) { - value.tail = tail; - scrut1 = cur2.handleBlockList.next !== null; + scrut1 = value.tail.next !== cont; if (scrut1 === true) { + scrut2 = cont.next !== null; + if (scrut2 === true) { + scrut3 = value.tail.next !== null; + if (scrut3 === true) { + throw globalThis.Error("Internal Error: unexpected continuation"); + } else { + tmp2 = null; + } + } else { + tmp2 = null; + } + tmp3 = tmp2; + } else { + tmp3 = null; + } + scrut4 = value.tail.next === null; + if (scrut4 === true) { + value.tail.next = cont.next; + tmp4 = null; + } else { + tmp4 = null; + } + value.tail = tail; + scrut5 = cur2.handleBlockList.next !== null; + if (scrut5 === true) { value.handleBlockList.tail.next = cur2.handleBlockList.next; value.handleBlockList.tail = cur2.handleBlockList.tail; - tmp2 = null; + tmp5 = null; } else { - tmp2 = null; + tmp5 = null; } return value; } else { cont = cont.next; - tmp3 = null; + tmp6 = null; } - tmp4 = tmp3; - continue tmp7; + tmp7 = tmp6; + continue tmp10; } else { - tmp4 = null; + tmp7 = null; } break; } - scrut2 = cur2.handleBlockList.next === null; - if (scrut2 === true) { + scrut6 = cur2.handleBlockList.next === null; + if (scrut6 === true) { return value; } else { - tmp5 = this.__resumeHandleBlocks(cur2.handleBlockList.next, cur2.handleBlockList.tail, value); - cur2 = tmp5; + tmp8 = this.__resumeHandleBlocks(cur2.handleBlockList.next, cur2.handleBlockList.tail, value); + cur2 = tmp8; if (cur2 instanceof this.__EffectSig.class) { cur2.tail = tail; - tmp6 = null; + tmp9 = null; } else { - tmp6 = null; + tmp9 = null; } return cur2; } }; } __resumeHandleBlocks(handleBlock, tailHandleBlock, value) { - let scrut, scrut1, scrut2, tmp, tmp1, tmp2, tmp3; - tmp4: while (true) { + let scrut, scrut1, scrut2, scrut3, scrut4, tmp, tmp1, tmp2, tmp3, tmp4; + tmp5: while (true) { scrut1 = handleBlock.contHead.next; if (scrut1 instanceof this.__Cont.class) { tmp = handleBlock.contHead.next.resume(value) ?? null; value = tmp; if (value instanceof this.__EffectSig.class) { - scrut2 = handleBlock.contHead.next !== value.tail.next; + scrut2 = value.tail.next !== handleBlock.contHead.next; if (scrut2 === true) { - handleBlock.contHead.next = handleBlock.contHead.next.next; - tmp1 = null; + scrut3 = value.tail.next !== null; + if (scrut3 === true) { + throw globalThis.Error("Internal Error: unexpected continuation during handle block resumption"); + } else { + tmp1 = null; + } } else { tmp1 = null; } + scrut4 = value.tail.next !== handleBlock.contHead.next; + if (scrut4 === true) { + handleBlock.contHead.next = handleBlock.contHead.next.next; + tmp2 = null; + } else { + tmp2 = null; + } value.tail.next = null; value.handleBlockList.tail.next = handleBlock; value.handleBlockList.tail = tailHandleBlock; return value; } else { handleBlock.contHead.next = handleBlock.contHead.next.next; - tmp2 = null; + tmp3 = null; } - tmp3 = tmp2; - continue tmp4; + tmp4 = tmp3; + continue tmp5; } else { scrut = handleBlock.next; if (scrut instanceof this.__HandleBlock.class) { handleBlock = handleBlock.next; - tmp3 = null; - continue tmp4; + tmp4 = null; + continue tmp5; } else { return value; } } break; } - return tmp3; + return tmp4; } toString() { return "Predef"; } }; Predef1 = new Predef$class; diff --git a/hkmc2/shared/src/test/mlscript-compile/Predef.mls b/hkmc2/shared/src/test/mlscript-compile/Predef.mls index ec7f1e4ea..2b9bde9df 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Predef.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Predef.mls @@ -207,6 +207,13 @@ fun __resume(cur, tail)(value) = cont is __Cont do set value = cont.resume(value) if value is __EffectSig then + if value.tail.next !== cont do + if cont.next !== null and value.tail.next !== null do + throw Error("Internal Error: unexpected continuation") + // in tail call optimization, the continuation is not appended, we append it here + if value.tail.next === null do + // since it is a tail call, it is already completed so we append the continuation after it + set value.tail.next = cont.next // we are returning to handler, which performs unwinding so tail is needed set value.tail = tail if cur.handleBlockList.next !== null do @@ -233,8 +240,11 @@ fun __resumeHandleBlocks(handleBlock, tailHandleBlock, value) = // resuming tailHandlerList or post handler continuations set value = handleBlock.contHead.next.resume(value) if value is __EffectSig then + if value.tail.next !== handleBlock.contHead.next and value.tail.next !== null do + throw Error("Internal Error: unexpected continuation during handle block resumption") // this checks when continuation resume results in tail call to effectful func - if handleBlock.contHead.next !== value.tail.next do + // when a tail call happens, the continuation will not be appended so this will be null + if value.tail.next !== handleBlock.contHead.next do // if this is a tail call that results in effect, the continuation is already completed // and should be removed set handleBlock.contHead.next = handleBlock.contHead.next.next @@ -249,3 +259,11 @@ fun __resumeHandleBlocks(handleBlock, tailHandleBlock, value) = set handleBlock = handleBlock.next else return value + +// stack safety +val __stackLimit = 0 // How deep the stack can go before heapifying the stack +val __stackDepth = 0 // Tracks the virtual + real stack depth +val __stackOffset = 0 // How much to offset __stackDepth by to get the true stack depth (i.e. the virtual depth) +val __stackHandler = null +abstract class __StackDelay() with + fun perform() diff --git a/hkmc2/shared/src/test/mlscript/handlers/Effects.mls b/hkmc2/shared/src/test/mlscript/handlers/Effects.mls index 124df6b2b..d21f30665 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/Effects.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/Effects.mls @@ -158,7 +158,7 @@ if true do //│ JS (unsanitized): //│ let tmp11, handleBlock$23; //│ handleBlock$23 = function handleBlock$() { -//│ let h, scrut, tmp12, res, Cont$, f, Effect$h$; +//│ let h, scrut, f, tmp12, res, Cont$, Effect$h$; //│ Effect$h$ = class Effect$h$ extends Effect1 { //│ constructor() { //│ let tmp13; diff --git a/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls b/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls index d568c1d32..d417f13bd 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls @@ -15,7 +15,7 @@ class Lol(h) with //│ Lol1.class = class Lol { //│ constructor(h) { //│ this.h = h; -//│ let tmp, res, Cont$; +//│ let tmp, res, res1, Cont$; //│ const this$Lol = this; //│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); }; //│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class { @@ -27,14 +27,25 @@ class Lol(h) with //│ resume(value$) { //│ if (this.pc === 0) { //│ res = value$; +//│ } else if (this.pc === 1) { +//│ res1 = value$; //│ } //│ contLoop: while (true) { -//│ if (this.pc === 1) { +//│ if (this.pc === 2) { //│ return this$Lol; //│ } else if (this.pc === 0) { //│ tmp = res; +//│ res1 = Predef.print(tmp); +//│ if (res1 instanceof globalThis.Predef.__EffectSig.class) { +//│ res1.tail.next = this; +//│ this.pc = 1; +//│ return res1; +//│ } //│ this.pc = 1; //│ continue contLoop; +//│ } else if (this.pc === 1) { +//│ this.pc = 2; +//│ continue contLoop; //│ } //│ break; //│ } @@ -48,7 +59,13 @@ class Lol(h) with //│ return res; //│ } //│ tmp = res; -//│ Predef.print(tmp) +//│ res1 = Predef.print(tmp); +//│ if (res1 instanceof globalThis.Predef.__EffectSig.class) { +//│ res1.tail.next = new Cont$.class(1); +//│ res1.tail = res1.tail.next; +//│ return res1; +//│ } +//│ res1 //│ } //│ toString() { return "Lol(" + this.h + ")"; } //│ }; @@ -89,6 +106,7 @@ let oops = k("b") Lol(h) //│ > k +//│ > b //│ oops = Lol { h: Effect$h$ {} } oops.h diff --git a/hkmc2/shared/src/test/mlscript/handlers/Generators.mls b/hkmc2/shared/src/test/mlscript/handlers/Generators.mls index 7627b30f4..75acb3fd8 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/Generators.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/Generators.mls @@ -54,3 +54,31 @@ permutations_foreach([1, 2, 3], print) //│ > 2,3,1 //│ > 3,1,2 //│ > 3,2,1 + + +fun permutations_impl(gen, l1, l2) = + if l2 is + [f, ...t] do + handle genWithPrefix = Generator with + fun produce(result)(resume) = + result.unshift(f) + gen.produce(result) + let x = resume(()) + x + permutations_impl(genWithPrefix, [], l1.concat(t)) + l1.push(f) + permutations_impl(gen, l1, t) + [] and l1 is [] do + gen.produce([]) +fun permutations(gen, l) = + permutations_impl(gen, [], l) + +// FIXME: wrong code +let res = [] +handle gen = Generator with + fun produce(result)(resume) = + res.push(result) + let x = resume(()) + x +in permutations(gen, [1, 2, 3]) +//│ res = [ [ 1, 2, 3 ], [ 1, 3, 2 ] ] diff --git a/hkmc2/shared/src/test/mlscript/handlers/ManualStackSafety.mls b/hkmc2/shared/src/test/mlscript/handlers/ManualStackSafety.mls new file mode 100644 index 000000000..3ef7b374f --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/handlers/ManualStackSafety.mls @@ -0,0 +1,121 @@ +:js +:handler + +// * This file demonstrates the handler-based mechanism of our stack safety implementation +// * by manually applying the transformation of a recursive factorial function +// * defined through the Z combinator. +// * It is notably interesting in that it demonstrates the ability to preserve tail calls. +// * The original function can be found in `hkmc2/shared/src/test/mlscript/handlers/ZCombinator.mls` + + +fun test(n, stackLimit) = + let stackDepth = 0 + let stackOffset = 0 + abstract class StackDelay with + fun perform() + handle h = StackDelay with + fun perform()(resume) = + let curOffset = stackOffset + set stackOffset = stackDepth + console.log("resuming at offset:", curOffset) + let tmp = resume() + console.log("finished at offset:", curOffset) + set stackOffset = curOffset + tmp + fun selfApp(f) = + if stackDepth - stackOffset >= stackLimit do h.perform() + set stackDepth += 1 + f(f) + fun mkrec(g) = + if stackDepth - stackOffset >= stackLimit do h.perform() + fun a(self) = + if stackDepth - stackOffset >= stackLimit do h.perform() + fun b(y) = + if stackDepth - stackOffset >= stackLimit do h.perform() + let tmp = stackDepth + set stackDepth += 1 + let res = selfApp(self) + set stackDepth = tmp + set stackDepth += 1 + res(y) + set stackDepth += 1 + g(b) + set stackDepth += 1 + selfApp(a) + let fact = + fun a(self) = + if stackDepth - stackOffset >= stackLimit do h.perform() + fun b(x) = + if stackDepth - stackOffset >= stackLimit do h.perform() + if x == 0 then 1 else + console.log(stackDepth, stackOffset) + let tmp = stackDepth + set stackDepth += 1 + let prev = self(x - 1) + set stackDepth = tmp + console.log("resumed:", x) + x * prev + b + mkrec(a) + set stackDepth = 1 + let ans = fact(n) + set stackDepth = 0 + ans + +:expect 3628800 +test(10, 100) +//│ > 1 0 +//│ > 3 0 +//│ > 5 0 +//│ > 7 0 +//│ > 9 0 +//│ > 11 0 +//│ > 13 0 +//│ > 15 0 +//│ > 17 0 +//│ > 19 0 +//│ > resumed: 1 +//│ > resumed: 2 +//│ > resumed: 3 +//│ > resumed: 4 +//│ > resumed: 5 +//│ > resumed: 6 +//│ > resumed: 7 +//│ > resumed: 8 +//│ > resumed: 9 +//│ > resumed: 10 +//│ = 3628800 + +:expect 3628800 +test(10, 5) +//│ > 1 0 +//│ > resuming at offset: 0 +//│ > 3 5 +//│ > 5 5 +//│ > 7 5 +//│ > resuming at offset: 5 +//│ > 9 10 +//│ > 11 10 +//│ > resuming at offset: 10 +//│ > 13 15 +//│ > 15 15 +//│ > 17 15 +//│ > resuming at offset: 15 +//│ > 19 20 +//│ > resumed: 1 +//│ > resumed: 2 +//│ > resumed: 3 +//│ > resumed: 4 +//│ > resumed: 5 +//│ > resumed: 6 +//│ > resumed: 7 +//│ > resumed: 8 +//│ > resumed: 9 +//│ > resumed: 10 +//│ > finished at offset: 15 +//│ > finished at offset: 10 +//│ > finished at offset: 5 +//│ > finished at offset: 0 +//│ = 3628800 + + diff --git a/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls b/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls index 4080b9bc8..88dc9733e 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls @@ -126,7 +126,7 @@ str //│ scrut = true; //│ if (scrut === true) { //│ handleBlock$9 = function handleBlock$() { -//│ let h1, tmp8, Cont$, handleBlock$10, Effect$h1$; +//│ let h1, tmp8, handleBlock$10, Cont$, Effect$h1$; //│ Effect$h1$ = class Effect$h1$ extends Effect1 { //│ constructor() { //│ let tmp9; @@ -134,7 +134,7 @@ str //│ } //│ perform(arg) { //│ return globalThis.Predef.__mkEffect(h1, (k) => { -//│ let tmp9, tmp10, tmp11, res3, Cont$1; +//│ let tmp9, tmp10, tmp11, res5, Cont$1; //│ Cont$1 = function Cont$(pc1) { return new Cont$.class(pc1); }; //│ Cont$1.class = class Cont$ extends globalThis.Predef.__Cont.class { //│ constructor(pc) { @@ -144,11 +144,11 @@ str //│ } //│ resume(value$) { //│ if (this.pc === 5) { -//│ res3 = value$; +//│ res5 = value$; //│ } //│ contLoop: while (true) { //│ if (this.pc === 5) { -//│ tmp10 = res3; +//│ tmp10 = res5; //│ tmp11 = str + "A"; //│ str = tmp11; //│ return null; @@ -160,13 +160,13 @@ str //│ }; //│ tmp9 = str + "A"; //│ str = tmp9; -//│ res3 = k(arg) ?? null; -//│ if (res3 instanceof globalThis.Predef.__EffectSig.class) { -//│ res3.tail.next = new Cont$1.class(5); -//│ res3.tail = res3.tail.next; -//│ return res3; +//│ res5 = k(arg) ?? null; +//│ if (res5 instanceof globalThis.Predef.__EffectSig.class) { +//│ res5.tail.next = new Cont$1.class(5); +//│ res5.tail = res5.tail.next; +//│ return res5; //│ } -//│ tmp10 = res3; +//│ tmp10 = res5; //│ tmp11 = str + "A"; //│ str = tmp11; //│ return null; @@ -202,7 +202,7 @@ str //│ toString() { return "Cont$(" + this.pc + ")"; } //│ }; //│ handleBlock$10 = function handleBlock$() { -//│ let h2, tmp9, res3, res4, Cont$1, Effect$h2$; +//│ let h2, tmp9, res5, res6, Cont$1, Effect$h2$; //│ Effect$h2$ = class Effect$h2$ extends Effect1 { //│ constructor() { //│ let tmp10; @@ -210,7 +210,7 @@ str //│ } //│ perform(arg) { //│ return globalThis.Predef.__mkEffect(h2, (k) => { -//│ let tmp10, tmp11, tmp12, tmp13, tmp14, res5, Cont$2; +//│ let tmp10, tmp11, tmp12, tmp13, tmp14, res7, Cont$2; //│ Cont$2 = function Cont$(pc1) { return new Cont$.class(pc1); }; //│ Cont$2.class = class Cont$ extends globalThis.Predef.__Cont.class { //│ constructor(pc) { @@ -220,11 +220,11 @@ str //│ } //│ resume(value$) { //│ if (this.pc === 2) { -//│ res5 = value$; +//│ res7 = value$; //│ } //│ contLoop: while (true) { //│ if (this.pc === 2) { -//│ tmp12 = res5; +//│ tmp12 = res7; //│ tmp13 = str + "B"; //│ tmp14 = str + tmp13; //│ str = tmp14; @@ -238,13 +238,13 @@ str //│ tmp10 = str + "B"; //│ tmp11 = str + tmp10; //│ str = tmp11; -//│ res5 = k(arg) ?? null; -//│ if (res5 instanceof globalThis.Predef.__EffectSig.class) { -//│ res5.tail.next = new Cont$2.class(2); -//│ res5.tail = res5.tail.next; -//│ return res5; +//│ res7 = k(arg) ?? null; +//│ if (res7 instanceof globalThis.Predef.__EffectSig.class) { +//│ res7.tail.next = new Cont$2.class(2); +//│ res7.tail = res7.tail.next; +//│ return res7; //│ } -//│ tmp12 = res5; +//│ tmp12 = res7; //│ tmp13 = str + "B"; //│ tmp14 = str + tmp13; //│ str = tmp14; @@ -263,41 +263,41 @@ str //│ } //│ resume(value$) { //│ if (this.pc === 0) { -//│ res3 = value$; +//│ res5 = value$; //│ } else if (this.pc === 1) { -//│ res4 = value$; +//│ res6 = value$; //│ } //│ contLoop: while (true) { //│ if (this.pc === 0) { -//│ tmp9 = res3; -//│ res4 = h1.perform(null) ?? null; -//│ if (res4 instanceof globalThis.Predef.__EffectSig.class) { -//│ res4.tail.next = this; +//│ tmp9 = res5; +//│ res6 = h1.perform(null) ?? null; +//│ if (res6 instanceof globalThis.Predef.__EffectSig.class) { +//│ res6.tail.next = this; //│ this.pc = 1; -//│ return res4; +//│ return res6; //│ } //│ this.pc = 1; //│ continue contLoop; //│ } else if (this.pc === 1) { -//│ return res4; +//│ return res6; //│ } //│ break; //│ } //│ } //│ toString() { return "Cont$(" + this.pc + ")"; } //│ }; -//│ res3 = h2.perform(null) ?? null; -//│ if (res3 instanceof globalThis.Predef.__EffectSig.class) { -//│ res3.tail.next = new Cont$1(0); -//│ return globalThis.Predef.__handleBlockImpl(res3, h2); +//│ res5 = h2.perform(null) ?? null; +//│ if (res5 instanceof globalThis.Predef.__EffectSig.class) { +//│ res5.tail.next = new Cont$1(0); +//│ return globalThis.Predef.__handleBlockImpl(res5, h2); //│ } -//│ tmp9 = res3; -//│ res4 = h1.perform(null) ?? null; -//│ if (res4 instanceof globalThis.Predef.__EffectSig.class) { -//│ res4.tail.next = new Cont$1(1); -//│ return globalThis.Predef.__handleBlockImpl(res4, h2); +//│ tmp9 = res5; +//│ res6 = h1.perform(null) ?? null; +//│ if (res6 instanceof globalThis.Predef.__EffectSig.class) { +//│ res6.tail.next = new Cont$1(1); +//│ return globalThis.Predef.__handleBlockImpl(res6, h2); //│ } -//│ return res4; +//│ return res6; //│ }; //│ tmp8 = handleBlock$10(); //│ if (tmp8 instanceof globalThis.Predef.__EffectSig.class) { diff --git a/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls new file mode 100644 index 000000000..e6444f1ce --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls @@ -0,0 +1,478 @@ +:js + +// sanity check +:expect 5050 +fun sum(n) = + if n == 0 then 0 + else + n + sum(n - 1) +sum(100) +//│ = 5050 + +// preserve tail calls +// MUST see "return hi1(tmp)" in the output +:stackSafe 5 +:handler +:expect 0 +:sjs +fun hi(n) = + if n == 0 then 0 + else hi(n - 1) +hi(0) +//│ JS (unsanitized): +//│ let hi1, res, handleBlock$1; +//│ hi1 = function hi(n) { +//│ let scrut, tmp, diff, diffGeqLimit, handlerExists, scrut1, dummy, res1, Cont$; +//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp1; +//│ tmp1 = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 0) { +//│ res1 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 2) { +//│ scrut = n == 0; +//│ if (scrut === true) { +//│ return 0; +//│ } else { +//│ tmp = n - 1; +//│ globalThis.Predef.__stackDepth = globalThis.Predef.__stackDepth + 1; +//│ return hi1(tmp); +//│ } +//│ this.pc = 1; +//│ continue contLoop; +//│ } else if (this.pc === 1) { +//│ break contLoop; +//│ } else if (this.pc === 0) { +//│ dummy = res1; +//│ this.pc = 2; +//│ continue contLoop; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ diff = globalThis.Predef.__stackDepth - globalThis.Predef.__stackOffset; +//│ diffGeqLimit = diff >= globalThis.Predef.__stackLimit; +//│ handlerExists = globalThis.Predef.__stackHandler !== undefined; +//│ scrut1 = diffGeqLimit && handlerExists; +//│ if (scrut1 === true) { +//│ res1 = globalThis.Predef.__stackHandler.perform(); +//│ if (res1 instanceof globalThis.Predef.__EffectSig.class) { +//│ res1.tail.next = new Cont$.class(0); +//│ res1.tail = res1.tail.next; +//│ return res1; +//│ } +//│ dummy = res1; +//│ } +//│ scrut = n == 0; +//│ if (scrut === true) { +//│ return 0; +//│ } else { +//│ tmp = n - 1; +//│ globalThis.Predef.__stackDepth = globalThis.Predef.__stackDepth + 1; +//│ return hi1(tmp); +//│ } +//│ }; +//│ handleBlock$1 = function handleBlock$() { +//│ let stackHandler, res1, Cont$, StackDelay$; +//│ StackDelay$ = class StackDelay$ extends globalThis.Predef.__StackDelay.class { +//│ constructor() { +//│ let tmp; +//│ tmp = super(); +//│ } +//│ perform() { +//│ return globalThis.Predef.__mkEffect(stackHandler, (resume) => { +//│ let res2, curOffset, res3, Cont$1; +//│ Cont$1 = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$1.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp; +//│ tmp = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 4) { +//│ res3 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 4) { +//│ res2 = res3; +//│ globalThis.Predef.__stackOffset = curOffset; +//│ return res2; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ curOffset = globalThis.Predef.__stackOffset; +//│ globalThis.Predef.__stackOffset = globalThis.Predef.__stackDepth; +//│ res3 = resume(); +//│ if (res3 instanceof globalThis.Predef.__EffectSig.class) { +//│ res3.tail.next = new Cont$1.class(4); +//│ res3.tail = res3.tail.next; +//│ return res3; +//│ } +//│ res2 = res3; +//│ globalThis.Predef.__stackOffset = curOffset; +//│ return res2; +//│ }); +//│ } +//│ toString() { return "StackDelay$"; } +//│ }; +//│ stackHandler = new StackDelay$(); +//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp; +//│ tmp = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 3) { +//│ res1 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 3) { +//│ return res1; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ globalThis.Predef.__stackLimit = 5; +//│ globalThis.Predef.__stackDepth = 1; +//│ globalThis.Predef.__stackHandler = stackHandler; +//│ res1 = hi1(0); +//│ if (res1 instanceof globalThis.Predef.__EffectSig.class) { +//│ res1.tail.next = new Cont$(3); +//│ return globalThis.Predef.__handleBlockImpl(res1, stackHandler); +//│ } +//│ return res1; +//│ }; +//│ res = handleBlock$1(); +//│ if (res instanceof this.Predef.__EffectSig.class) { +//│ throw new this.Error("Unhandled effects"); +//│ } +//│ this.Predef.__stackDepth = 0; +//│ this.Predef.__stackHandler = undefined; +//│ res +//│ = 0 + +:sjs +:stackSafe 1000 +:handler +:expect 50005000 +fun sum(n) = + if n == 0 then 0 + else + n + sum(n - 1) +sum(10000) +//│ JS (unsanitized): +//│ let sum3, res1, handleBlock$3; +//│ sum3 = function sum(n) { +//│ let scrut, tmp, tmp1, tmp2, prevDepth, diff, diffGeqLimit, handlerExists, scrut1, dummy, res2, res3, Cont$; +//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp3; +//│ tmp3 = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 1) { +//│ res3 = value$; +//│ } else if (this.pc === 0) { +//│ res2 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 3) { +//│ scrut = n == 0; +//│ if (scrut === true) { +//│ return 0; +//│ } else { +//│ tmp = n - 1; +//│ prevDepth = globalThis.Predef.__stackDepth; +//│ globalThis.Predef.__stackDepth = globalThis.Predef.__stackDepth + 1; +//│ res3 = sum3(tmp); +//│ if (res3 instanceof globalThis.Predef.__EffectSig.class) { +//│ res3.tail.next = this; +//│ this.pc = 1; +//│ return res3; +//│ } +//│ this.pc = 1; +//│ continue contLoop; +//│ } +//│ this.pc = 2; +//│ continue contLoop; +//│ } else if (this.pc === 2) { +//│ break contLoop; +//│ } else if (this.pc === 1) { +//│ tmp2 = res3; +//│ globalThis.Predef.__stackDepth = prevDepth; +//│ tmp1 = tmp2; +//│ return n + tmp1; +//│ } else if (this.pc === 0) { +//│ dummy = res2; +//│ this.pc = 3; +//│ continue contLoop; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ diff = globalThis.Predef.__stackDepth - globalThis.Predef.__stackOffset; +//│ diffGeqLimit = diff >= globalThis.Predef.__stackLimit; +//│ handlerExists = globalThis.Predef.__stackHandler !== undefined; +//│ scrut1 = diffGeqLimit && handlerExists; +//│ if (scrut1 === true) { +//│ res2 = globalThis.Predef.__stackHandler.perform(); +//│ if (res2 instanceof globalThis.Predef.__EffectSig.class) { +//│ res2.tail.next = new Cont$.class(0); +//│ res2.tail = res2.tail.next; +//│ return res2; +//│ } +//│ dummy = res2; +//│ } +//│ scrut = n == 0; +//│ if (scrut === true) { +//│ return 0; +//│ } else { +//│ tmp = n - 1; +//│ prevDepth = globalThis.Predef.__stackDepth; +//│ globalThis.Predef.__stackDepth = globalThis.Predef.__stackDepth + 1; +//│ res3 = sum3(tmp); +//│ if (res3 instanceof globalThis.Predef.__EffectSig.class) { +//│ res3.tail.next = new Cont$.class(1); +//│ res3.tail = res3.tail.next; +//│ return res3; +//│ } +//│ tmp2 = res3; +//│ globalThis.Predef.__stackDepth = prevDepth; +//│ tmp1 = tmp2; +//│ return n + tmp1; +//│ } +//│ }; +//│ handleBlock$3 = function handleBlock$() { +//│ let stackHandler, res2, Cont$, StackDelay$; +//│ StackDelay$ = class StackDelay$ extends globalThis.Predef.__StackDelay.class { +//│ constructor() { +//│ let tmp; +//│ tmp = super(); +//│ } +//│ perform() { +//│ return globalThis.Predef.__mkEffect(stackHandler, (resume) => { +//│ let res3, curOffset, res4, Cont$1; +//│ Cont$1 = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$1.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp; +//│ tmp = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 5) { +//│ res4 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 5) { +//│ res3 = res4; +//│ globalThis.Predef.__stackOffset = curOffset; +//│ return res3; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ curOffset = globalThis.Predef.__stackOffset; +//│ globalThis.Predef.__stackOffset = globalThis.Predef.__stackDepth; +//│ res4 = resume(); +//│ if (res4 instanceof globalThis.Predef.__EffectSig.class) { +//│ res4.tail.next = new Cont$1.class(5); +//│ res4.tail = res4.tail.next; +//│ return res4; +//│ } +//│ res3 = res4; +//│ globalThis.Predef.__stackOffset = curOffset; +//│ return res3; +//│ }); +//│ } +//│ toString() { return "StackDelay$"; } +//│ }; +//│ stackHandler = new StackDelay$(); +//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); }; +//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class { +//│ constructor(pc) { +//│ let tmp; +//│ tmp = super(null, null); +//│ this.pc = pc; +//│ } +//│ resume(value$) { +//│ if (this.pc === 4) { +//│ res2 = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 4) { +//│ return res2; +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return "Cont$(" + this.pc + ")"; } +//│ }; +//│ globalThis.Predef.__stackLimit = 1000; +//│ globalThis.Predef.__stackDepth = 1; +//│ globalThis.Predef.__stackHandler = stackHandler; +//│ res2 = sum3(10000); +//│ if (res2 instanceof globalThis.Predef.__EffectSig.class) { +//│ res2.tail.next = new Cont$(4); +//│ return globalThis.Predef.__handleBlockImpl(res2, stackHandler); +//│ } +//│ return res2; +//│ }; +//│ res1 = handleBlock$3(); +//│ if (res1 instanceof this.Predef.__EffectSig.class) { +//│ throw new this.Error("Unhandled effects"); +//│ } +//│ this.Predef.__stackDepth = 0; +//│ this.Predef.__stackHandler = undefined; +//│ res1 +//│ = 50005000 + +// stack-overflows without :stackSafe +:re +fun sum(n) = + if n == 0 then 0 + else + n + sum(n - 1) +sum(10000) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + +:handler +:stackSafe 100 +mut val ctr = 0 +fun dummy(x) = x +fun foo(f) = + if ctr > 10000 then 0 + else + set ctr += 1 + dummy(f(f)) +foo(foo) +//│ = 0 +//│ ctr = 10001 + +:stackSafe 1000 +:handler +:expect 50005000 +val foo = + val f = n => + if n <= 0 then 0 + else n + f(n-1) + f(10000) +foo +//│ = 50005000 +//│ foo = 50005000 + +:re +fun foo() = + val f = n => + if n <= 0 then 0 + else n + f(n-1) + f(10000) +foo() +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + +abstract class Eff with + fun perform(a): () + +// functions and lambdas inside handlers +:handler +:stackSafe 100 +:expect 50005000 +fun foo(h) = h.perform +handle h = Eff with + fun perform(resume) = + val f = n => + if n <= 0 then 0 + else n + f(n-1) + resume(f(10000)) +foo(h) +//│ = 50005000 + +// function call and defn inside handler +:handler +:stackSafe 100 +:expect 50005000 +handle h = Eff with + fun perform(resume) = + val f = n => + if n <= 0 then 0 + else n + f(n-1) + resume(f(10000)) +in + fun foo(h) = h.perform + foo(h) +//│ = 50005000 + +:re +:handler +fun foo(h) = h.perform(2) +handle h = Eff with + fun perform(a)(resume) = + val f = n => + if n <= 0 then 0 + else n + f(n-1) + resume(f(10000)) +foo(h) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + +:handler +:stackSafe +:sjs +fun max(a, b) = if a < b then b else a +//│ JS (unsanitized): +//│ let max1; +//│ max1 = function max(a, b) { +//│ let scrut; +//│ scrut = a < b; +//│ if (scrut === true) { +//│ return b; +//│ } else { +//│ return a; +//│ } +//│ }; +//│ null + + +// * Note that currently the `:sjs` command will not run the code if there is a compilation error +:sjs +:stackSafe 42 +:ge +fun hi(n) = n +hi(0) +//│ ═══[COMPILATION ERROR] This code requires effect handler instrumentation but was compiled without it. +//│ JS (unsanitized): +//│ let hi3, stackHandler; +//│ hi3 = function hi(n) { +//│ return n; +//│ }; +//│ throw globalThis.Error("This code requires effect handler instrumentation but was compiled without it."); + +:stackSafe 42 +:ge +:re +hi(0) +//│ ═══[COMPILATION ERROR] This code requires effect handler instrumentation but was compiled without it. +//│ ═══[RUNTIME ERROR] Error: This code requires effect handler instrumentation but was compiled without it. + + diff --git a/hkmc2/shared/src/test/mlscript/handlers/TailCallOptimization.mls b/hkmc2/shared/src/test/mlscript/handlers/TailCallOptimization.mls new file mode 100644 index 000000000..99ca1276b --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/handlers/TailCallOptimization.mls @@ -0,0 +1,22 @@ +:js +:handler + +abstract class StackDelay with + fun perform(): () + +:expect 3 +handle h = StackDelay with + fun perform()(resume) = resume() +fun foo(x) = + h.perform() + x + 4 +fun bar(y) = + h.perform() + // tail call + foo(y + 2) +fun foobar(z) = + bar(z + 1) + // stuff after tail call is linked differently + 3 +foobar(0) +//│ = 3 diff --git a/hkmc2/shared/src/test/mlscript/handlers/ZCombinator.mls b/hkmc2/shared/src/test/mlscript/handlers/ZCombinator.mls new file mode 100644 index 000000000..7d7cb4122 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/handlers/ZCombinator.mls @@ -0,0 +1,114 @@ +:js +:handler +:stackSafe 1000 + +fun selfApp(f) = f(f) + +fun mkrec(g) = + selfApp of self => + g of y => self(self)(y) + +let fact = mkrec of self => x => + if x == 0 then 1 else self(x - 1) * x +//│ fact = [Function (anonymous)] + +fact(3) +//│ = 6 + +fact(10) +//│ = 3628800 + +:re +:stackSafe off +fact(10000) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + + +fun mkrec(g) = + selfApp of self => + g of y => selfApp(self)(y) + +let fact = mkrec of self => x => + if x == 0 then 1 + else self(x - 1) * x +//│ fact = [Function (anonymous)] + +:stackSafe 1000 +fact(10000) +//│ = Infinity + +// * Without `:stackSafe`, gives `RangeError: Maximum call stack size exceeded` +:re +:stackSafe off +fact(10000) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + +:stackSafe 1000 +fact(10000) +//│ = Infinity + +:re +:stackSafe off +fact(10000) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + + +:stackSafe 5 +set selfApp = f => f(f) + +:expect 3628800 +fact(10) +//│ = 3628800 + + +:stackSafe 10 +set selfApp = f => f(f) + +fact(1000) +//│ = Infinity + +// simplified version without lambdas for easier debugging +:stackSafe 5 +fun mkrec(g) = + fun a(self) = + fun b(y) = selfApp(self)(y) + g(b) + selfApp(a) + +:stackSafe 5 +let fact = + fun a(self) = + fun b(x) = + if x == 0 then 1 else + console.log(__stackDepth, __stackOffset) + let prev = self(x - 1) + console.log("resumed:", x) + x * prev + b + mkrec(a) +//│ fact = [Function: b] + +:expect 3628800 +:stackSafe 5 +fact(10) +//│ > 2 0 +//│ > 4 0 +//│ > 6 5 +//│ > 8 5 +//│ > 10 10 +//│ > 12 10 +//│ > 14 10 +//│ > 16 15 +//│ > 18 15 +//│ > 20 20 +//│ > resumed: 1 +//│ > resumed: 2 +//│ > resumed: 3 +//│ > resumed: 4 +//│ > resumed: 5 +//│ > resumed: 6 +//│ > resumed: 7 +//│ > resumed: 8 +//│ > resumed: 9 +//│ > resumed: 10 +//│ = 3628800 diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index abe2a40b0..58cc0601a 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -27,6 +27,8 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val handler = NullaryCommand("handler") val expect = Command("expect"): ln => ln.trim + val stackSafe = Command("stackSafe"): ln => + ln.trim private val baseScp: utils.Scope = utils.Scope.empty @@ -48,11 +50,25 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: private var hostCreated = false override def run(): Unit = try super.run() finally if hostCreated then host.terminate() + + private val DEFAULT_STACK_LIMT = 500 override def processTerm(blk: semantics.Term.Blk, inImport: Bool)(using Raise): Unit = super.processTerm(blk, inImport) val outerRaise: Raise = summon var showingJSYieldedCompileError = false + val stackLimit = stackSafe.get match + case None => None + case Some("off") => None + case Some(value) => value.toIntOption match + case None => Some(DEFAULT_STACK_LIMT) + case Some(value) => + if value < 0 then + failures += 1 + output("/!\\ Stack limit must be positive, but the stack limit here is set to " + value) + Some(DEFAULT_STACK_LIMT) + else + Some(value) if showJS.isSet then given Raise = case d @ ErrorReport(source = Source.Compilation) => @@ -61,7 +77,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: case d => outerRaise(d) given Elaborator.Ctx = curCtx val low = ltl.givenIn: - new codegen.Lowering(lowerHandlers = handler.isSet) + new codegen.Lowering(lowerHandlers = handler.isSet, stackLimit = stackLimit) with codegen.LoweringSelSanityChecks(instrument = false) with codegen.LoweringTraceLog(instrument = false) val jsb = new JSBuilder @@ -76,7 +92,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if js.isSet && !showingJSYieldedCompileError then given Elaborator.Ctx = curCtx val low = ltl.givenIn: - new codegen.Lowering(lowerHandlers = handler.isSet) + new codegen.Lowering(lowerHandlers = handler.isSet, stackLimit = stackLimit) with codegen.LoweringSelSanityChecks(noSanityCheck.isUnset) with codegen.LoweringTraceLog(traceJS.isSet) val jsb = new JSBuilder diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala index ddd431923..802382911 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -37,7 +37,7 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: super.processTerm(trm, inImport) if llir.isSet then val low = ltl.givenIn: - codegen.Lowering(false) + codegen.Lowering(lowerHandlers = false, stackLimit = None) val le = low.program(trm) given Scope = Scope.empty val fresh = Fresh()