diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 044e4df66..670030768 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -18,7 +18,7 @@ jobs: - name: Install TypeScript run: npm ci - name: Run test - run: sbt -J-Xmx4096M -J-Xss4M test + run: sbt -J-Xmx4096M -J-Xss8M test - name: Check no changes run: | git update-index -q --refresh diff --git a/flake.nix b/flake.nix index 5f81da24a..8e657face 100644 --- a/flake.nix +++ b/flake.nix @@ -11,7 +11,7 @@ (system: let sbtOverlay = self: super: { - sbt = super.sbt.override { jre = super.jdk8_headless; }; + sbt = super.sbt.override { jre = super.jdk11_headless; }; }; pkgs = import nixpkgs { inherit system; diff --git a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala index 6390cf3c3..5d770db33 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala @@ -10,7 +10,7 @@ import mlscript.utils._, shorthands._ class MainDiffMaker(val rootPath: Str, val file: os.Path, val preludeFile: os.Path, val predefFile: os.Path, val relativeName: Str) - extends BbmlDiffMaker + extends LlirDiffMaker diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala new file mode 100644 index 000000000..ef9317828 --- /dev/null +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -0,0 +1,73 @@ +package hkmc2 + +import scala.collection.mutable + +import mlscript.utils.*, shorthands.* +import utils.* + +import document.* +import codegen.Block +import codegen.llir.* +import codegen.cpp.* +import hkmc2.syntax.Tree.Ident +import hkmc2.codegen.Path +import hkmc2.semantics.Term.Blk +import hkmc2.codegen.llir.Fresh +import hkmc2.utils.Scope +import hkmc2.codegen.llir.Ctx +import hkmc2.codegen.llir._ + +abstract class LlirDiffMaker extends BbmlDiffMaker: + val llir = NullaryCommand("llir") + val cpp = NullaryCommand("cpp") + val sllir = NullaryCommand("sllir") + val scpp = NullaryCommand("scpp") + val rcpp = NullaryCommand("rcpp") + val intl = NullaryCommand("intl") + val wcpp = Command[Str]("wcpp", false)(x => x.stripLeading()) + + def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) = + val p = new java.io.PrintWriter(f) + try { op(p) } finally { p.close() } + + override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = + super.processTerm(trm, inImport) + if llir.isSet then + val low = ltl.givenIn: + codegen.Lowering() + val le = low.program(trm) + given Scope = Scope.empty + val fresh = Fresh() + val fuid = FreshInt() + val cuid = FreshInt() + val llb = LlirBuilder(tl)(fresh, fuid, cuid) + given Ctx = Ctx.empty + try + val llirProg = llb.bProg(le) + if sllir.isSet then + output("LLIR:") + output(llirProg.show()) + if cpp.isSet || scpp.isSet || rcpp.isSet || wcpp.isSet then + val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) + if scpp.isSet then + output("\nCpp:") + output(cpp.toDocument.toString) + val auxPath = os.Path(rootPath) / "hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" + if wcpp.isSet then + printToFile(java.io.File((auxPath / s"${wcpp.get.get}.cpp").toString)): + p => p.println(cpp.toDocument.toString) + if rcpp.isSet then + val cppHost = CppCompilerHost(auxPath.toString, output.apply) + if !cppHost.ready then + output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") + else + output("\n") + cppHost.compileAndRun(cpp.toDocument.toString) + if intl.isSet then + val intr = codegen.llir.Interpreter(verbose = true) + output("\nInterpreted:") + output(intr.interpret(llirProg)) + catch + case e: LowLevelIRError => + output("Stopped due to an error during the Llir generation") + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index b86f008df..2015624b4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -64,7 +64,20 @@ object Printer: case ValDefn(owner, k, sym, rhs) => doc"val ${sym.nme} = ${mkDocument(rhs)}" case ClsLikeDefn(sym, k, parentSym, methods, privateFields, publicFields, preCtor, ctor) => - doc"class ${sym.nme} #{ #} " + def optFldBody(t: semantics.TermDefinition) = + t.body match + case Some(x) => doc" = ..." + case None => doc"" + val clsDefn = sym.defn.getOrElse(die) + val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + val ctorParams = clsParams.map(p => summon[Scope].allocateName(p)) + val privFields = privateFields.map(x => doc"let ${x.id.name} = ...").mkDocument(sep = doc" # ") + val pubFields = publicFields.map(x => doc"${x.k.str} ${x.sym.nme}${optFldBody(x)}").mkDocument(sep = doc" # ") + val docPrivFlds = if privateFields.isEmpty then doc"" else doc" # ${privFields}" + val docPubFlds = if publicFields.isEmpty then doc"" else doc" # ${pubFields}" + val docBody = if publicFields.isEmpty && privateFields.isEmpty then doc"" else doc" { #{ ${docPrivFlds}${docPubFlds} #} # }" + val docCtorParams = if clsParams.isEmpty then doc"" else doc"(${ctorParams.mkString(", ")})" + doc"class ${sym.nme}${docCtorParams}${docBody}" def mkDocument(arg: Arg)(using Raise, Scope): Document = val doc = mkDocument(arg.value) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala new file mode 100644 index 000000000..c05e4f1d3 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala @@ -0,0 +1,220 @@ +package hkmc2.codegen.cpp + +import mlscript._ +import mlscript.utils._ +import mlscript.utils.shorthands._ + +import hkmc2.Message.MessageContext +import hkmc2.document._ + +import scala.language.implicitConversions + +private def raw(x: String): Document = doc"$x" +given Conversion[String, Document] = x => doc"$x" + +enum Specifier: + case Extern + case Static + case Inline + + def toDocument = + this match + case Extern => "extern" + case Static => "static" + case Inline => "inline" + + override def toString: Str = toDocument + +object Type: + def toDocuments(args: Ls[Type], sep: Document, extraTypename: Bool = false): Document = + args.map(_.toDocument(extraTypename)).mkDocument(sep) + + def toDocuments(args: Ls[(Str, Type)], sep: Document): Document = + args.map(x => doc"${x._2.toDocument()} ${x._1}").mkDocument(sep) + +enum Type: + case Prim(name: Str) + case Ptr(inner: Type) + case Ref(inner: Type) + case Array(inner: Type, size: Opt[Int]) + case FuncPtr(ret: Type, args: List[Type]) + case Struct(name: Str) + case Enum(name: Str) + case Template(name: Str, args: List[Type]) + case Var(name: Str) + case Qualifier(inner: Type, qual: Str) + + def toDocument(extraTypename: Bool = false): Document = + def aux(x: Type): Document = x match + case Prim(name) => name + case Ptr(inner) => doc"${aux(inner)}*" + case Ref(inner) => doc"${aux(inner)}&" + case Array(inner, size) => + doc"${aux(inner)}[${size.fold("")(x => x.toString)}]" + case FuncPtr(ret, args) => + doc"${aux(ret)}(${Type.toDocuments(args, sep = ", ")})" + case Struct(name) => doc"struct $name" + case Enum(name) => doc"enum $name" + case Template(name, args) => + doc"$name<${Type.toDocuments(args, sep = ", ")}>" + case Var(name) => name + case Qualifier(inner, qual) => doc"${aux(inner)} $qual" + aux(this) + + override def toString: Str = toDocument().toString + +object Stmt: + def toDocuments(decl: Ls[Decl], stmts: Ls[Stmt]): Document = + (decl.map(_.toDocument) ++ stmts.map(_.toDocument)).mkDocument(doc" # ") + +enum Stmt: + case AutoBind(lhs: Ls[Str], rhs: Expr) + case Assign(lhs: Str, rhs: Expr) + case Return(expr: Expr) + case If(cond: Expr, thenStmt: Stmt, elseStmt: Opt[Stmt]) + case While(cond: Expr, body: Stmt) + case For(init: Stmt, cond: Expr, update: Stmt, body: Stmt) + case ExprStmt(expr: Expr) + case Break + case Continue + case Block(decl: Ls[Decl], stmts: Ls[Stmt]) + case Switch(expr: Expr, cases: Ls[(Expr, Stmt)]) + case Raw(stmt: Str) + + def toDocument: Document = + def aux(x: Stmt): Document = x match + case AutoBind(lhs, rhs) => + lhs match + case Nil => rhs.toDocument + case x :: Nil => + doc"auto $x = ${rhs.toDocument};" + case _ => + doc"auto [${lhs.mkDocument(", ")}] = ${rhs.toDocument};" + case Assign(lhs, rhs) => + doc"$lhs = ${rhs.toDocument};" + case Return(expr) => + doc"return ${expr.toDocument};" + case If(cond, thenStmt, elseStmt) => + doc"if (${cond.toDocument}) ${thenStmt.toDocument}${elseStmt.fold(doc" ")(x => doc" else ${x.toDocument}")}" + case While(cond, body) => + doc"while (${cond.toDocument}) ${body.toDocument}" + case For(init, cond, update, body) => + doc"for (${init.toDocument}; ${cond.toDocument}; ${update.toDocument}) ${body.toDocument}" + case ExprStmt(expr) => + doc"${expr.toDocument};" + case Break => "break;" + case Continue => "continue;" + case Block(decl, stmts) => + doc"{ #{ # ${Stmt.toDocuments(decl, stmts)} #} # }" + case Switch(expr, cases) => + val docCases = cases.map { + case (cond, stmt) => doc"case ${cond.toDocument}: ${stmt.toDocument}" + }.mkDocument(doc" # ") + doc"switch (${expr.toDocument}) { #{ # ${docCases} #} # }" + case Raw(stmt) => stmt + aux(this) + +object Expr: + def toDocuments(args: Ls[Expr], sep: Document): Document = + args.map(_.toDocument).mkDocument(sep) + +enum Expr: + case Var(name: Str) + case IntLit(value: BigInt) + case DoubleLit(value: Double) + case StrLit(value: Str) + case CharLit(value: Char) + case Call(func: Expr, args: Ls[Expr]) + case Member(expr: Expr, member: Str) + case Index(expr: Expr, index: Expr) + case Unary(op: Str, expr: Expr) + case Binary(op: Str, lhs: Expr, rhs: Expr) + case Initializer(exprs: Ls[Expr]) + case Constructor(name: Str, init: Expr) + + def toDocument: Document = + def aux(x: Expr): Document = x match + case Var(name) => name + case IntLit(value) => value.toString + case DoubleLit(value) => value.toString + case StrLit(value) => s"\"$value\"" // need more reliable escape utils + case CharLit(value) => value.toInt.toString + case Call(func, args) => + doc"${func.toDocument}(${Expr.toDocuments(args, sep = ", ")})" + case Member(expr, member) => + doc"${expr.toDocument}->$member" + case Index(expr, index) => + doc"${expr.toDocument}[${index.toDocument}]" + case Unary(op, expr) => + doc"($op${expr.toDocument})" + case Binary(op, lhs, rhs) => + doc"(${lhs.toDocument} $op ${rhs.toDocument})" + case Initializer(exprs) => + doc"{${Expr.toDocuments(exprs, sep = ", ")}}" + case Constructor(name, init) => + doc"$name(${init.toDocument})" + aux(this) + +case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): + def toDocument: Document = + (includes.map(raw) ++ decls.map(_.toDocument) ++ defs.map(_.toDocument)).mkDocument(doc" # ") + def toDocumentWithoutHidden: Document = + val hiddenNames: Set[Str] = Set() + defs.filterNot { + case Def.StructDef(name, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) + case _ => false + }.map(_.toDocument).mkDocument(doc" # ") + +enum Decl: + case StructDecl(name: Str) + case EnumDecl(name: Str) + case FuncDecl(ret: Type, name: Str, args: Ls[Type]) + case VarDecl(name: Str, typ: Type) + + def toDocument: Document = + def aux(x: Decl): Document = x match + case StructDecl(name) => doc"struct $name;" + case EnumDecl(name) => doc"enum $name;" + case FuncDecl(ret, name, args) => + doc"${ret.toDocument()} $name(${Type.toDocuments(args, sep = ", ")});" + case VarDecl(name, typ) => + doc"${typ.toDocument()} $name;" + aux(this) + +enum Def: + case StructDef(name: Str, fields: Ls[(Str, Type)], inherit: Opt[Ls[Str]], methods: Ls[Def] = Ls.empty) + case EnumDef(name: Str, fields: Ls[(Str, Opt[Int])]) + case FuncDef(specret: Type, name: Str, args: Ls[(Str, Type)], body: Stmt.Block, or: Bool = false, virt: Bool = false) + case VarDef(typ: Type, name: Str, init: Opt[Expr]) + case RawDef(raw: Str) + + def toDocument: Document = + def aux(x: Def): Document = x match + case StructDef(name, fields, inherit, defs) => + val docFirst = doc"struct $name${inherit.fold(doc"")(x => doc": public ${x.mkDocument(doc", ")}")} {" + val docFields = fields.map { + case (name, typ) => doc"${typ.toDocument()} $name;" + }.mkDocument(doc" # ") + val docDefs = defs.map(_.toDocument).mkDocument(doc" # ") + val docLast = "};" + doc"$docFirst #{ # $docFields # $docDefs #} # $docLast" + case EnumDef(name, fields) => + val docFirst = doc"enum $name {" + val docFields = fields.map { + case (name, value) => value.fold(s"$name")(x => s"$name = $x") + }.mkDocument(doc" # ") + val docLast = "};" + doc"$docFirst #{ # $docFields #} # $docLast" + case FuncDef(specret, name, args, body, or, virt) => + val docVirt = (if virt then doc"virtual " else doc"") + val docSpecRet = specret.toDocument() + val docArgs = Type.toDocuments(args, sep = ", ") + val docOverride = if or then doc" override" else doc"" + val docBody = body.toDocument + doc"$docVirt$docSpecRet $name($docArgs)$docOverride ${body.toDocument}" + case VarDef(typ, name, init) => + val docTyp = typ.toDocument() + val docInit = init.fold(raw(""))(x => doc" = ${x.toDocument}") + doc"$docTyp $name$docInit;" + case RawDef(x) => x + aux(this) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala new file mode 100644 index 000000000..80f83ae7c --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala @@ -0,0 +1,240 @@ +package hkmc2.codegen.cpp + +import mlscript.utils._ +import mlscript.utils.shorthands._ +import scala.collection.mutable.ListBuffer + +import hkmc2.codegen.llir.{Expr => IExpr, _} +import hkmc2.codegen.cpp._ + +object CppCodeGen: + def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_') + def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_') + val freshName = Fresh(div = '_'); + val mlsValType = Type.Prim("_mlsValue") + val mlsUnitValue = Expr.Call(Expr.Var("_mlsValue::create<_mls_Unit>"), Ls()); + val mlsRetValue = "_mls_retval" + val mlsRetValueDecl = Decl.VarDecl(mlsRetValue, mlsValType) + val mlsMainName = "_mlsMain" + val mlsPrelude = "#include \"mlsprelude.h\"" + val mlsPreludeImpl = "#include \"mlsprelude.cpp\"" + val mlsInternalClass = Set("True", "False", "Boolean", "Callable") + val mlsObject = "_mlsObject" + val mlsBuiltin = "builtin" + val mlsEntryPoint = s"int main() { return _mlsLargeStack(_mlsMainWrapper); }"; + def mlsIntLit(x: BigInt) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.IntLit(x))) + def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::fromStrLit"), Ls(Expr.StrLit(x))) + def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x))) + def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args) + def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut)) + def mlsIsBoolLit(scrut: Expr, lit: hkmc2.syntax.Tree.BoolLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(if lit.value then 1 else 0))) + def mlsIsIntLit(scrut: Expr, lit: hkmc2.syntax.Tree.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value))) + def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x)) + def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init) + def mlsAs(name: Str, cls: Str) = Expr.Var(s"_mlsValue::as<$cls>($name)") + def mlsAsUnchecked(name: Str, cls: Str) = Expr.Var(s"_mlsValue::cast<$cls>($name)") + def mlsObjectNameMethod(name: Str) = s"constexpr static inline const char *typeName = \"${name}\";" + def mlsTypeTag() = s"constexpr static inline uint32_t typeTag = nextTypeTag();" + def mlsTypeTag(n: Int) = s"constexpr static inline uint32_t typeTag = $n;" + def mlsCommonCreateMethod(cls: Str, fields: Ls[Str], id: Int) = + val parameters = fields.map{x => s"_mlsValue $x"}.mkString(", ") + val fieldsAssignment = fields.map{x => s"_mlsVal->$x = $x; "}.mkString + s"static _mlsValue create($parameters) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) $cls; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; $fieldsAssignment return _mlsValue(_mlsVal); }" + def mlsCommonPrintMethod(fields: Ls[Str]) = + if fields.isEmpty then s"virtual void print() const override { std::printf(\"%s\", typeName); }" + else + val fieldsPrint = fields.map{x => s"this->$x.print(); "}.mkString("std::printf(\", \"); ") + s"virtual void print() const override { std::printf(\"%s\", typeName); std::printf(\"(\"); $fieldsPrint std::printf(\")\"); }" + def mlsCommonDestructorMethod(cls: Str, fields: Ls[Str]) = + val fieldsDeletion = fields.map{x => s"_mlsValue::destroy(this->$x); "}.mkString + s"virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); }" + def mlsThrowNonExhaustiveMatch = Stmt.Raw("_mlsNonExhaustiveMatch();"); + def mlsCall(fn: Str, args: Ls[Expr]) = Expr.Call(Expr.Var("_mlsCall"), Expr.Var(fn) :: args) + def mlsMethodCall(cls: ClassRef, method: Str, args: Ls[Expr]) = + Expr.Call(Expr.Member(Expr.Call(Expr.Var(s"_mlsMethodCall<${cls.name |> mapName}>"), Ls(args.head)), method), args.tail) + def mlsFnWrapperName(fn: Str) = s"_mlsFn_$fn" + def mlsFnCreateMethod(fn: Str) = s"static _mlsValue create() { static _mlsFn_$fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); }" + def mlsNeverValue(n: Int) = if (n <= 1) then Expr.Call(Expr.Var(s"_mlsValue::never"), Ls()) else Expr.Call(Expr.Var(s"_mlsValue::never<$n>"), Ls()) + + case class Ctx( + defnCtx: Set[Str], + ) + + def codegenClassInfo(using ctx: Ctx)(cls: ClassInfo): (Opt[Def], Decl) = + val fields = cls.fields.map{x => (x |> mapName, mlsValType)} + val parents = if cls.parents.nonEmpty then cls.parents.toList.map(mapName) else mlsObject :: Nil + val decl = Decl.StructDecl(cls.name |> mapName) + if mlsInternalClass.contains(cls.name) then return (None, decl) + val theDef = Def.StructDef( + cls.name |> mapName, fields, + if parents.nonEmpty then Some(parents) else None, + Ls(Def.RawDef(mlsObjectNameMethod(cls.name)), + Def.RawDef(mlsTypeTag()), + Def.RawDef(mlsCommonPrintMethod(cls.fields.map(mapName))), + Def.RawDef(mlsCommonDestructorMethod(cls.name |> mapName, cls.fields.map(mapName))), + Def.RawDef(mlsCommonCreateMethod(cls.name |> mapName, cls.fields.map(mapName), cls.id))) + ++ cls.methods.map{case (name, defn) => { + val (theDef, decl) = codegenDefn(using Ctx(ctx.defnCtx + cls.name))(defn) + theDef match + case x @ Def.FuncDef(_, name, _, _, _, _) => x.copy(virt = true) + case _ => theDef + }} + ) + (S(theDef), decl) + + def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using ctx: Ctx): Opt[Expr] = texpr match + case IExpr.Ref(name) => S(Expr.Var(name |> mapName)) + case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => S(mlsIntLit(if x then 1 else 0)) + case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => S(mlsIntLit(x)) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => S(mlsIntLit(x.toBigInt)) + case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => S(mlsStrLit(x)) + case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => if reifyUnit then S(mlsUnitValue) else None + + def toExpr(texpr: TrivialExpr)(using ctx: Ctx): Expr = texpr match + case IExpr.Ref(name) => Expr.Var(name |> mapName) + case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => mlsIntLit(if x then 1 else 0) + case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => mlsIntLit(x) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => mlsIntLit(x.toBigInt) + case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => mlsStrLit(x) + case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => mlsUnitValue + + + def wrapMultiValues(exprs: Ls[TrivialExpr])(using ctx: Ctx): Expr = exprs match + case x :: Nil => toExpr(x, reifyUnit = true).get + case _ => + val init = Expr.Initializer(exprs.map{x => toExpr(x)}) + mlsTupleValue(init) + + def codegenCaseWithIfs(scrut: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + val scrut2 = toExpr(scrut) + val init: Stmt = + default.fold(mlsThrowNonExhaustiveMatch)(x => { + val (decls2, stmts2) = codegen(x, storeInto)(using Ls.empty, Ls.empty[Stmt]) + Stmt.Block(decls2, stmts2) + }) + val stmt = cases.foldRight(S(init)) { + case ((Pat.Class(cls), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsValueOf(cls.name |> mapName, scrut2), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) + case ((Pat.Lit(i @ hkmc2.syntax.Tree.IntLit(_)), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsIntLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) + case ((Pat.Lit(i @ hkmc2.syntax.Tree.BoolLit(_)), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsBoolLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) + case _ => ??? + } + (decls, stmt.fold(stmts)(x => stmts :+ x)) + + def codegenJumpWithCall(func: FuncRef, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + val call = Expr.Call(Expr.Var(func.name |> mapName), args.map(toExpr)) + val stmts2 = stmts ++ Ls(storeInto.fold(Stmt.Return(call))(x => Stmt.Assign(x, call))) + (decls, stmts2) + + def codegenOps(op: Str, args: Ls[TrivialExpr])(using ctx: Ctx) = op match + case "+" => Expr.Binary("+", toExpr(args(0)), toExpr(args(1))) + case "-" => Expr.Binary("-", toExpr(args(0)), toExpr(args(1))) + case "*" => Expr.Binary("*", toExpr(args(0)), toExpr(args(1))) + case "/" => Expr.Binary("/", toExpr(args(0)), toExpr(args(1))) + case "%" => Expr.Binary("%", toExpr(args(0)), toExpr(args(1))) + case "==" => Expr.Binary("==", toExpr(args(0)), toExpr(args(1))) + case "!=" => Expr.Binary("!=", toExpr(args(0)), toExpr(args(1))) + case "<" => Expr.Binary("<", toExpr(args(0)), toExpr(args(1))) + case "<=" => Expr.Binary("<=", toExpr(args(0)), toExpr(args(1))) + case ">" => Expr.Binary(">", toExpr(args(0)), toExpr(args(1))) + case ">=" => Expr.Binary(">=", toExpr(args(0)), toExpr(args(1))) + case "&&" => Expr.Binary("&&", toExpr(args(0)), toExpr(args(1))) + case "||" => Expr.Binary("||", toExpr(args(0)), toExpr(args(1))) + case "!" => Expr.Unary("!", toExpr(args(0))) + case _ => TODO("codegenOps") + + + def codegen(expr: IExpr)(using ctx: Ctx): Expr = expr match + case x @ (IExpr.Ref(_) | IExpr.Literal(_)) => toExpr(x, reifyUnit = true).get + case IExpr.CtorApp(cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr)) + case IExpr.Select(name, cls, field) => Expr.Member(mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName) + case IExpr.BasicOp(name, args) => codegenOps(name, args) + case IExpr.AssignField(assignee, cls, field, value) => TODO("codegen assign field") + + def codegenBuiltin(names: Ls[Name], builtin: Str, args: Ls[TrivialExpr])(using ctx: Ctx): Ls[Stmt] = builtin match + case "error" => Ls(Stmt.Raw("throw std::runtime_error(\"Error\");"), Stmt.AutoBind(names.map(mapName), mlsNeverValue(names.size))) + case _ => Ls(Stmt.AutoBind(names.map(mapName), Expr.Call(Expr.Var("_mls_builtin_" + builtin), args.map(toExpr)))) + + def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = body match + case Node.Result(res) => + val expr = wrapMultiValues(res) + val stmts2 = stmts ++ Ls(Stmt.Assign(storeInto, expr)) + (decls, stmts2) + case Node.Jump(defn, args) => + codegenJumpWithCall(defn, args, S(storeInto)) + case Node.Panic(msg) => (decls, stmts :+ Stmt.Raw(s"throw std::runtime_error(\"$msg\");")) + case Node.LetExpr(name, expr, body) => + val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> mapName), codegen(expr))) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, IExpr.Ref(Name("builtin")) :: args, body) => + val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, args, body) => + val call = mlsMethodCall(cls, method.str |> mapName, args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetCall(names, defn, args, body) => + val call = Expr.Call(Expr.Var(defn.name |> mapName), args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.Case(scrut, cases, default) => + codegenCaseWithIfs(scrut, cases, default, storeInto) + + def codegenDefn(using ctx: Ctx)(defn: Func): (Def, Decl) = defn match + case Func(id, name, params, resultNum, body) => + val decls = Ls(mlsRetValueDecl) + val stmts = Ls.empty[Stmt] + val (decls2, stmts2) = codegen(body, mlsRetValue)(using decls, stmts) + val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) + val theDef = Def.FuncDef(mlsValType, name |> mapName, params.map(x => (x |> mapName, mlsValType)), Stmt.Block(decls2, stmtsWithReturn)) + val decl = Decl.FuncDecl(mlsValType, name |> mapName, params.map(x => mlsValType)) + (theDef, decl) + + def codegenTopNode(node: Node)(using ctx: Ctx): (Def, Decl) = + val decls = Ls(mlsRetValueDecl) + val stmts = Ls.empty[Stmt] + val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts) + val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) + val theDef = Def.FuncDef(mlsValType, mlsMainName, Ls(), Stmt.Block(decls2, stmtsWithReturn)) + val decl = Decl.FuncDecl(mlsValType, mlsMainName, Ls()) + (theDef, decl) + + // Topological sort of classes based on inheritance relationships + def sortClasses(prog: Program): Ls[ClassInfo] = + var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap + var degree = depgraph.view.mapValues(_.size).toMap + def removeNode(node: Str) = + degree -= node + depgraph -= node + depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap + degree = depgraph.view.mapValues(_.size).toMap + val sorted = ListBuffer.empty[ClassInfo] + var work = degree.filter(_._2 == 0).keys.toSet + while work.nonEmpty do + val node = work.head + work -= node + sorted.addOne(prog.classes.find(_.name == node).get) + removeNode(node) + val next = degree.filter(_._2 == 0).keys + work ++= next + if depgraph.nonEmpty then + val cycle = depgraph.keys.mkString(", ") + throw new Exception(s"Cycle detected in class hierarchy: $cycle") + sorted.toList + + def codegen(prog: Program): CompilationUnit = + val sortedClasses = sortClasses(prog) + val defnCtx = prog.defs.map(_.name) + val (defs, decls) = sortedClasses.map(codegenClassInfo(using Ctx(defnCtx))).unzip + val (defs2, decls2) = prog.defs.map(codegenDefn(using Ctx(defnCtx))).unzip + val (defMain, declMain) = codegenTopNode(prog.main)(using Ctx(defnCtx)) + CompilationUnit(Ls(mlsPrelude), decls ++ decls2 :+ declMain, defs.flatten ++ defs2 :+ defMain :+ Def.RawDef(mlsEntryPoint)) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala new file mode 100644 index 000000000..867ffe4e6 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala @@ -0,0 +1,44 @@ +package hkmc2.codegen.cpp + +import mlscript._ +import mlscript.utils.shorthands._ +import scala.collection.mutable.ListBuffer + +final class CppCompilerHost(val auxPath: Str, val output: Str => Unit): + import scala.sys.process._ + private def ifAnyCppCompilerExists(): Boolean = + Seq("g++", "--version").! == 0 || Seq("clang++", "--version").! == 0 + + private def isMakeExists(): Boolean = + import scala.sys.process._ + Seq("make", "--version").! == 0 + + val ready = ifAnyCppCompilerExists() && isMakeExists() + + def compileAndRun(src: Str): Unit = + if !ready then + return + val srcPath = os.temp(contents = src, suffix = ".cpp") + val binPath = os.temp(suffix = ".mls.out") + var stdout = ListBuffer[Str]() + var stderr = ListBuffer[Str]() + val buildLogger = ProcessLogger(stdout :+= _, stderr :+= _) + val buildResult = Seq("make", "-B", "-C", auxPath, "auto", s"SRC=$srcPath", s"DST=$binPath") ! buildLogger + if buildResult != 0 then + output("Compilation failed: ") + for line <- stdout do output(line) + for line <- stderr do output(line) + return + + stdout.clear() + stderr.clear() + val runCmd = Seq(binPath.toString) + val runResult = runCmd ! buildLogger + if runResult != 0 then + output("Execution failed: ") + for line <- stdout do output(line) + for line <- stderr do output(line) + return + + output("Execution succeeded: ") + for line <- stdout do output(line) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala new file mode 100644 index 000000000..042295874 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala @@ -0,0 +1,125 @@ +package hkmc2.codegen.llir + +import mlscript._ +import hkmc2.codegen._ +import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } +import mlscript.utils._ +import mlscript.utils.shorthands._ +import hkmc2.semantics.BuiltinSymbol +import hkmc2.syntax.Tree.UnitLit +import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} +import hkmc2.Message.MessageContext +import hkmc2.semantics.InnerSymbol +import hkmc2.codegen.llir.FuncRef.fromName +import scala.collection.mutable.ListBuffer + +import scala.annotation.tailrec +import scala.collection.immutable.* +import scala.collection.mutable.{HashMap => MutHMap} +import scala.collection.mutable.{HashSet => MutHSet, Set => MutSet} + +class UsefulnessAnalysis(verbose: Bool = false): + import Expr._ + import Node._ + + def log(x: Any) = if verbose then println(x) + + val uses = MutHMap[(Name, Int), Int]() + val defs = MutHMap[Name, Int]() + + private def addDef(x: Name) = + defs.update(x, defs.getOrElse(x, 0) + 1) + + private def addUse(x: Name) = + val def_count = defs.get(x) match + case None => throw Exception(s"Use of undefined variable $x") + case Some(value) => value + val key = (x, defs(x)) + uses.update(key, uses.getOrElse(key, 0) + 1) + + private def f(x: TrivialExpr): Unit = x match + case Ref(name) => addUse(name) + case _ => () + + private def f(x: Expr): Unit = x match + case Ref(name) => addUse(name) + case Literal(lit) => + case CtorApp(name, args) => args.foreach(f) + case Select(name, cls, field) => addUse(name) + case BasicOp(name, args) => args.foreach(f) + case AssignField(assignee, _, _, value) => + addUse(assignee) + f(value) + + private def f(x: Node): Unit = x match + case Result(res) => res.foreach(f) + case Jump(defn, args) => args.foreach(f) + case Case(scrut, cases, default) => + scrut match + case Ref(name) => addUse(name) + case _ => () + cases.foreach { case (cls, body) => f(body) }; default.foreach(f) + case LetMethodCall(names, cls, method, args, body) => addUse(method); args.foreach(f); names.foreach(addDef); f(body) + case LetExpr(name, expr, body) => f(expr); addDef(name); f(body) + case LetCall(names, defn, args, body) => args.foreach(f); names.foreach(addDef); f(body) + + def run(x: Func) = + x.params.foreach(addDef) + f(x.body) + uses.toMap + +class FreeVarAnalysis(extended_scope: Bool = true, verbose: Bool = false): + import Expr._ + import Node._ + + private val visited = MutHSet[Str]() + private def f(using defined: Set[Str])(defn: Func, fv: Set[Str]): Set[Str] = + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + f(using defined2)(defn.body, fv) + private def f(using defined: Set[Str])(expr: Expr, fv: Set[Str]): Set[Str] = expr match + case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + case Literal(lit) => fv + case CtorApp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case Select(name, cls, field) => if defined.contains(name.str) then fv else fv + name.str + case BasicOp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case AssignField(assignee, _, _, value) => f(using defined)( + value.toExpr, + if defined.contains(assignee.str) then fv + assignee.str else fv + ) + private def f(using defined: Set[Str])(node: Node, fv: Set[Str]): Set[Str] = node match + case Result(res) => res.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case Jump(defnref, args) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + if extended_scope && !visited.contains(defnref.name) then + val defn = defnref.expectFn + visited.add(defn.name) + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + fv2 = f(using defined2)(defn, fv2) + fv2 + case Case(scrut, cases, default) => + val fv2 = scrut match + case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + case _ => fv + val fv3 = cases.foldLeft(fv2) { + case (acc, (cls, body)) => f(using defined)(body, acc) + } + fv3 + case LetMethodCall(resultNames, cls, method, args, body) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) + f(using defined2)(body, fv2) + case LetExpr(name, expr, body) => + val fv2 = f(using defined)(expr, fv) + val defined2 = defined + name.str + f(using defined2)(body, fv2) + case LetCall(resultNames, defnref, args, body) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) + if extended_scope && !visited.contains(defnref.name) then + val defn = defnref.expectFn + visited.add(defn.name) + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + fv2 = f(using defined2)(defn, fv2) + f(using defined2)(body, fv2) + def run(node: Node) = f(using Set.empty)(node, Set.empty) + def run_with(node: Node, defined: Set[Str]) = f(using defined)(node, Set.empty) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala new file mode 100644 index 000000000..06caaffce --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -0,0 +1,326 @@ +package hkmc2 +package codegen +package llir + +import scala.collection.mutable.ListBuffer + +import mlscript.utils.* +import mlscript.utils.shorthands.* +import hkmc2.document.* +import hkmc2.utils.Scope +import hkmc2.utils.TraceLogger +import hkmc2.Message.MessageContext + +import hkmc2.syntax.Tree +import hkmc2.semantics.* +import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } +import FuncRef.fromName +import hkmc2.codegen.Program + + +def err(msg: Message)(using Raise): Unit = + raise(ErrorReport(msg -> N :: Nil, + source = Diagnostic.Source.Compilation)) + +def errStop(msg: Message)(using Raise) = + err(msg) + throw LowLevelIRError("stopped") + +final case class Ctx( + def_acc: ListBuffer[Func], + class_acc: ListBuffer[ClassInfo], + symbol_ctx: Map[Str, Name] = Map.empty, + fn_ctx: Map[Local, Name] = Map.empty, // is a known function + class_ctx: Map[Local, Name] = Map.empty, + block_ctx: Map[Local, Name] = Map.empty, + is_top_level: Bool = true, +): + def addFuncName(n: Local, m: Name) = copy(fn_ctx = fn_ctx + (n -> m)) + def findFuncName(n: Local)(using Raise) = fn_ctx.get(n) match + case None => + errStop(msg"Function name not found: ${n.toString()}") + Name("error") + case Some(value) => value + def addClassName(n: Local, m: Name) = copy(class_ctx = class_ctx + (n -> m)) + def findClassName(n: Local)(using Raise) = class_ctx.get(n) match + case None => + errStop(msg"Class name not found: ${n.toString()}") + Name("error") + case Some(value) => value + def addName(n: Str, m: Name) = copy(symbol_ctx = symbol_ctx + (n -> m)) + def findName(n: Str)(using Raise): Name = symbol_ctx.get(n) match + case None => + errStop(msg"Name not found: $n") + Name("error") + case Some(value) => value + def reset = + def_acc.clear() + class_acc.clear() + def nonTopLevel = copy(is_top_level = false) + +object Ctx: + val empty = Ctx(ListBuffer.empty, ListBuffer.empty) + + +final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: FreshInt): + import tl.{trace, log} + def er = Expr.Ref + def nr = Node.Result + def nme(x: Str) = Name(x) + def sr(x: Str) = er(Name(x)) + def sr(x: Name) = er(x) + def nsr(xs: Ls[Name]) = xs.map(er(_)) + + private def allocIfNew(l: Local)(using Raise, Scope): String = + trace[Str](s"allocIfNew begin: $l", x => s"allocIfNew end: $x"): + if summon[Scope].lookup(l).isDefined then + getVar_!(l) + else + summon[Scope].allocateName(l) + + private def getVar_!(l: Local)(using Raise, Scope): String = + trace[Str](s"getVar_! begin", x => s"getVar_! end: $x"): + l match + case ts: semantics.TermSymbol => + ts.owner match + case S(owner) => ts.id.name + case N => + ts.id.name + case ts: semantics.BlockMemberSymbol => // this means it's a locally-defined member + ts.nme + case ts: semantics.InnerSymbol => + summon[Scope].findThis_!(ts) + case _ => summon[Scope].lookup_!(l) + + private def bBind(name: Opt[Str], e: Result, body: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bBind begin: $name", x => s"bBind end: ${x.show}"): + bResult(e): + case r: Expr.Ref => + given Ctx = ctx.addName(name.getOrElse(fresh.make.str), r.name) + log(s"bBind ref: $name -> $r") + bBlock(body)(k) + case l: Expr.Literal => + val v = fresh.make + given Ctx = ctx.addName(name.getOrElse(fresh.make.str), v) + log(s"bBind lit: $name -> $v") + Node.LetExpr(v, l, bBlock(body)(k)) + + private def bArgs(e: Ls[Arg])(k: Ls[TrivialExpr] => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bArgs begin", x => s"bArgs end: ${x.show}"): + e match + case Nil => k(Nil) + case Arg(spread, x) :: xs => bPath(x): + case r: TrivialExpr => bArgs(xs): + case rs: Ls[TrivialExpr] => k(r :: rs) + + private def bPaths(e: Ls[Path])(k: Ls[TrivialExpr] => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bArgs begin", x => s"bArgs end: ${x.show}"): + e match + case Nil => k(Nil) + case x :: xs => bPath(x): + case r: TrivialExpr => bPaths(xs): + case rs: Ls[TrivialExpr] => k(r :: rs) + + private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = + trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): + val FunDefn(sym, params, body) = e + if !ctx.is_top_level then + errStop(msg"Non top-level definition ${sym.nme} not supported") + else if params.length != 1 then + errStop(msg"Curried function or zero arguments function not supported: ${params.length.toString}") + else + val paramsList = params.head.params.map(x => x -> summon[Scope].allocateName(x.sym)) + val ctx2 = paramsList.foldLeft(ctx)((acc, x) => acc.addName(getVar_!(x._1.sym), x._2 |> nme)) + val ctx3 = ctx2.nonTopLevel + val pl = paramsList.map(_._2).map(nme) + Func( + fnUid.make, + sym.nme, + params = pl, + resultNum = 1, + body = bBlock(body)(x => Node.Result(Ls(x)))(using ctx3) + ) + + private def bClsLikeDef(e: ClsLikeDefn)(using ctx: Ctx)(using Raise, Scope): ClassInfo = + trace[ClassInfo](s"bClsLikeDef begin", x => s"bClsLikeDef end: ${x.show}"): + val ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e + if !ctx.is_top_level then + errStop(msg"Non top-level definition ${sym.nme} not supported") + else + val clsDefn = sym.defn.getOrElse(die) + val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + val clsFields = publicFields.map(_.sym) + ClassInfo( + clsUid.make, + sym.nme, + clsParams.map(_.nme) ++ clsFields.map(_.nme), + ) + + private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bValue begin", x => s"bValue end: ${x.show}"): + v match + case Value.Ref(l) => k(ctx.findName(getVar_!(l)) |> sr) + case Value.This(sym) => errStop(msg"Unsupported value: This"); Node.Result(Ls()) + case Value.Lit(lit) => k(Expr.Literal(lit)) + case Value.Lam(params, body) => errStop(msg"Unsupported value: Lam"); Node.Result(Ls()) + case Value.Arr(elems) => errStop(msg"Unsupported value: Arr"); Node.Result(Ls()) + + private def getClassOfMem(p: FieldSymbol)(using ctx: Ctx)(using Raise, Scope): Local = + trace[Local](s"bMemSym begin", x => s"bMemSym end: $x"): + p match + case ts: TermSymbol => ts.owner.get + case ms: MemberSymbol[?] => ms.defn.get.sym + + private def bPath(p: Path)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bPath begin", x => s"bPath end: ${x.show}"): + p match + case Select(Value.Ref(_: TopLevelSymbol), name) if name.name.head.isUpper => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), Ls()), k(v |> sr)) + // field selection + case s @ Select(qual, name) => + log(s"bPath Select: $qual.$name with ${s.symbol}") + s.symbol match + case None => + errStop(msg"Unsupported selection by users") + case Some(value) => + bPath(qual): + case q: Expr.Ref => + val v = fresh.make + val cls = ClassRef.fromName(ctx.findClassName(getClassOfMem(s.symbol.get))) + val field = name.name + Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) + case q: Expr.Literal => + errStop(msg"Unsupported select on literal") + Node.Result(Ls()) + case x: Value => bValue(x)(k) + + private def bResult(r: Result)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bResult begin", x => s"bResult end: ${x.show}"): + r match + case Call(Value.Ref(sym: BuiltinSymbol), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.BasicOp(sym.nme, args), k(v |> sr)) + case Call(Select(Value.Ref(_: TopLevelSymbol), name), args) if name.name.head.isUpper => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) + case Call(Select(Value.Ref(_: TopLevelSymbol), name), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetCall(Ls(v), FuncRef.fromName(name.name), args, k(v |> sr)) + case Call(Select(Value.Ref(_: BuiltinSymbol), name), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) + case Call(Value.Ref(name), args) if ctx.fn_ctx.contains(name) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + val fn = ctx.fn_ctx.get(name).get + Node.LetCall(Ls(v), FuncRef.fromName(fn), args, k(v |> sr)) + case Call(fn, args) => + bPath(fn): + case f: TrivialExpr => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetMethodCall(Ls(v), ClassRef(R("Callable")), Name("apply" + args.length), f :: args, k(v |> sr)) + case Instantiate( + Select(Select(Value.Ref(_: TopLevelSymbol), name), Tree.Ident("class")), args) => + bPaths(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) + case Instantiate(cls, args) => + errStop(msg"Unsupported kind of Instantiate") + Node.Result(Ls()) + case x: Path => bPath(x)(k) + + private def bBlock(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bBlock begin", x => s"bBlock end: ${x.show}"): + blk match + case Match(scrut, arms, dflt, rest) => + bPath(scrut): + case e: TrivialExpr => + val jp = fresh.make("j") + val fvset = (rest.freeVars -- rest.definedVars).map(allocIfNew) + val fvs1 = fvset.toList + val new_ctx = fvs1.foldLeft(ctx)((acc, x) => acc.addName(x, fresh.make)) + val fvs = fvs1.map(new_ctx.findName(_)) + def cont(x: TrivialExpr)(using ctx: Ctx) = Node.Jump( + FuncRef.fromName(jp), + fvs1.map(x => ctx.findName(x)).map(sr) + ) + val casesList: Ls[(Pat, Node)] = arms.map: + case (Case.Lit(lit), body) => + (Pat.Lit(lit), bBlock(body)(cont)(using ctx)) + case (Case.Cls(cls, _), body) => + (Pat.Class(ClassRef.fromName(cls.nme)), bBlock(body)(cont)(using ctx)) + case (Case.Tup(len, inf), body) => + (Pat.Class(ClassRef.fromName("Tuple" + len.toString())), bBlock(body)(cont)(using ctx)) + val defaultCase = dflt.map(bBlock(_)(cont)) + val jpdef = Func( + fnUid.make, + jp.str, + params = fvs, + resultNum = 1, + bBlock(rest)(k)(using new_ctx), + ) + summon[Ctx].def_acc += jpdef + Node.Case(e, casesList, defaultCase) + case Return(res, implct) => bResult(res)(x => Node.Result(Ls(x))) + case Throw(Instantiate(Select(Value.Ref(_), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name == "Error" => + Node.Panic(e) + case Label(label, body, rest) => ??? + case Break(label) => TODO("Break not supported") + case Continue(label) => TODO("Continue not supported") + case Begin(sub, rest) => + // re-associate rest blocks to correctly handle the continuation + sub match + case _: BlockTail => + val definedVars = sub.definedVars + definedVars.foreach(allocIfNew) + bBlock(sub): + x => bBlock(rest)(k) + case Assign(lhs, rhs, rest2) => + bBlock(Assign(lhs, rhs, Begin(rest2, rest)))(k) + case Begin(sub, rest2) => + bBlock(Begin(sub, Begin(rest2, rest)))(k) + case Define(defn, rest2) => + bBlock(Define(defn, Begin(rest2, rest)))(k) + case Match(scrut, arms, dflt, rest2) => + bBlock(Match(scrut, arms, dflt, Begin(rest2, rest)))(k) + case _ => TODO(s"Other non-tail sub components of Begin not supported $sub") + case TryBlock(sub, finallyDo, rest) => TODO("TryBlock not supported") + case Assign(lhs, rhs, rest) => + val name = allocIfNew(lhs) + bBind(S(name), rhs, rest)(k) + case AssignField(lhs, nme, rhs, rest) => TODO("AssignField not supported") + case Define(fd @ FunDefn(sym, params, body), rest) => + val f = bFunDef(fd) + ctx.def_acc += f + val new_ctx = ctx.addFuncName(sym, Name(f.name)) + bBlock(rest)(k)(using new_ctx) + case Define(cd @ ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor), rest) => + val c = bClsLikeDef(cd) + ctx.class_acc += c + val new_ctx = ctx.addClassName(sym, Name(c.name)) + bBlock(rest)(k)(using new_ctx) + case End(msg) => k(Expr.Literal(Tree.UnitLit(false))) + case _: Block => + val docBlock = blk.showAsTree + errStop(msg"Unsupported block: $docBlock") + Node.Result(Ls()) + + def bProg(e: Program)(using Raise, Scope): LlirProgram = + val ctx = Ctx.empty + given Ctx = ctx + ctx.reset + val entry = bBlock(e.main)(x => Node.Result(Ls(x))) + LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entry) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala new file mode 100644 index 000000000..0c5688eab --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala @@ -0,0 +1,28 @@ +package hkmc2.codegen.llir + +import collection.mutable.{HashMap => MutHMap} +import mlscript.utils.shorthands._ + +final class Fresh(div : Char = '$'): + private val counter = MutHMap[Str, Int]() + private def gensym(s: Str) = { + val n = s.lastIndexOf(div) + val (ts, suffix) = s.splitAt(if n == -1 then s.length() else n) + var x = if suffix.stripPrefix(div.toString).forall(_.isDigit) then ts else s + val count = counter.getOrElse(x, 0) + val tmp = s"$x$div$count" + counter.update(x, count + 1) + Name(tmp) + } + + def make(s: Str) = gensym(s) + def make = gensym("x") + +final class FreshInt: + private var counter = 0 + def make: Int = { + val tmp = counter + counter += 1 + tmp + } + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala new file mode 100644 index 000000000..4f5416d1f --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala @@ -0,0 +1,211 @@ +package hkmc2.codegen.llir + +import mlscript.* +import mlscript.utils.* +import scala.collection.immutable.* +import scala.collection.mutable.ListBuffer +import shorthands.* +import scala.util.boundary, boundary.break + +import hkmc2.codegen.llir.* +import hkmc2.syntax.Tree + +enum Stuck: + case StuckExpr(expr: Expr, msg: Str) + case StuckNode(node: Node, msg: Str) + + override def toString: String = + this match + case StuckExpr(expr, msg) => s"StuckExpr(${expr.show}, $msg)" + case StuckNode(node, msg) => s"StuckNode(${node.show}, $msg)" + +final case class InterpreterError(message: String) extends Exception(message) + +class Interpreter(verbose: Bool): + private def log(x: Any) = if verbose then println(x) + import Stuck._ + + private case class Configuration( + var context: Ctx + ) + + private type Result[T] = Either[Stuck, T] + + private enum Value: + case Class(cls: ClassInfo, var fields: Ls[Value]) + case Literal(lit: hkmc2.syntax.Literal) + + override def toString: String = + import hkmc2.syntax.Tree.* + this match + case Class(cls, fields) => s"${cls.name}(${fields.mkString(",")})" + case Literal(IntLit(lit)) => lit.toString + case Literal(BoolLit(lit)) => lit.toString + case Literal(DecLit(lit)) => lit.toString + case Literal(StrLit(lit)) => lit.toString + case Literal(UnitLit(undefinedOrNull)) => if undefinedOrNull then "undefined" else "null" + + private final case class Ctx( + bindingCtx: Map[Str, Value], + classCtx: Map[Str, ClassInfo], + funcCtx: Map[Str, Func], + ) + + import Node._ + import Expr._ + + private def getTrue(using ctx: Ctx) = Value.Literal(hkmc2.syntax.Tree.BoolLit(true)) + private def getFalse(using ctx: Ctx) = Value.Literal(hkmc2.syntax.Tree.BoolLit(false)) + + private def eval(op: Str, x1: Value, x2: Value)(using ctx: Ctx): Opt[Value] = + import Value.{Literal => Li} + import hkmc2.syntax.Tree.* + (op, x1, x2) match + case ("+", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x + y))) + case ("-", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x - y))) + case ("*", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x * y))) + case ("/", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x / y))) + case ("==", Li(IntLit(x)), Li(IntLit(y))) => S(if x == y then getTrue else getFalse) + case ("!=", Li(IntLit(x)), Li(IntLit(y))) => S(if x != y then getTrue else getFalse) + case ("<=", Li(IntLit(x)), Li(IntLit(y))) => S(if x <= y then getTrue else getFalse) + case (">=", Li(IntLit(x)), Li(IntLit(y))) => S(if x >= y then getTrue else getFalse) + case (">", Li(IntLit(x)), Li(IntLit(y))) => S(if x > y then getTrue else getFalse) + case ("<", Li(IntLit(x)), Li(IntLit(y))) => S(if x < y then getTrue else getFalse) + case _ => N + + private def evalArgs(using ctx: Ctx)(exprs: Ls[TrivialExpr]): Result[Ls[Value]] = + var values = ListBuffer.empty[Value] + var stuck: Opt[Stuck] = None + exprs foreach { expr => + stuck match + case None => eval(expr) match + case L(x) => stuck = Some(x) + case R(x) => values += x + case _ => () + } + stuck.toLeft(values.toList) + + private def eval(expr: TrivialExpr)(using ctx: Ctx): Result[Value] = expr match + case e @ Ref(name) => ctx.bindingCtx.get(name.str).toRight(StuckExpr(e, s"undefined variable $name")) + case Literal(lit) => R(Value.Literal(lit)) + + private def eval(expr: Expr)(using ctx: Ctx): Result[Value] = expr match + case Ref(Name(x)) => ctx.bindingCtx.get(x).toRight(StuckExpr(expr, s"undefined variable $x")) + case Literal(x) => R(Value.Literal(x)) + case CtorApp(cls, args) => + for + xs <- evalArgs(args) + cls <- ctx.classCtx.get(cls.name).toRight(StuckExpr(expr, s"undefined class ${cls.name}")) + yield Value.Class(cls, xs) + case Select(name, cls, field) => + ctx.bindingCtx.get(name.str).toRight(StuckExpr(expr, s"undefined variable $name")).flatMap { + case Value.Class(cls2, xs) if cls.name == cls2.name => + xs.zip(cls2.fields).find{_._2 == field} match + case Some((x, _)) => R(x) + case None => L(StuckExpr(expr, s"unable to find selected field $field")) + case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) + case x => L(StuckExpr(expr, s"unexpected value $x")) + } + case BasicOp(name, args) => boundary: + evalArgs(args).flatMap( + xs => + name match + case "+" | "-" | "*" | "/" | "==" | "!=" | "<=" | ">=" | "<" | ">" => + if xs.length < 2 then break: + L(StuckExpr(expr, s"not enough arguments for basic operation $name")) + else eval(name, xs.head, xs.tail.head).toRight(StuckExpr(expr, s"unable to evaluate basic operation")) + case _ => L(StuckExpr(expr, s"unexpected basic operation $name"))) + case AssignField(assignee, cls, field, value) => + for + x <- eval(Ref(assignee): TrivialExpr) + y <- eval(value) + res <- x match + case obj @ Value.Class(cls2, xs) if cls.name == cls2.name => + xs.zip(cls2.fields).find{_._2 == field} match + case Some((_, _)) => + obj.fields = xs.map(x => if x == obj then y else x) + // Ideally, we should return a unit value here, but here we return the assignee value for simplicity. + R(obj) + case None => L(StuckExpr(expr, s"unable to find selected field $field")) + case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) + case x => L(StuckExpr(expr, s"unexpected value $x")) + yield res + + private def eval(node: Node)(using ctx: Ctx): Result[Ls[Value]] = node match + case Result(res) => evalArgs(res) + case Jump(func, args) => + for + xs <- evalArgs(args) + func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + res <- eval(func.body)(using ctx1) + yield res + case Case(scrut, cases, default) => + eval(scrut) flatMap { + case Value.Class(cls, fields) => + cases.find { + case (Pat.Class(cls2), _) => cls.name == cls2.name + case _ => false + } match { + case Some((_, x)) => eval(x) + case None => + default match + case S(x) => eval(x) + case N => L(StuckNode(node, s"can not find the matched case, ${cls.name} expected")) + } + case Value.Literal(lit) => + cases.find { + case (Pat.Lit(lit2), _) => lit == lit2 + case _ => false + } match { + case Some((_, x)) => eval(x) + case None => + default match + case S(x) => eval(x) + case N => L(StuckNode(node, s"can not find the matched case, $lit expected")) + } + } + case LetExpr(name, expr, body) => + for + x <- eval(expr) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx + (name.str -> x)) + res <- eval(body)(using ctx1) + yield res + case LetMethodCall(names, cls, method, args, body) => + for + ys <- evalArgs(args).flatMap { + case Value.Class(cls2, xs) :: args => + cls2.methods.get(method.str).toRight(StuckNode(node, s"undefined method ${method.str}")).flatMap { method => + val ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ cls2.fields.zip(xs) ++ method.params.map{_.str}.zip(args)) + eval(method.body)(using ctx1) + } + case _ => L(StuckNode(node, s"not enough arguments for method call, or the first argument is not a class")) + } + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + res <- eval(body)(using ctx2) + yield res + case LetCall(names, func, args, body) => + for + xs <- evalArgs(args) + func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + ys <- eval(func.body)(using ctx1) + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + res <- eval(body)(using ctx2) + yield res + case Panic(msg) => L(StuckNode(node, msg)) + + private def f(prog: Program): Ls[Value] = + val Program(classes, defs, main) = prog + given Ctx = Ctx( + bindingCtx = Map.empty, + classCtx = classes.map{cls => cls.name -> cls}.toMap, + funcCtx = defs.map{func => func.name -> func}.toMap, + ) + eval(main) match + case R(x) => x + case L(x) => throw InterpreterError("Stuck evaluation: " + x.toString) + + def interpret(prog: Program): Str = + val node = f(prog) + node.map(_.toString).mkString(",") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala new file mode 100644 index 000000000..72be21209 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -0,0 +1,214 @@ +package hkmc2.codegen.llir + +import mlscript._ +import mlscript.utils._ +import mlscript.utils.shorthands._ + +import hkmc2.syntax._ +import hkmc2.Message.MessageContext +import hkmc2.document._ + +import util.Sorting +import collection.immutable.SortedSet +import language.implicitConversions +import collection.mutable.{Map as MutMap, Set as MutSet, HashMap, ListBuffer} + +private def raw(x: String): Document = doc"$x" + +final case class LowLevelIRError(message: String) extends Exception(message) + +case class Program( + classes: Set[ClassInfo], + defs: Set[Func], + main: Node, +): + override def toString: String = + val t1 = classes.toArray + val t2 = defs.toArray + Sorting.quickSort(t1) + Sorting.quickSort(t2) + s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$main)" + + def show(hiddenNames: Set[Str] = Set.empty) = toDocument(hiddenNames).toString + def toDocument(hiddenNames: Set[Str] = Set.empty) : Document = + val t1 = classes.toArray + val t2 = defs.toArray + Sorting.quickSort(t1) + Sorting.quickSort(t2) + given Conversion[String, Document] = raw + val docClasses = t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList.mkDocument(doc" # ") + val docDefs = t2.map(_.toDocument).toList.mkDocument(doc" # ") + val docMain = main.toDocument + doc" #{ $docClasses\n$docDefs\n$docMain #} " + +implicit object ClassInfoOrdering extends Ordering[ClassInfo] { + def compare(a: ClassInfo, b: ClassInfo) = a.id.compare(b.id) +} + +case class ClassInfo( + id: Int, + name: Str, + fields: Ls[Str], +): + var parents: Set[Str] = Set.empty + var methods: Map[Str, Func] = Map.empty + override def hashCode: Int = id + override def toString: String = + s"ClassInfo($id, $name, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})" + + def show = toDocument.toString + def toDocument: Document = + given Conversion[String, Document] = raw + val ext = if parents.isEmpty then "" else " extends " + parents.mkString(", ") + if methods.isEmpty then + doc"class $name(${fields.mkString(",")})$ext" + else + val docFirst = doc"class $name (${fields.mkString(",")})$ext {" + val docMethods = methods.map { (_, func) => func.toDocument }.toList.mkDocument(doc" # ") + val docLast = doc"}" + doc"$docFirst #{ # $docMethods # #} $docLast" + +case class Name(str: Str): + def trySubst(map: Map[Str, Name]) = map.getOrElse(str, this) + override def toString: String = str + +object FuncRef: + def fromName(name: Str) = FuncRef(Right(name)) + def fromName(name: Name) = FuncRef(Right(name.str)) + def fromFunc(func: Func) = FuncRef(Left(func)) + +class FuncRef(var func: Either[Func, Str]): + def name: String = func.fold(_.name, x => x) + def expectFn: Func = func.fold(identity, x => throw Exception(s"Expected a def, but got $x")) + def getFunc: Opt[Func] = func.left.toOption + override def equals(o: Any): Bool = o match { + case o: FuncRef => o.name == this.name + case _ => false + } + +object ClassRef: + def fromName(name: Str) = ClassRef(Right(name)) + def fromName(name: Name) = ClassRef(Right(name.str)) + def fromClass(cls: ClassInfo) = ClassRef(Left(cls)) + +class ClassRef(var cls: Either[ClassInfo, Str]): + def name: String = cls.fold(_.name, x => x) + def expectCls: ClassInfo = cls.fold(identity, x => throw Exception(s"Expected a class, but got $x")) + def getClass: Opt[ClassInfo] = cls.left.toOption + override def equals(o: Any): Bool = o match { + case o: ClassRef => o.name == this.name + case _ => false + } + +implicit object FuncOrdering extends Ordering[Func] { + def compare(a: Func, b: Func) = a.id.compare(b.id) +} + +case class Func( + id: Int, + name: Str, + params: Ls[Name], + resultNum: Int, + body: Node +): + var recBoundary: Opt[Int] = None + override def hashCode: Int = id + + override def toString: String = + val ps = params.map(_.toString).mkString("[", ",", "]") + s"Def($id, $name, $ps, \n$resultNum, \n$body\n)" + + def show = toDocument + def toDocument: Document = + given Conversion[String, Document] = raw + val docFirst = doc"def $name(${params.map(_.toString).mkString(",")}) =" + val docBody = body.toDocument + doc"$docFirst #{ # $docBody #} " + +sealed trait TrivialExpr: + import Expr._ + override def toString: String + def show: String + def toDocument: Document + def toExpr: Expr = this match { case x: Expr => x } + +private def showArguments(args: Ls[TrivialExpr]) = args map (_.show) mkString "," + +enum Expr: + case Ref(name: Name) extends Expr, TrivialExpr + case Literal(lit: hkmc2.syntax.Literal) extends Expr, TrivialExpr + case CtorApp(cls: ClassRef, args: Ls[TrivialExpr]) + case Select(name: Name, cls: ClassRef, field: Str) + case BasicOp(name: Str, args: Ls[TrivialExpr]) + case AssignField(assignee: Name, cls: ClassRef, field: Str, value: TrivialExpr) + + override def toString: String = show + + def show: String = toDocument.toString + + def toDocument: Document = + given Conversion[String, Document] = raw + this match + case Ref(s) => s.toString + case Literal(Tree.BoolLit(lit)) => s"$lit" + case Literal(Tree.IntLit(lit)) => s"$lit" + case Literal(Tree.DecLit(lit)) => s"$lit" + case Literal(Tree.StrLit(lit)) => s"$lit" + case Literal(Tree.UnitLit(undefinedOrNull)) => if undefinedOrNull then "undefined" else "null" + case CtorApp(cls, args) => + doc"${cls.name}(${args.map(_.toString).mkString(",")})" + case Select(s, cls, fld) => + doc"${s.toString}.<${cls.name}:$fld>" + case BasicOp(name: Str, args) => + doc"$name(${args.map(_.toString).mkString(",")})" + case AssignField(assignee, clsInfo, fieldName, value) => + doc"${assignee.toString}.${fieldName} := ${value.toString}" + +enum Pat: + case Lit(lit: hkmc2.syntax.Literal) + case Class(cls: ClassRef) + + override def toString: String = this match + case Lit(lit) => s"$lit" + case Class(cls) => s"${cls.name}" + +enum Node: + // Terminal forms: + case Result(res: Ls[TrivialExpr]) + case Jump(func: FuncRef, args: Ls[TrivialExpr]) + case Case(scrutinee: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node]) + case Panic(msg: Str) + // Intermediate forms: + case LetExpr(name: Name, expr: Expr, body: Node) + case LetMethodCall(names: Ls[Name], cls: ClassRef, method: Name, args: Ls[TrivialExpr], body: Node) + case LetCall(names: Ls[Name], func: FuncRef, args: Ls[TrivialExpr], body: Node) + + override def toString: String = show + + def show: String = toDocument.toString + + def toDocument: Document = + given Conversion[String, Document] = raw + this match + case Result(res) => (res |> showArguments) + case Jump(jp, args) => + doc"jump ${jp.name}(${args |> showArguments})" + case Case(x, cases, default) => + val docFirst = doc"case ${x.toString} of" + val docCases = cases.map { + case (pat, node) => doc"${pat.toString} => #{ # ${node.toDocument} #} " + }.mkDocument(doc" # ") + default match + case N => doc"$docFirst #{ # $docCases #} " + case S(dc) => + val docDeft = doc"_ => #{ # ${dc.toDocument} #} " + doc"$docFirst #{ # $docCases # $docDeft #} " + case Panic(msg) => + doc"panic ${s"\"$msg\""}" + case LetExpr(x, expr, body) => + doc"let ${x.toString} = ${expr.toString} in # ${body.toDocument}" + case LetMethodCall(xs, cls, method, args, body) => + doc"let ${xs.map(_.toString).mkString(",")} = ${cls.name}.${method.toString}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" + case LetCall(xs, func, args, body) => + doc"let* (${xs.map(_.toString).mkString(",")}) = ${func.name}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala new file mode 100644 index 000000000..5b6da3eab --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala @@ -0,0 +1,56 @@ +package hkmc2.codegen.llir + +import mlscript.utils.shorthands._ + +import Node._ + +// Resolves the definition references by turning them from Right(name) to Left(Func). +private final class RefResolver(defs: Map[Str, Func], classes: Map[Str, ClassInfo], allowInlineJp: Bool): + private def f(x: Expr): Unit = x match + case Expr.Ref(name) => + case Expr.Literal(lit) => + case Expr.CtorApp(cls, args) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + case Expr.Select(name, cls, field) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + case Expr.BasicOp(name, args) => + case Expr.AssignField(name, cls, field, value) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + + private def f(x: Pat): Unit = x match + case Pat.Lit(lit) => + case Pat.Class(cls) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + + private def f(x: Node): Unit = x match + case Result(res) => + case Case(scrut, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f + case LetExpr(name, expr, body) => f(expr); f(body) + case LetMethodCall(names, cls, method, args, body) => f(body) + case LetCall(resultNames, defnref, args, body) => + defs.get(defnref.name) match + case Some(defn) => defnref.func = Left(defn) + case None => throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") + f(body) + case Jump(defnref, _) => + // maybe not promoted yet + defs.get(defnref.name) match + case Some(defn) => defnref.func = Left(defn) + case None => + if !allowInlineJp then + throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") + case Panic(_) => + def run(node: Node) = f(node) + def run(node: Func) = f(node.body) + +def resolveRef(entry: Node, defs: Set[Func], classes: Set[ClassInfo], allowInlineJp: Bool = false): Unit = + val defsMap = defs.map(x => x.name -> x).toMap + val classesMap = classes.map(x => x.name -> x).toMap + val rl = RefResolver(defsMap, classesMap, allowInlineJp) + rl.run(entry) + defs.foreach(rl.run(_)) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala new file mode 100644 index 000000000..a660f9f07 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala @@ -0,0 +1,45 @@ +package hkmc2.codegen.llir + +import hkmc2.utils._ + +private final class FuncRefInSet(defs: Set[Func], classes: Set[ClassInfo]): + import Node._ + import Expr._ + + private def f(x: Expr): Unit = x match + case Ref(name) => + case Literal(lit) => + case CtorApp(name, args) => + case Select(name, ref, field) => ref.getClass match { + case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") + case _ => + } + case BasicOp(name, args) => + case AssignField(assignee, ref, _, value) => ref.getClass match { + case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") + case _ => + } + + private def f(x: Node): Unit = x match + case Result(res) => + case Jump(func, args) => + case Case(x, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f + case Panic(_) => + case LetExpr(name, expr, body) => f(body) + case LetMethodCall(names, cls, method, args, body) => f(body) + case LetCall(res, ref, args, body) => + ref.getFunc match { + case Some(real_func) => if !defs.exists(_ eq real_func) then throw LowLevelIRError("ref is not in the set") + case _ => + } + f(body) + def run(node: Node) = f(node) + def run(func: Func) = f(func.body) + +def validateRefInSet(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = + val funcRefInSet = FuncRefInSet(defs, classes) + defs.foreach(funcRefInSet.run(_)) + +def validate(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = + validateRefInSet(entry, defs, classes) + diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile new file mode 100644 index 000000000..45aae4802 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile @@ -0,0 +1,27 @@ +CXX := g++ +CFLAGS += -O3 -Wall -Wextra -std=c++20 -I. -Wno-inconsistent-missing-override -I/opt/homebrew/include +LDFLAGS += -L/opt/homebrew/lib +LDLIBS := -lmimalloc -lgmp +SRC := +INCLUDES := mlsprelude.h +DST := +DEFAULT_TARGET := mls +TARGET := $(or $(DST),$(DEFAULT_TARGET)) + +.PHONY: pre all run clean auto + +all: $(TARGET) + +run: $(TARGET) + ./$(TARGET) + +pre: $(SRC) + sed -i '' 's#^//│ ##g' $(SRC) + +clean: + rm -r $(TARGET) $(TARGET).dSYM + +auto: $(TARGET) + +$(TARGET): $(SRC) $(INCLUDES) + $(CXX) $(CFLAGS) $(LDFLAGS) $(SRC) $(LDLIBS) -o $(TARGET) diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h new file mode 100644 index 000000000..fed52549b --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h @@ -0,0 +1,529 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr std::size_t _mlsAlignment = 8; + +template class tuple_type { + template > struct impl; + template struct impl> { + template using wrap = T; + using type = std::tuple...>; + }; + +public: + using type = typename impl<>::type; +}; +template struct counter { + using tag = counter; + + struct generator { + friend consteval auto is_defined(tag) { return true; } + }; + friend consteval auto is_defined(tag); + + template + static consteval auto exists(auto) { + return true; + } + + static consteval auto exists(...) { return generator(), false; } +}; + +template +consteval auto nextTypeTag() { + if constexpr (not counter::exists(Id)) + return Id; + else + return nextTypeTag(); +} + +struct _mlsObject { + uint32_t refCount; + uint32_t tag; + constexpr static inline uint32_t stickyRefCount = + std::numeric_limits::max(); + + void incRef() { + if (refCount != stickyRefCount) + ++refCount; + } + bool decRef() { + if (refCount != stickyRefCount && --refCount == 0) + return true; + return false; + } + + virtual void print() const = 0; + virtual void destroy() = 0; +}; + +class _mlsValue { + using uintptr_t = std::uintptr_t; + using uint64_t = std::uint64_t; + + void *value alignas(_mlsAlignment); + + bool isInt63() const { return (reinterpret_cast(value) & 1) == 1; } + + bool isPtr() const { return (reinterpret_cast(value) & 1) == 0; } + + uint64_t asInt63() const { return reinterpret_cast(value) >> 1; } + + uintptr_t asRawInt() const { return reinterpret_cast(value); } + + static _mlsValue fromRawInt(uintptr_t i) { + return _mlsValue(reinterpret_cast(i)); + } + + static _mlsValue fromInt63(uint64_t i) { + return _mlsValue(reinterpret_cast((i << 1) | 1)); + } + + void *asPtr() const { + assert(!isInt63()); + return value; + } + + _mlsObject *asObject() const { + assert(isPtr()); + return static_cast<_mlsObject *>(value); + } + + bool eqInt63(const _mlsValue &other) const { + return asRawInt() == other.asRawInt(); + } + + _mlsValue addInt63(const _mlsValue &other) const { + return fromRawInt(asRawInt() + other.asRawInt() - 1); + } + + _mlsValue subInt63(const _mlsValue &other) const { + return fromRawInt(asRawInt() - other.asRawInt() + 1); + } + + _mlsValue mulInt63(const _mlsValue &other) const { + return fromInt63(asInt63() * other.asInt63()); + } + + _mlsValue divInt63(const _mlsValue &other) const { + return fromInt63(asInt63() / other.asInt63()); + } + + _mlsValue gtInt63(const _mlsValue &other) const { + return _mlsValue::fromBoolLit(asInt63() > other.asInt63()); + } + + _mlsValue ltInt63(const _mlsValue &other) const { + return _mlsValue::fromBoolLit(asInt63() < other.asInt63()); + } + + _mlsValue geInt63(const _mlsValue &other) const { + return _mlsValue::fromBoolLit(asInt63() >= other.asInt63()); + } + + _mlsValue leInt63(const _mlsValue &other) const { + return _mlsValue::fromBoolLit(asInt63() <= other.asInt63()); + } + +public: + explicit _mlsValue() : value(nullptr) {} + explicit _mlsValue(void *value) : value(value) {} + _mlsValue(const _mlsValue &other) : value(other.value) { + if (isPtr()) + asObject()->incRef(); + } + + _mlsValue &operator=(const _mlsValue &other) { + if (value != nullptr && isPtr()) + asObject()->decRef(); + value = other.value; + if (isPtr()) + asObject()->incRef(); + return *this; + } + + ~_mlsValue() { + if (isPtr()) + if (asObject()->decRef()) { + asObject()->destroy(); + value = nullptr; + } + } + + uint64_t asInt() const { + assert(isInt63()); + return asInt63(); + } + + static _mlsValue fromIntLit(uint64_t i) { return fromInt63(i); } + + static _mlsValue fromBoolLit(bool b) { return fromInt63(b); } + + template static tuple_type<_mlsValue, N> never() { + __builtin_unreachable(); + } + static _mlsValue never() { __builtin_unreachable(); } + + template static _mlsValue create(U... args) { + return _mlsValue(T::create(args...)); + } + + static void destroy(_mlsValue &v) { v.~_mlsValue(); } + + template static bool isValueOf(const _mlsValue &v) { + return v.asObject()->tag == T::typeTag; + } + + static bool isIntLit(const _mlsValue &v, uint64_t n) { + return v.asInt63() == n; + } + + static bool isIntLit(const _mlsValue &v) { return v.isInt63(); } + + template static T *as(const _mlsValue &v) { + return dynamic_cast(v.asObject()); + } + + template static T *cast(_mlsValue &v) { + return static_cast(v.asObject()); + } + + // Operators + + _mlsValue operator==(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return _mlsValue::fromBoolLit(eqInt63(other)); + assert(false); + } + + _mlsValue operator+(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return addInt63(other); + assert(false); + } + + _mlsValue operator-(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return subInt63(other); + assert(false); + } + + _mlsValue operator*(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return mulInt63(other); + assert(false); + } + + _mlsValue operator/(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return divInt63(other); + assert(false); + } + + _mlsValue operator>(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return gtInt63(other); + assert(false); + } + + _mlsValue operator<(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return ltInt63(other); + assert(false); + } + + _mlsValue operator>=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return geInt63(other); + assert(false); + } + + _mlsValue operator<=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return leInt63(other); + assert(false); + } + + // Auxiliary functions + + void print() const { + if (isInt63()) + std::printf("%" PRIu64, asInt63()); + else if (isPtr() && asObject()) + asObject()->print(); + } +}; + +struct _mls_Callable : public _mlsObject { + virtual _mlsValue _mls_apply0() { throw std::runtime_error("Not implemented"); } + virtual _mlsValue _mls_apply1(_mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply2(_mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply3(_mlsValue, _mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply4(_mlsValue, _mlsValue, _mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual void destroy() override {} +}; + +inline static _mls_Callable *_mlsToCallable(_mlsValue fn) { + auto *ptr = _mlsValue::as<_mls_Callable>(fn); + if (!ptr) + throw std::runtime_error("Not a callable object"); + return ptr; +} + +template +inline static _mlsValue _mlsCall(_mlsValue f, U... args) { + static_assert(sizeof...(U) <= 4, "Too many arguments"); + if constexpr (sizeof...(U) == 0) + return _mlsToCallable(f)->_mls_apply0(); + else if constexpr (sizeof...(U) == 1) + return _mlsToCallable(f)->_mls_apply1(args...); + else if constexpr (sizeof...(U) == 2) + return _mlsToCallable(f)->_mls_apply2(args...); + else if constexpr (sizeof...(U) == 3) + return _mlsToCallable(f)->_mls_apply3(args...); + else if constexpr (sizeof...(U) == 4) + return _mlsToCallable(f)->_mls_apply4(args...); +} + +template +inline static T *_mlsMethodCall(_mlsValue self) { + auto *ptr = _mlsValue::as(self); + if (!ptr) + throw std::runtime_error("unable to convert object for method calls"); + return ptr; +} + +inline int _mlsLargeStack(void *(*fn)(void *)) { + pthread_t thread; + pthread_attr_t attr; + + size_t stacksize = 512 * 1024 * 1024; + pthread_attr_init(&attr); + pthread_attr_setstacksize(&attr, stacksize); + + int rc = pthread_create(&thread, &attr, fn, nullptr); + if (rc) { + printf("ERROR: return code from pthread_create() is %d\n", rc); + return 1; + } + pthread_join(thread, NULL); + return 0; +} + +_mlsValue _mlsMain(); + +inline void *_mlsMainWrapper(void *) { + _mlsValue res = _mlsMain(); + res.print(); + return nullptr; +} + +struct _mls_Unit final : public _mlsObject { + constexpr static inline const char *typeName = "Unit"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { std::printf(typeName); } + static _mlsValue create() { + static _mls_Unit mlsUnit alignas(_mlsAlignment); + mlsUnit.refCount = stickyRefCount; + mlsUnit.tag = typeTag; + return _mlsValue(&mlsUnit); + } + virtual void destroy() override {} +}; + +#include + +struct _mls_ZInt final : public _mlsObject { + boost::multiprecision::mpz_int z; + constexpr static inline const char *typeName = "Z"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { + std::printf(typeName); + std::printf("("); + std::printf("%s", z.str().c_str()); + std::printf(")"); + } + virtual void destroy() override { + z.~number(); + operator delete(this, std::align_val_t(_mlsAlignment)); + } + static _mlsValue create() { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + static _mlsValue create(_mlsValue z) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; + _mlsVal->z = z.asInt(); + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + _mlsValue operator+(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z + other.z; + return _mlsVal; + } + + _mlsValue operator-(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z - other.z; + return _mlsVal; + } + + _mlsValue operator*(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z * other.z; + return _mlsVal; + } + + _mlsValue operator/(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z / other.z; + return _mlsVal; + } + + _mlsValue operator%(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z % other.z; + return _mlsVal; + } + + _mlsValue operator==(const _mls_ZInt &other) const { + return _mlsValue::fromBoolLit(z == other.z); + } + + _mlsValue operator>(const _mls_ZInt &other) const { + return _mlsValue::fromBoolLit(z > other.z); + } + + _mlsValue operator<(const _mls_ZInt &other) const { + return _mlsValue::fromBoolLit(z < other.z); + } + + _mlsValue operator>=(const _mls_ZInt &other) const { + return _mlsValue::fromBoolLit(z >= other.z); + } + + _mlsValue operator<=(const _mls_ZInt &other) const { + return _mlsValue::fromBoolLit(z <= other.z); + } + + _mlsValue toInt() const { + return _mlsValue::fromIntLit(z.convert_to()); + } + + static _mlsValue fromInt(uint64_t i) { + return _mlsValue::create<_mls_ZInt>(_mlsValue::fromIntLit(i)); + } +}; + +__attribute__((noinline)) inline void _mlsNonExhaustiveMatch() { + throw std::runtime_error("Non-exhaustive match"); +} + +inline _mlsValue _mls_builtin_z_add(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) + *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_sub(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) - *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_mul(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) * *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_div(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) / *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_mod(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) % *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_equal(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) == *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_gt(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) > *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_lt(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) < *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_geq(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) >= *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_leq(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) <= *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_to_int(_mlsValue a) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + return _mlsValue::cast<_mls_ZInt>(a)->toInt(); +} + +inline _mlsValue _mls_builtin_z_of_int(_mlsValue a) { + assert(_mlsValue::isIntLit(a)); + return _mlsValue::create<_mls_ZInt>(a); +} + +inline _mlsValue _mls_builtin_print(_mlsValue a) { + a.print(); + return _mlsValue::create<_mls_Unit>(); +} + +inline _mlsValue _mls_builtin_println(_mlsValue a) { + a.print(); + std::puts(""); + return _mlsValue::create<_mls_Unit>(); +} + +inline _mlsValue _mls_builtin_debug(_mlsValue a) { + a.print(); + std::puts(""); + return a; +} diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls new file mode 100644 index 000000000..ec2a852a7 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -0,0 +1,27 @@ + +:global +:llir +:cpp + +:ge +fun oops(a) = + class A with + fun m = a + let x = 1 +//│ ═══[COMPILATION ERROR] Non top-level definition A not supported +//│ Stopped due to an error during the Llir generation + +:ge +let x = "oops" +x.m +//│ ═══[COMPILATION ERROR] Unsupported selection by users +//│ Stopped due to an error during the Llir generation + +:ge +fun foo(a) = + let x + if a > 0 do + x = 1 + x + 1 +//│ ═══[COMPILATION ERROR] Name not found: x +//│ Stopped due to an error during the Llir generation diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls new file mode 100644 index 000000000..a0bce1b91 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -0,0 +1,442 @@ +:js +:llir + + +:sllir +abstract class Option[out T]: Some[T] | None +class Some[out T](x: T) extends Option[T] +object None extends Option +fun fromSome(s) = if s is Some(x) then x +class Lazy[out A](init: () -> A) with + mut val cache: Option[A] = None +fun lazy(x) = Lazy(x) +//│ LLIR: +//│ class Option() +//│ class Some(x) +//│ class None() +//│ class Lazy(init,cache) +//│ def fromSome(s) = +//│ case s of +//│ Some => +//│ let x$0 = s. in +//│ x$0 +//│ _ => +//│ panic "match error" +//│ def j$0() = +//│ null +//│ def lazy(x1) = +//│ let x$1 = Lazy(x1) in +//│ x$1 +//│ undefined + +:sllir +fun testCtor1() = None +fun testCtor2() = new None +//│ LLIR: +//│ +//│ def testCtor1() = +//│ let x$0 = None() in +//│ x$0 +//│ def testCtor2() = +//│ let x$1 = None() in +//│ x$1 +//│ undefined + +:sllir +:intl +abstract class Option[out T]: Some[T] | None +class Some[out T](x: T) extends Option[T] +object None extends Option +fun fromSome(s) = if s is Some(x) then x +abstract class Nat: S[Nat] | O +class S(s: Nat) extends Nat +object O extends Nat +fun aaa() = + let m = 1 + let n = 2 + let p = 3 + let q = 4 + m + n - p + q +fun bbb() = + let x = aaa() + x * 100 + 4 +fun not(x) = + if x then false else true +fun foo(x) = + if x then None + else Some(foo(not(x))) +fun main() = + let x = foo(false) + if x is + None then aaa() + Some(b1) then bbb() +main() +//│ = 404 +//│ LLIR: +//│ class Option() +//│ class Some(x) +//│ class None() +//│ class Nat() +//│ class S(s) +//│ class O() +//│ def fromSome(s) = +//│ case s of +//│ Some => +//│ let x$0 = s. in +//│ x$0 +//│ _ => +//│ panic "match error" +//│ def j$0() = +//│ null +//│ def aaa() = +//│ let x$1 = 1 in +//│ let x$2 = 2 in +//│ let x$3 = 3 in +//│ let x$4 = 4 in +//│ let x$5 = +(x$1,x$2) in +//│ let x$6 = -(x$5,x$3) in +//│ let x$7 = +(x$6,x$4) in +//│ x$7 +//│ def bbb() = +//│ let* (x$8) = aaa() in +//│ let x$9 = *(x$8,100) in +//│ let x$10 = +(x$9,4) in +//│ x$10 +//│ def not(x2) = +//│ case x2 of +//│ BoolLit(true) => +//│ false +//│ _ => +//│ true +//│ def j$1() = +//│ null +//│ def foo(x3) = +//│ case x3 of +//│ BoolLit(true) => +//│ let x$11 = None() in +//│ x$11 +//│ _ => +//│ let* (x$12) = not(x3) in +//│ let* (x$13) = foo(x$12) in +//│ let x$14 = Some(x$13) in +//│ x$14 +//│ def j$2() = +//│ null +//│ def main() = +//│ let* (x$15) = foo(false) in +//│ case x$15 of +//│ None => +//│ let* (x$16) = aaa() in +//│ x$16 +//│ _ => +//│ case x$15 of +//│ Some => +//│ let x$17 = x$15. in +//│ let* (x$18) = bbb() in +//│ x$18 +//│ _ => +//│ panic "match error" +//│ def j$4() = +//│ jump j$3() +//│ def j$3() = +//│ null +//│ let* (x$19) = main() in +//│ x$19 +//│ +//│ Interpreted: +//│ 404 + +:sllir +:intl +fun f1() = + let x = 1 + let x = 2 + x +f1() +//│ = 2 +//│ LLIR: +//│ +//│ def f1() = +//│ let x$0 = 1 in +//│ let x$1 = 2 in +//│ x$1 +//│ let* (x$2) = f1() in +//│ x$2 +//│ +//│ Interpreted: +//│ 2 + +:sllir +:intl +fun f2() = + let x = 0 + if x == 1 then 2 else 3 +f2() +//│ = 3 +//│ LLIR: +//│ +//│ def f2() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ 2 +//│ _ => +//│ 3 +//│ def j$0() = +//│ null +//│ let* (x$2) = f2() in +//│ x$2 +//│ +//│ Interpreted: +//│ 3 + + +:sllir +fun f3() = + let x1 = 0 + let x2 = 1 + if true then x1 else x2 +f3() +//│ = 0 +//│ LLIR: +//│ +//│ def f3() = +//│ let x$0 = 0 in +//│ let x$1 = 1 in +//│ let x$2 = true in +//│ case x$2 of +//│ BoolLit(true) => +//│ x$0 +//│ _ => +//│ x$1 +//│ def j$0() = +//│ null +//│ let* (x$3) = f3() in +//│ x$3 + + +:sllir +:intl +fun f4() = + let x = 0 + let x = if x == 1 then 2 else 3 + x +f4() +//│ = 3 +//│ LLIR: +//│ +//│ def f4() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = 2 in +//│ jump j$0(x$3) +//│ _ => +//│ let x$4 = 3 in +//│ jump j$0(x$4) +//│ def j$0(x$2) = +//│ x$2 +//│ let* (x$5) = f4() in +//│ x$5 +//│ +//│ Interpreted: +//│ 3 + +:sllir +:intl +fun f5() = + let x = 0 + let x = if x == 1 then 2 else 3 + let x = if x == 2 then 4 else 5 + x +f5() +//│ = 5 +//│ LLIR: +//│ +//│ def f5() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = 2 in +//│ jump j$0(x$3) +//│ _ => +//│ let x$4 = 3 in +//│ jump j$0(x$4) +//│ def j$0(x$2) = +//│ let x$5 = ==(x$2,2) in +//│ case x$5 of +//│ BoolLit(true) => +//│ let x$7 = 4 in +//│ jump j$1(x$7) +//│ _ => +//│ let x$8 = 5 in +//│ jump j$1(x$8) +//│ def j$1(x$6) = +//│ x$6 +//│ let* (x$9) = f5() in +//│ x$9 +//│ +//│ Interpreted: +//│ 5 + +:sllir +:scpp +fun test() = + if true do test() +//│ LLIR: +//│ +//│ def test() = +//│ let x$0 = true in +//│ case x$0 of +//│ BoolLit(true) => +//│ let* (x$1) = test() in +//│ x$1 +//│ _ => +//│ undefined +//│ def j$0() = +//│ null +//│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(); +//│ _mlsValue _mls_test(); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_test() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { +//│ auto _mls_x_1 = _mls_test(); +//│ _mls_retval = _mls_x_1; +//│ } else { +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + +:sllir +:scpp +fun test() = + (if true then test()) + 1 +//│ LLIR: +//│ +//│ def test() = +//│ let x$0 = true in +//│ case x$0 of +//│ BoolLit(true) => +//│ let* (x$2) = test() in +//│ jump j$0(x$2) +//│ _ => +//│ panic "match error" +//│ def j$0(x$1) = +//│ let x$3 = +(x$1,1) in +//│ x$3 +//│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_test(); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0(_mlsValue _mls_x_1) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_3 = (_mls_x_1 + _mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x_3; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_test() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { +//│ auto _mls_x_2 = _mls_test(); +//│ _mls_retval = _mls_j_0(_mls_x_2); +//│ } else { +//│ throw std::runtime_error("match error"); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + + +:sllir +:intl +:scpp +fun f() = + let x = 10 + if true do + set x += 1 + x +f() +//│ = 11 +//│ LLIR: +//│ +//│ def f() = +//│ let x$0 = 10 in +//│ let x$1 = true in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = +(x$0,1) in +//│ let x$4 = undefined in +//│ jump j$0(x$3) +//│ _ => +//│ let x$5 = undefined in +//│ jump j$0(x$0) +//│ def j$0(x$2) = +//│ x$2 +//│ let* (x$6) = f() in +//│ x$6 +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_f(); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0(_mlsValue _mls_x_2) { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mls_x_2; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_f() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(10); +//│ auto _mls_x_1 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_1, 1)) { +//│ auto _mls_x_3 = (_mls_x_0 + _mlsValue::fromIntLit(1)); +//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_3); +//│ } else { +//│ auto _mls_x_5 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_0); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_6 = _mls_f(); +//│ _mls_retval = _mls_x_6; +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } +//│ +//│ Interpreted: +//│ 11 +