From 986a3cfeb9cba0d5cedb6c297f1f3704de2db111 Mon Sep 17 00:00:00 2001 From: waterlens Date: Fri, 15 Nov 2024 18:41:30 +0800 Subject: [PATCH 01/23] Copy files to new sub project --- .../main/scala/hkmc2/codegen/cpp/Ast.scala | 212 +++++++++++++++ .../scala/hkmc2/codegen/cpp/CodeGen.scala | 238 ++++++++++++++++ .../hkmc2/codegen/cpp/CompilerHost.scala | 44 +++ .../main/scala/hkmc2/codegen/llir/Fresh.scala | 28 ++ .../main/scala/hkmc2/codegen/llir/Llir.scala | 255 ++++++++++++++++++ .../hkmc2/codegen/llir/RefResolver.scala | 55 ++++ .../scala/hkmc2/codegen/llir/Validator.scala | 44 +++ .../hkmc2/utils/document/LegacyDocument.scala | 52 ++++ 8 files changed, 928 insertions(+) create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala new file mode 100644 index 000000000..94787c2fa --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala @@ -0,0 +1,212 @@ +package hkmc2.codegen.cpp + +import mlscript._ +import mlscript.utils._ +import mlscript.utils.shorthands._ + +import hkmc2.utils.legacy_document._ + +import scala.language.implicitConversions + +given Conversion[String, Document] = raw + +enum Specifier: + case Extern + case Static + case Inline + + def toDocument = raw( + this match + case Extern => "extern" + case Static => "static" + case Inline => "inline" + ) + + override def toString: Str = toDocument.print + +object Type: + def toDocuments(args: Ls[Type], sep: Document, extraTypename: Bool = false): Document = + args.iterator.zipWithIndex.map { + case (x, 0) => + x.toDocument(extraTypename) + case (x, _) => + sep <#> x.toDocument(extraTypename) + }.fold(raw(""))(_ <#> _) + + def toDocuments(args: Ls[(Str, Type)], sep: Document): Document = + args.iterator.zipWithIndex.map { + case (x, 0) => + x._2.toDocument() <:> raw(x._1) + case (x, _) => + sep <#> x._2.toDocument() <:> raw(x._1) + }.fold(raw(""))(_ <#> _) + +enum Type: + case Prim(name: Str) + case Ptr(inner: Type) + case Ref(inner: Type) + case Array(inner: Type, size: Opt[Int]) + case FuncPtr(ret: Type, args: List[Type]) + case Struct(name: Str) + case Enum(name: Str) + case Template(name: Str, args: List[Type]) + case Var(name: Str) + case Qualifier(inner: Type, qual: Str) + + def toDocument(extraTypename: Bool = false): Document = + def aux(x: Type): Document = x match + case Prim(name) => name + case Ptr(inner) => aux(inner) <#> "*" + case Ref(inner) => aux(inner) <#> "&" + case Array(inner, size) => aux(inner) <#> "[" <#> size.fold(raw(""))(x => x.toString) <#> "]" + case FuncPtr(ret, args) => aux(ret) <#> "(" <#> Type.toDocuments(args, sep = ", ") <#> ")" + case Struct(name) => s"struct $name" + case Enum(name) => s"enum $name" + case Template(name, args) => s"$name" <#> "<" <#> Type.toDocuments(args, sep = ", ") <#> ">" + case Var(name) => name + case Qualifier(inner, qual) => aux(inner) <:> qual + aux(this) + + override def toString: Str = toDocument().print + +object Stmt: + def toDocuments(decl: Ls[Decl], stmts: Ls[Stmt]): Document = + stack_list(decl.map(_.toDocument) ++ stmts.map(_.toDocument)) + +enum Stmt: + case AutoBind(lhs: Ls[Str], rhs: Expr) + case Assign(lhs: Str, rhs: Expr) + case Return(expr: Expr) + case If(cond: Expr, thenStmt: Stmt, elseStmt: Opt[Stmt]) + case While(cond: Expr, body: Stmt) + case For(init: Stmt, cond: Expr, update: Stmt, body: Stmt) + case ExprStmt(expr: Expr) + case Break + case Continue + case Block(decl: Ls[Decl], stmts: Ls[Stmt]) + case Switch(expr: Expr, cases: Ls[(Expr, Stmt)]) + case Raw(stmt: Str) + + def toDocument: Document = + def aux(x: Stmt): Document = x match + case AutoBind(lhs, rhs) => + lhs match + case Nil => rhs.toDocument + case x :: Nil => "auto" <:> x <:> "=" <:> rhs.toDocument <#> ";" + case _ => "auto" <:> lhs.mkString("[", ",", "]") <:> "=" <:> rhs.toDocument <#> ";" + case Assign(lhs, rhs) => lhs <#> " = " <#> rhs.toDocument <#> ";" + case Return(expr) => "return " <#> expr.toDocument <#> ";" + case If(cond, thenStmt, elseStmt) => + "if (" <#> cond.toDocument <#> ")" <#> thenStmt.toDocument <:> elseStmt.fold(raw(""))(x => "else" <:> x.toDocument) + case While(cond, body) => + "while (" <#> cond.toDocument <#> ")" <#> body.toDocument + case For(init, cond, update, body) => + "for (" <#> init.toDocument <#> "; " <#> cond.toDocument <#> "; " <#> update.toDocument <#> ")" <#> body.toDocument + case ExprStmt(expr) => expr.toDocument <#> ";" + case Break => "break;" + case Continue => "continue;" + case Block(decl, stmts) => + stack( + "{", + Stmt.toDocuments(decl, stmts) |> indent, + "}") + case Switch(expr, cases) => + "switch (" <#> expr.toDocument <#> ")" <#> "{" <#> stack_list(cases.map { + case (cond, stmt) => "case " <#> cond.toDocument <#> ":" <#> stmt.toDocument + }) <#> "}" + case Raw(stmt) => stmt + aux(this) + +object Expr: + def toDocuments(args: Ls[Expr], sep: Document): Document = + args.zipWithIndex.map { + case (x, i) => + if i == 0 then x.toDocument + else sep <#> x.toDocument + }.fold(raw(""))(_ <#> _) + +enum Expr: + case Var(name: Str) + case IntLit(value: BigInt) + case DoubleLit(value: Double) + case StrLit(value: Str) + case CharLit(value: Char) + case Call(func: Expr, args: Ls[Expr]) + case Member(expr: Expr, member: Str) + case Index(expr: Expr, index: Expr) + case Unary(op: Str, expr: Expr) + case Binary(op: Str, lhs: Expr, rhs: Expr) + case Initializer(exprs: Ls[Expr]) + case Constructor(name: Str, init: Expr) + + def toDocument: Document = + def aux(x: Expr): Document = x match + case Var(name) => name + case IntLit(value) => value.toString + case DoubleLit(value) => value.toString + case StrLit(value) => s"\"$value\"" // need more reliable escape utils + case CharLit(value) => value.toInt.toString + case Call(func, args) => aux(func) <#> "(" <#> Expr.toDocuments(args, sep = ", ") <#> ")" + case Member(expr, member) => aux(expr) <#> "->" <#> member + case Index(expr, index) => aux(expr) <#> "[" <#> aux(index) <#> "]" + case Unary(op, expr) => "(" <#> op <#> aux(expr) <#> ")" + case Binary(op, lhs, rhs) => "(" <#> aux(lhs) <#> op <#> aux(rhs) <#> ")" + case Initializer(exprs) => "{" <#> Expr.toDocuments(exprs, sep = ", ") <#> "}" + case Constructor(name, init) => name <#> init.toDocument + aux(this) + +case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): + def toDocument: Document = + stack_list(includes.map(x => raw(x)) ++ decls.map(_.toDocument) ++ defs.map(_.toDocument)) + def toDocumentWithoutHidden: Document = + val hiddenNames = Set( + "HiddenTheseEntities", "True", "False", "Callable", "List", "Cons", "Nil", "Option", "Some", "None", "Pair", "Tuple2", "Tuple3", "Nat", "S", "O" + ) + stack_list(defs.filterNot { + case Def.StructDef(name, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) + case _ => false + }.map(_.toDocument)) + +enum Decl: + case StructDecl(name: Str) + case EnumDecl(name: Str) + case FuncDecl(ret: Type, name: Str, args: Ls[Type]) + case VarDecl(name: Str, typ: Type) + + def toDocument: Document = + def aux(x: Decl): Document = x match + case StructDecl(name) => s"struct $name;" + case EnumDecl(name) => s"enum $name;" + case FuncDecl(ret, name, args) => ret.toDocument() <#> s" $name(" <#> Type.toDocuments(args, sep = ", ") <#> ");" + case VarDecl(name, typ) => typ.toDocument() <#> s" $name;" + aux(this) + +enum Def: + case StructDef(name: Str, fields: Ls[(Str, Type)], inherit: Opt[Ls[Str]], methods: Ls[Def] = Ls.empty) + case EnumDef(name: Str, fields: Ls[(Str, Opt[Int])]) + case FuncDef(specret: Type, name: Str, args: Ls[(Str, Type)], body: Stmt.Block, or: Bool = false, virt: Bool = false) + case VarDef(typ: Type, name: Str, init: Opt[Expr]) + case RawDef(raw: Str) + + def toDocument: Document = + def aux(x: Def): Document = x match + case StructDef(name, fields, inherit, defs) => + stack( + s"struct $name" <#> (if inherit.nonEmpty then ": public" <:> inherit.get.mkString(", ") else "" ) <:> "{", + stack_list(fields.map { + case (name, typ) => typ.toDocument() <#> " " <#> name <#> ";" + }) |> indent, + stack_list(defs.map(_.toDocument)) |> indent, + "};" + ) + case EnumDef(name, fields) => + s"enum $name" <:> "{" <#> stack_list(fields.map { + case (name, value) => value.fold(s"$name")(x => s"$name = $x") + }) <#> "};" + case FuncDef(specret, name, args, body, or, virt) => + (if virt then "virtual " else "") + <#> specret.toDocument() <#> s" $name(" <#> Type.toDocuments(args, sep = ", ") <#> ")" <#> (if or then " override" else "") <#> body.toDocument + case VarDef(typ, name, init) => + typ.toDocument() <#> s" $name" <#> init.fold(raw(""))(x => " = " <#> x.toDocument) <#> raw(";") + case RawDef(x) => x + aux(this) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala new file mode 100644 index 000000000..2ed12302e --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala @@ -0,0 +1,238 @@ +package hkmc2.codegen.cpp + +import mlscript.utils._ +import mlscript.utils.shorthands._ +import scala.collection.mutable.ListBuffer + +import hkmc2.codegen.llir.{Expr => IExpr, _} +import hkmc2.codegen.cpp._ + +def codegen(prog: Program): CompilationUnit = + val codegen = CppCodeGen() + codegen.codegen(prog) + +private class CppCodeGen: + def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_') + def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_') + val freshName = Fresh(div = '_'); + val mlsValType = Type.Prim("_mlsValue") + val mlsUnitValue = Expr.Call(Expr.Var("_mlsValue::create<_mls_Unit>"), Ls()); + val mlsRetValue = "_mls_retval" + val mlsRetValueDecl = Decl.VarDecl(mlsRetValue, mlsValType) + val mlsMainName = "_mlsMain" + val mlsPrelude = "#include \"mlsprelude.h\"" + val mlsPreludeImpl = "#include \"mlsprelude.cpp\"" + val mlsInternalClass = Set("True", "False", "Boolean", "Callable") + val mlsObject = "_mlsObject" + val mlsBuiltin = "builtin" + val mlsEntryPoint = s"int main() { return _mlsLargeStack(_mlsMainWrapper); }"; + def mlsIntLit(x: BigInt) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.IntLit(x))) + def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::fromStrLit"), Ls(Expr.StrLit(x))) + def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x))) + def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args) + def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut)) + def mlsIsIntLit(scrut: Expr, lit: hkmc2.syntax.Tree.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value))) + def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x)) + def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init) + def mlsAs(name: Str, cls: Str) = Expr.Var(s"_mlsValue::as<$cls>($name)") + def mlsAsUnchecked(name: Str, cls: Str) = Expr.Var(s"_mlsValue::cast<$cls>($name)") + def mlsObjectNameMethod(name: Str) = s"constexpr static inline const char *typeName = \"${name}\";" + def mlsTypeTag() = s"constexpr static inline uint32_t typeTag = nextTypeTag();" + def mlsTypeTag(n: Int) = s"constexpr static inline uint32_t typeTag = $n;" + def mlsCommonCreateMethod(cls: Str, fields: Ls[Str], id: Int) = + val parameters = fields.map{x => s"_mlsValue $x"}.mkString(", ") + val fieldsAssignment = fields.map{x => s"_mlsVal->$x = $x; "}.mkString + s"static _mlsValue create($parameters) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) $cls; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; $fieldsAssignment return _mlsValue(_mlsVal); }" + def mlsCommonPrintMethod(fields: Ls[Str]) = + if fields.isEmpty then s"virtual void print() const override { std::printf(\"%s\", typeName); }" + else + val fieldsPrint = fields.map{x => s"this->$x.print(); "}.mkString("std::printf(\", \"); ") + s"virtual void print() const override { std::printf(\"%s\", typeName); std::printf(\"(\"); $fieldsPrint std::printf(\")\"); }" + def mlsCommonDestructorMethod(cls: Str, fields: Ls[Str]) = + val fieldsDeletion = fields.map{x => s"_mlsValue::destroy(this->$x); "}.mkString + s"virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); }" + def mlsThrowNonExhaustiveMatch = Stmt.Raw("_mlsNonExhaustiveMatch();"); + def mlsCall(fn: Str, args: Ls[Expr]) = Expr.Call(Expr.Var("_mlsCall"), Expr.Var(fn) :: args) + def mlsMethodCall(cls: ClassRef, method: Str, args: Ls[Expr]) = + Expr.Call(Expr.Member(Expr.Call(Expr.Var(s"_mlsMethodCall<${cls.name |> mapName}>"), Ls(args.head)), method), args.tail) + def mlsFnWrapperName(fn: Str) = s"_mlsFn_$fn" + def mlsFnCreateMethod(fn: Str) = s"static _mlsValue create() { static _mlsFn_$fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); }" + def mlsNeverValue(n: Int) = if (n <= 1) then Expr.Call(Expr.Var(s"_mlsValue::never"), Ls()) else Expr.Call(Expr.Var(s"_mlsValue::never<$n>"), Ls()) + + case class Ctx( + defnCtx: Set[Str], + ) + + def codegenClassInfo(using ctx: Ctx)(cls: ClassInfo): (Opt[Def], Decl) = + val fields = cls.fields.map{x => (x |> mapName, mlsValType)} + val parents = if cls.parents.nonEmpty then cls.parents.toList.map(mapName) else mlsObject :: Nil + val decl = Decl.StructDecl(cls.name |> mapName) + if mlsInternalClass.contains(cls.name) then return (None, decl) + val theDef = Def.StructDef( + cls.name |> mapName, fields, + if parents.nonEmpty then Some(parents) else None, + Ls(Def.RawDef(mlsObjectNameMethod(cls.name)), + Def.RawDef(mlsTypeTag()), + Def.RawDef(mlsCommonPrintMethod(cls.fields.map(mapName))), + Def.RawDef(mlsCommonDestructorMethod(cls.name |> mapName, cls.fields.map(mapName))), + Def.RawDef(mlsCommonCreateMethod(cls.name |> mapName, cls.fields.map(mapName), cls.id))) + ++ cls.methods.map{case (name, defn) => { + val (theDef, decl) = codegenDefn(using Ctx(ctx.defnCtx + cls.name))(defn) + theDef match + case x @ Def.FuncDef(_, name, _, _, _, _) => x.copy(virt = true) + case _ => theDef + }} + ) + (S(theDef), decl) + + def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using ctx: Ctx): Opt[Expr] = texpr match + case IExpr.Ref(name) => S(Expr.Var(name |> mapName)) + case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => S(mlsIntLit(if x then 1 else 0)) + case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => S(mlsIntLit(x)) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => S(mlsIntLit(x.toBigInt)) + case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => S(mlsStrLit(x)) + case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => if reifyUnit then S(mlsUnitValue) else None + + def toExpr(texpr: TrivialExpr)(using ctx: Ctx): Expr = texpr match + case IExpr.Ref(name) => Expr.Var(name |> mapName) + case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => mlsIntLit(if x then 1 else 0) + case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => mlsIntLit(x) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => mlsIntLit(x.toBigInt) + case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => mlsStrLit(x) + case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => mlsUnitValue + + + def wrapMultiValues(exprs: Ls[TrivialExpr])(using ctx: Ctx): Expr = exprs match + case x :: Nil => toExpr(x, reifyUnit = true).get + case _ => + val init = Expr.Initializer(exprs.map{x => toExpr(x)}) + mlsTupleValue(init) + + def codegenCaseWithIfs(scrut: Name, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + val scrutName = mapName(scrut) + val init: Stmt = + default.fold(mlsThrowNonExhaustiveMatch)(x => { + val (decls2, stmts2) = codegen(x, storeInto)(using Ls.empty, Ls.empty[Stmt]) + Stmt.Block(decls2, stmts2) + }) + val stmt = cases.foldRight(S(init)) { + case ((Pat.Class(cls), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsValueOf(cls.name |> mapName, Expr.Var(scrutName)), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) + case ((Pat.Lit(i @ hkmc2.syntax.Tree.IntLit(_)), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsIntLit(Expr.Var(scrutName), i), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) + case _ => ??? + } + (decls, stmt.fold(stmts)(x => stmts :+ x)) + + def codegenJumpWithCall(func: FuncRef, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + val call = Expr.Call(Expr.Var(func.name |> mapName), args.map(toExpr)) + val stmts2 = stmts ++ Ls(storeInto.fold(Stmt.Return(call))(x => Stmt.Assign(x, call))) + (decls, stmts2) + + def codegenOps(op: Str, args: Ls[TrivialExpr])(using ctx: Ctx) = op match + case "+" => Expr.Binary("+", toExpr(args(0)), toExpr(args(1))) + case "-" => Expr.Binary("-", toExpr(args(0)), toExpr(args(1))) + case "*" => Expr.Binary("*", toExpr(args(0)), toExpr(args(1))) + case "/" => Expr.Binary("/", toExpr(args(0)), toExpr(args(1))) + case "%" => Expr.Binary("%", toExpr(args(0)), toExpr(args(1))) + case "==" => Expr.Binary("==", toExpr(args(0)), toExpr(args(1))) + case "!=" => Expr.Binary("!=", toExpr(args(0)), toExpr(args(1))) + case "<" => Expr.Binary("<", toExpr(args(0)), toExpr(args(1))) + case "<=" => Expr.Binary("<=", toExpr(args(0)), toExpr(args(1))) + case ">" => Expr.Binary(">", toExpr(args(0)), toExpr(args(1))) + case ">=" => Expr.Binary(">=", toExpr(args(0)), toExpr(args(1))) + case "&&" => Expr.Binary("&&", toExpr(args(0)), toExpr(args(1))) + case "||" => Expr.Binary("||", toExpr(args(0)), toExpr(args(1))) + case "!" => Expr.Unary("!", toExpr(args(0))) + case _ => TODO("codegenOps") + + + def codegen(expr: IExpr)(using ctx: Ctx): Expr = expr match + case x @ (IExpr.Ref(_) | IExpr.Literal(_)) => toExpr(x, reifyUnit = true).get + case IExpr.CtorApp(cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr)) + case IExpr.Select(name, cls, field) => Expr.Member(mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName) + case IExpr.BasicOp(name, args) => codegenOps(name, args) + case IExpr.AssignField(assignee, cls, field, value) => TODO("codegen assign field") + + def codegenBuiltin(names: Ls[Name], builtin: Str, args: Ls[TrivialExpr])(using ctx: Ctx): Ls[Stmt] = builtin match + case "error" => Ls(Stmt.Raw("throw std::runtime_error(\"Error\");"), Stmt.AutoBind(names.map(mapName), mlsNeverValue(names.size))) + case _ => Ls(Stmt.AutoBind(names.map(mapName), Expr.Call(Expr.Var("_mls_builtin_" + builtin), args.map(toExpr)))) + + def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = body match + case Node.Result(res) => + val expr = wrapMultiValues(res) + val stmts2 = stmts ++ Ls(Stmt.Assign(storeInto, expr)) + (decls, stmts2) + case Node.Jump(defn, args) => + codegenJumpWithCall(defn, args, S(storeInto)) + case Node.LetExpr(name, expr, body) => + val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> mapName), codegen(expr))) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, IExpr.Ref(Name("builtin")) :: args, body) => + val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, args, body) => + val call = mlsMethodCall(cls, method.str |> mapName, args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetCall(names, defn, args, body) => + val call = Expr.Call(Expr.Var(defn.name |> mapName), args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.Case(scrut, cases, default) => + codegenCaseWithIfs(scrut, cases, default, storeInto) + + def codegenDefn(using ctx: Ctx)(defn: Func): (Def, Decl) = defn match + case Func(id, name, params, resultNum, body) => + val decls = Ls(mlsRetValueDecl) + val stmts = Ls.empty[Stmt] + val (decls2, stmts2) = codegen(body, mlsRetValue)(using decls, stmts) + val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) + val theDef = Def.FuncDef(mlsValType, name |> mapName, params.map(x => (x |> mapName, mlsValType)), Stmt.Block(decls2, stmtsWithReturn)) + val decl = Decl.FuncDecl(mlsValType, name |> mapName, params.map(x => mlsValType)) + (theDef, decl) + + def codegenTopNode(node: Node)(using ctx: Ctx): (Def, Decl) = + val decls = Ls(mlsRetValueDecl) + val stmts = Ls.empty[Stmt] + val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts) + val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) + val theDef = Def.FuncDef(mlsValType, mlsMainName, Ls(), Stmt.Block(decls2, stmtsWithReturn)) + val decl = Decl.FuncDecl(mlsValType, mlsMainName, Ls()) + (theDef, decl) + + // Topological sort of classes based on inheritance relationships + def sortClasses(prog: Program): Ls[ClassInfo] = + var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap + var degree = depgraph.view.mapValues(_.size).toMap + def removeNode(node: Str) = + degree -= node + depgraph -= node + depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap + degree = depgraph.view.mapValues(_.size).toMap + val sorted = ListBuffer.empty[ClassInfo] + var work = degree.filter(_._2 == 0).keys.toSet + while work.nonEmpty do + val node = work.head + work -= node + sorted.addOne(prog.classes.find(_.name == node).get) + removeNode(node) + val next = degree.filter(_._2 == 0).keys + work ++= next + if depgraph.nonEmpty then + val cycle = depgraph.keys.mkString(", ") + throw new Exception(s"Cycle detected in class hierarchy: $cycle") + sorted.toList + + def codegen(prog: Program): CompilationUnit = + val sortedClasses = sortClasses(prog) + val defnCtx = prog.defs.map(_.name) + val (defs, decls) = sortedClasses.map(codegenClassInfo(using Ctx(defnCtx))).unzip + val (defs2, decls2) = prog.defs.map(codegenDefn(using Ctx(defnCtx))).unzip + val (defMain, declMain) = codegenTopNode(prog.main)(using Ctx(defnCtx)) + CompilationUnit(Ls(mlsPrelude), decls ++ decls2 :+ declMain, defs.flatten ++ defs2 :+ defMain :+ Def.RawDef(mlsEntryPoint)) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala new file mode 100644 index 000000000..291f5a0c4 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala @@ -0,0 +1,44 @@ +package hkmc2.codegen.cpp + +import mlscript._ +import mlscript.utils.shorthands._ +import scala.collection.mutable.ListBuffer + +final class CppCompilerHost(val auxPath: Str): + import scala.sys.process._ + private def ifAnyCppCompilerExists(): Boolean = + Seq("g++", "--version").! == 0 || Seq("clang++", "--version").! == 0 + + private def isMakeExists(): Boolean = + import scala.sys.process._ + Seq("make", "--version").! == 0 + + val ready = ifAnyCppCompilerExists() && isMakeExists() + + def compileAndRun(src: Str, output: Str => Unit): Unit = + if !ready then + return + val srcPath = os.temp(contents = src, suffix = ".cpp") + val binPath = os.temp(suffix = ".mls.out") + var stdout = ListBuffer[Str]() + var stderr = ListBuffer[Str]() + val buildLogger = ProcessLogger(stdout :+= _, stderr :+= _) + val buildResult = Seq("make", "-B", "-C", auxPath, "auto", s"SRC=$srcPath", s"DST=$binPath") ! buildLogger + if buildResult != 0 then + output("Compilation failed: ") + for line <- stdout do output(line) + for line <- stderr do output(line) + return + + stdout.clear() + stderr.clear() + val runCmd = Seq(binPath.toString) + val runResult = runCmd ! buildLogger + if runResult != 0 then + output("Execution failed: ") + for line <- stdout do output(line) + for line <- stderr do output(line) + return + + output("Execution succeeded: ") + for line <- stdout do output(line) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala new file mode 100644 index 000000000..0c5688eab --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala @@ -0,0 +1,28 @@ +package hkmc2.codegen.llir + +import collection.mutable.{HashMap => MutHMap} +import mlscript.utils.shorthands._ + +final class Fresh(div : Char = '$'): + private val counter = MutHMap[Str, Int]() + private def gensym(s: Str) = { + val n = s.lastIndexOf(div) + val (ts, suffix) = s.splitAt(if n == -1 then s.length() else n) + var x = if suffix.stripPrefix(div.toString).forall(_.isDigit) then ts else s + val count = counter.getOrElse(x, 0) + val tmp = s"$x$div$count" + counter.update(x, count + 1) + Name(tmp) + } + + def make(s: Str) = gensym(s) + def make = gensym("x") + +final class FreshInt: + private var counter = 0 + def make: Int = { + val tmp = counter + counter += 1 + tmp + } + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala new file mode 100644 index 000000000..bfa2fbe71 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -0,0 +1,255 @@ +package hkmc2.codegen.llir + +import mlscript._ +import mlscript.utils._ +import mlscript.utils.shorthands._ + +import hkmc2.utils.legacy_document._ +import hkmc2.syntax._ + +import util.Sorting +import collection.immutable.SortedSet +import language.implicitConversions +import collection.mutable.{Map as MutMap, Set as MutSet, HashMap, ListBuffer} + +final case class LowLevelIRError(message: String) extends Exception(message) + +case class Program( + classes: Set[ClassInfo], + defs: Set[Func], + main: Node, +): + override def toString: String = + val t1 = classes.toArray + val t2 = defs.toArray + Sorting.quickSort(t1) + Sorting.quickSort(t2) + s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$main)" + + def show(hiddenNames: Set[Str] = Set.empty) = toDocument(hiddenNames).print + def toDocument(hiddenNames: Set[Str] = Set.empty) : Document = + val t1 = classes.toArray + val t2 = defs.toArray + Sorting.quickSort(t1) + Sorting.quickSort(t2) + given Conversion[String, Document] = raw + stack( + "Program:", + stack_list(t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList) |> indent, + stack_list(t2.map(_.toDocument).toList) |> indent, + main.toDocument |> indent + ) + +implicit object ClassInfoOrdering extends Ordering[ClassInfo] { + def compare(a: ClassInfo, b: ClassInfo) = a.id.compare(b.id) +} + +case class ClassInfo( + id: Int, + name: Str, + fields: Ls[Str], +): + var parents: Set[Str] = Set.empty + var methods: Map[Str, Func] = Map.empty + override def hashCode: Int = id + override def toString: String = + s"ClassInfo($id, $name, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})" + + def show = toDocument.print + def toDocument: Document = + given Conversion[String, Document] = raw + val extension = if parents.isEmpty then "" else " extends " + parents.mkString(", ") + if methods.isEmpty then + "class" <:> name <#> "(" <#> fields.mkString(",") <#> ")" <#> extension + else + stack( + "class" <:> name <#> "(" <#> fields.mkString(",") <#> ")" <#> extension <:> "{", + stack_list( methods.map { (_, func) => func.toDocument |> indent }.toList), + "}" + ) + +case class Name(str: Str): + def trySubst(map: Map[Str, Name]) = map.getOrElse(str, this) + override def toString: String = str + +class FuncRef(var func: Either[Func, Str]): + def name: String = func.fold(_.name, x => x) + def expectFn: Func = func.fold(identity, x => throw Exception(s"Expected a def, but got $x")) + def getFunc: Opt[Func] = func.left.toOption + override def equals(o: Any): Bool = o match { + case o: FuncRef => o.name == this.name + case _ => false + } + +class ClassRef(var cls: Either[ClassInfo, Str]): + def name: String = cls.fold(_.name, x => x) + def expectCls: ClassInfo = cls.fold(identity, x => throw Exception(s"Expected a class, but got $x")) + def getClass: Opt[ClassInfo] = cls.left.toOption + override def equals(o: Any): Bool = o match { + case o: ClassRef => o.name == this.name + case _ => false + } + +implicit object FuncOrdering extends Ordering[Func] { + def compare(a: Func, b: Func) = a.id.compare(b.id) +} + +case class Func( + id: Int, + name: Str, + params: Ls[Name], + resultNum: Int, + body: Node +): + var recBoundary: Opt[Int] = None + override def hashCode: Int = id + + override def toString: String = + val ps = params.map(_.toString).mkString("[", ",", "]") + s"Def($id, $name, $ps, \n$resultNum, \n$body\n)" + + def show = toDocument.print + def toDocument: Document = + given Conversion[String, Document] = raw + stack( + "def" <:> name <#> "(" <#> params.map(_.toString).mkString(",") <#> ")" <:> "=", + body.toDocument |> indent + ) + +sealed trait TrivialExpr: + import Expr._ + override def toString: String + def show: String + def toDocument: Document + def toExpr: Expr = this match { case x: Expr => x } + +private def showArguments(args: Ls[TrivialExpr]) = args map (_.show) mkString "," + +enum Expr: + case Ref(name: Name) extends Expr, TrivialExpr + case Literal(lit: hkmc2.syntax.Literal) extends Expr, TrivialExpr + case CtorApp(cls: ClassRef, args: Ls[TrivialExpr]) + case Select(name: Name, cls: ClassRef, field: Str) + case BasicOp(name: Str, args: Ls[TrivialExpr]) + case AssignField(assignee: Name, cls: ClassRef, field: Str, value: TrivialExpr) + + override def toString: String = show + + def show: String = + toDocument.print + + def toDocument: Document = + given Conversion[String, Document] = raw + this match + case Ref(s) => s.toString + case Literal(Tree.BoolLit(lit)) => s"$lit" + case Literal(Tree.IntLit(lit)) => s"$lit" + case Literal(Tree.DecLit(lit)) => s"$lit" + case Literal(Tree.StrLit(lit)) => s"$lit" + case Literal(Tree.UnitLit(undefinedOrNull)) => if undefinedOrNull then "undefined" else "null" + case CtorApp(cls, args) => + cls.name <#> "(" <#> (args |> showArguments) <#> ")" + case Select(s, cls, fld) => + cls.name <#> "." <#> fld <#> "(" <#> s.toString <#> ")" + case BasicOp(name: Str, args) => + name <#> "(" <#> (args |> showArguments) <#> ")" + case AssignField(assignee, clsInfo, fieldName, value) => + stack( + "assign" + <:> (assignee.toString + "." + fieldName) + <:> ":=" + <:> value.toDocument + ) + +enum Pat: + case Lit(lit: hkmc2.syntax.Literal) + case Class(cls: ClassRef) + + def isTrue = this match + case Class(cls) => cls.name == "True" + case _ => false + + def isFalse = this match + case Class(cls) => cls.name == "False" + case _ => false + + override def toString: String = this match + case Lit(lit) => s"$lit" + case Class(cls) => s"${cls.name}" + +enum Node: + // Terminal forms: + case Result(res: Ls[TrivialExpr]) + case Jump(func: FuncRef, args: Ls[TrivialExpr]) + case Case(scrutinee: Name, cases: Ls[(Pat, Node)], default: Opt[Node]) + // Intermediate forms: + case LetExpr(name: Name, expr: Expr, body: Node) + case LetMethodCall(names: Ls[Name], cls: ClassRef, method: Name, args: Ls[TrivialExpr], body: Node) + case LetCall(names: Ls[Name], func: FuncRef, args: Ls[TrivialExpr], body: Node) + + override def toString: String = show + + def show: String = + toDocument.print + + def toDocument: Document = + given Conversion[String, Document] = raw + this match + case Result(res) => (res |> showArguments) + case Jump(jp, args) => + "jump" + <:> jp.name + <#> "(" + <#> (args |> showArguments) + <#> ")" + case Case(x, Ls((true_pat, tru), (false_pat, fls)), N) if true_pat.isTrue && false_pat.isFalse => + val first = "if" <:> x.toString + val tru2 = indent(stack("true" <:> "=>", tru.toDocument |> indent)) + val fls2 = indent(stack("false" <:> "=>", fls.toDocument |> indent)) + Document.Stacked(Ls(first, tru2, fls2)) + case Case(x, cases, default) => + val first = "case" <:> x.toString <:> "of" + val other = cases flatMap { + case (pat, node) => + Ls(pat.toString <:> "=>", node.toDocument |> indent) + } + default match + case N => stack(first, (Document.Stacked(other) |> indent)) + case S(dc) => + val default = Ls("_" <:> "=>", dc.toDocument |> indent) + stack(first, (Document.Stacked(other ++ default) |> indent)) + case LetExpr(x, expr, body) => + stack( + "let" + <:> x.toString + <:> "=" + <:> expr.toDocument + <:> "in", + body.toDocument) + case LetMethodCall(xs, cls, method, args, body) => + stack( + "let" + <:> xs.map(_.toString).mkString(",") + <:> "=" + <:> cls.name + <#> "." + <#> method.toString + <#> "(" + <#> args.map{ x => x.toString }.mkString(",") + <#> ")" + <:> "in", + body.toDocument) + case LetCall(xs, func, args, body) => + stack( + "let*" + <:> "(" + <#> xs.map(_.toString).mkString(",") + <#> ")" + <:> "=" + <:> func.name + <#> "(" + <#> args.map{ x => x.toString }.mkString(",") + <#> ")" + <:> "in", + body.toDocument) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala new file mode 100644 index 000000000..54f0df95b --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala @@ -0,0 +1,55 @@ +package hkmc2.codegen.llir + +import mlscript.utils.shorthands._ + +import Node._ + +// Resolves the definition references by turning them from Right(name) to Left(Func). +private final class RefResolver(defs: Map[Str, Func], classes: Map[Str, ClassInfo], allowInlineJp: Bool): + private def f(x: Expr): Unit = x match + case Expr.Ref(name) => + case Expr.Literal(lit) => + case Expr.CtorApp(cls, args) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + case Expr.Select(name, cls, field) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + case Expr.BasicOp(name, args) => + case Expr.AssignField(name, cls, field, value) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + + private def f(x: Pat): Unit = x match + case Pat.Lit(lit) => + case Pat.Class(cls) => classes.get(cls.name) match + case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") + case Some(value) => cls.cls = Left(value) + + private def f(x: Node): Unit = x match + case Result(res) => + case Case(scrut, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f + case LetExpr(name, expr, body) => f(expr); f(body) + case LetMethodCall(names, cls, method, args, body) => f(body) + case LetCall(resultNames, defnref, args, body) => + defs.get(defnref.name) match + case Some(defn) => defnref.func = Left(defn) + case None => throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") + f(body) + case Jump(defnref, _) => + // maybe not promoted yet + defs.get(defnref.name) match + case Some(defn) => defnref.func = Left(defn) + case None => + if !allowInlineJp then + throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") + def run(node: Node) = f(node) + def run(node: Func) = f(node.body) + +def resolveRef(entry: Node, defs: Set[Func], classes: Set[ClassInfo], allowInlineJp: Bool = false): Unit = + val defsMap = defs.map(x => x.name -> x).toMap + val classesMap = classes.map(x => x.name -> x).toMap + val rl = RefResolver(defsMap, classesMap, allowInlineJp) + rl.run(entry) + defs.foreach(rl.run(_)) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala new file mode 100644 index 000000000..ee1459344 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala @@ -0,0 +1,44 @@ +package hkmc2.codegen.llir + +import hkmc2.utils._ + +private final class FuncRefInSet(defs: Set[Func], classes: Set[ClassInfo]): + import Node._ + import Expr._ + + private def f(x: Expr): Unit = x match + case Ref(name) => + case Literal(lit) => + case CtorApp(name, args) => + case Select(name, ref, field) => ref.getClass match { + case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") + case _ => + } + case BasicOp(name, args) => + case AssignField(assignee, ref, _, value) => ref.getClass match { + case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") + case _ => + } + + private def f(x: Node): Unit = x match + case Result(res) => + case Jump(func, args) => + case Case(x, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f + case LetExpr(name, expr, body) => f(body) + case LetMethodCall(names, cls, method, args, body) => f(body) + case LetCall(res, ref, args, body) => + ref.getFunc match { + case Some(real_func) => if !defs.exists(_ eq real_func) then throw LowLevelIRError("ref is not in the set") + case _ => + } + f(body) + def run(node: Node) = f(node) + def run(func: Func) = f(func.body) + +def validateRefInSet(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = + val funcRefInSet = FuncRefInSet(defs, classes) + defs.foreach(funcRefInSet.run(_)) + +def validate(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = + validateRefInSet(entry, defs, classes) + diff --git a/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala b/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala new file mode 100644 index 000000000..eec3b867f --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala @@ -0,0 +1,52 @@ +package hkmc2.utils.legacy_document + +enum Document: + case Indented(content: Document) + case Unindented(content: Document) + case Stacked(docs: List[Document], emptyLines: Boolean = false) + case Lined(docs: List[Document], separator: Document) + case Raw(s: String) + + def <:>(other: Document) = line(List(this, other)) + def <#>(other: Document) = line(List(this, other), sep = "") + + override def toString: String = print + + def print: String = { + val sb = StringBuffer() + + def rec(d: Document)(implicit ind: Int, first: Boolean): Unit = d match { + case Raw(s) => + if first && s.nonEmpty then sb append (" " * ind) + sb append s + case Indented(doc) => + rec(doc)(ind + 1, first) + case Unindented(doc) => + assume(ind > 0) + rec(doc)(ind - 1, first) + case Lined(Nil, _) => // skip + case Lined(docs, sep) => + rec(docs.head) + docs.tail foreach { doc => + rec(sep)(ind, false) + rec(doc)(ind, false) + } + case Stacked(Nil, _) => // skip + case Stacked(docs, emptyLines) => + rec(docs.head) + docs.tail foreach { doc => + sb append "\n" + if emptyLines then sb append "\n" + rec(doc)(ind, true) + } + } + + rec(this)(0, true) + sb.toString + } + +def stack(docs: Document*) = Document.Stacked(docs.toList) +def stack_list(docs: List[Document]) = Document.Stacked(docs) +def line(docs: List[Document], sep: String = " ") = Document.Lined(docs, Document.Raw(sep)) +def raw(s: String) = Document.Raw(s) +def indent(doc: Document) = Document.Indented(doc) From 92443254230aa43dcb7d150170f54613d7b18a1b Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 7 Jan 2025 20:41:59 +0800 Subject: [PATCH 02/23] Initial Llir builder for `Block` --- .../src/test/scala/hkmc2/DiffTestRunner.scala | 4 +- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 35 +++ .../scala/hkmc2/codegen/cpp/CodeGen.scala | 9 +- .../scala/hkmc2/codegen/llir/Analysis.scala | 125 ++++++++++ .../scala/hkmc2/codegen/llir/Builder.scala | 231 ++++++++++++++++++ .../main/scala/hkmc2/codegen/llir/Llir.scala | 15 +- .../hkmc2/codegen/llir/RefResolver.scala | 1 + .../scala/hkmc2/codegen/llir/Validator.scala | 1 + .../src/test/mlscript/llir/Playground.mls | 164 +++++++++++++ 9 files changed, 576 insertions(+), 9 deletions(-) create mode 100644 hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala create mode 100644 hkmc2/shared/src/test/mlscript/llir/Playground.mls diff --git a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala index 6390cf3c3..8f3c28405 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala @@ -10,9 +10,7 @@ import mlscript.utils._, shorthands._ class MainDiffMaker(val rootPath: Str, val file: os.Path, val preludeFile: os.Path, val predefFile: os.Path, val relativeName: Str) - extends BbmlDiffMaker - - + extends LlirDiffMaker class AllTests extends org.scalatest.Suites( new CompileTestRunner(DiffTestRunner.State){}, diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala new file mode 100644 index 000000000..c639bede8 --- /dev/null +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -0,0 +1,35 @@ +package hkmc2 + +import scala.collection.mutable + +import mlscript.utils.*, shorthands.* +import utils.* + +import document.* +import codegen.Block +import codegen.llir.LlirBuilder +import hkmc2.syntax.Tree.Ident +import hkmc2.codegen.Path +import hkmc2.semantics.Term.Blk +import hkmc2.codegen.llir.Fresh +import hkmc2.codegen.js.Scope +import hkmc2.codegen.llir.Ctx +import hkmc2.codegen.llir._ + +abstract class LlirDiffMaker extends BbmlDiffMaker: + val llir = NullaryCommand("llir") + + override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = + super.processTerm(trm, inImport) + if llir.isSet then + val low = ltl.givenIn: + codegen.Lowering() + val le = low.program(trm) + given Scope = Scope.empty + val fresh = Fresh() + val fuid = FreshInt() + val llb = LlirBuilder(tl)(fresh, fuid) + given Ctx = Ctx.empty + val llirProg = llb.bProg(le) + output(llirProg.show()) + \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala index 2ed12302e..81501fe2d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala @@ -108,8 +108,8 @@ private class CppCodeGen: val init = Expr.Initializer(exprs.map{x => toExpr(x)}) mlsTupleValue(init) - def codegenCaseWithIfs(scrut: Name, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = - val scrutName = mapName(scrut) + def codegenCaseWithIfs(scrut: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + val scrut2 = toExpr(scrut) val init: Stmt = default.fold(mlsThrowNonExhaustiveMatch)(x => { val (decls2, stmts2) = codegen(x, storeInto)(using Ls.empty, Ls.empty[Stmt]) @@ -118,11 +118,11 @@ private class CppCodeGen: val stmt = cases.foldRight(S(init)) { case ((Pat.Class(cls), arm), nextarm) => val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) - val stmt = Stmt.If(mlsIsValueOf(cls.name |> mapName, Expr.Var(scrutName)), Stmt.Block(decls2, stmts2), nextarm) + val stmt = Stmt.If(mlsIsValueOf(cls.name |> mapName, scrut2), Stmt.Block(decls2, stmts2), nextarm) S(stmt) case ((Pat.Lit(i @ hkmc2.syntax.Tree.IntLit(_)), arm), nextarm) => val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) - val stmt = Stmt.If(mlsIsIntLit(Expr.Var(scrutName), i), Stmt.Block(decls2, stmts2), nextarm) + val stmt = Stmt.If(mlsIsIntLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) S(stmt) case _ => ??? } @@ -169,6 +169,7 @@ private class CppCodeGen: (decls, stmts2) case Node.Jump(defn, args) => codegenJumpWithCall(defn, args, S(storeInto)) + case Node.Panic => (decls, stmts :+ Stmt.Raw("throw std::runtime_error(\"Panic\");")) case Node.LetExpr(name, expr, body) => val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> mapName), codegen(expr))) codegen(body, storeInto)(using decls, stmts2) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala new file mode 100644 index 000000000..042295874 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala @@ -0,0 +1,125 @@ +package hkmc2.codegen.llir + +import mlscript._ +import hkmc2.codegen._ +import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } +import mlscript.utils._ +import mlscript.utils.shorthands._ +import hkmc2.semantics.BuiltinSymbol +import hkmc2.syntax.Tree.UnitLit +import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} +import hkmc2.Message.MessageContext +import hkmc2.semantics.InnerSymbol +import hkmc2.codegen.llir.FuncRef.fromName +import scala.collection.mutable.ListBuffer + +import scala.annotation.tailrec +import scala.collection.immutable.* +import scala.collection.mutable.{HashMap => MutHMap} +import scala.collection.mutable.{HashSet => MutHSet, Set => MutSet} + +class UsefulnessAnalysis(verbose: Bool = false): + import Expr._ + import Node._ + + def log(x: Any) = if verbose then println(x) + + val uses = MutHMap[(Name, Int), Int]() + val defs = MutHMap[Name, Int]() + + private def addDef(x: Name) = + defs.update(x, defs.getOrElse(x, 0) + 1) + + private def addUse(x: Name) = + val def_count = defs.get(x) match + case None => throw Exception(s"Use of undefined variable $x") + case Some(value) => value + val key = (x, defs(x)) + uses.update(key, uses.getOrElse(key, 0) + 1) + + private def f(x: TrivialExpr): Unit = x match + case Ref(name) => addUse(name) + case _ => () + + private def f(x: Expr): Unit = x match + case Ref(name) => addUse(name) + case Literal(lit) => + case CtorApp(name, args) => args.foreach(f) + case Select(name, cls, field) => addUse(name) + case BasicOp(name, args) => args.foreach(f) + case AssignField(assignee, _, _, value) => + addUse(assignee) + f(value) + + private def f(x: Node): Unit = x match + case Result(res) => res.foreach(f) + case Jump(defn, args) => args.foreach(f) + case Case(scrut, cases, default) => + scrut match + case Ref(name) => addUse(name) + case _ => () + cases.foreach { case (cls, body) => f(body) }; default.foreach(f) + case LetMethodCall(names, cls, method, args, body) => addUse(method); args.foreach(f); names.foreach(addDef); f(body) + case LetExpr(name, expr, body) => f(expr); addDef(name); f(body) + case LetCall(names, defn, args, body) => args.foreach(f); names.foreach(addDef); f(body) + + def run(x: Func) = + x.params.foreach(addDef) + f(x.body) + uses.toMap + +class FreeVarAnalysis(extended_scope: Bool = true, verbose: Bool = false): + import Expr._ + import Node._ + + private val visited = MutHSet[Str]() + private def f(using defined: Set[Str])(defn: Func, fv: Set[Str]): Set[Str] = + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + f(using defined2)(defn.body, fv) + private def f(using defined: Set[Str])(expr: Expr, fv: Set[Str]): Set[Str] = expr match + case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + case Literal(lit) => fv + case CtorApp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case Select(name, cls, field) => if defined.contains(name.str) then fv else fv + name.str + case BasicOp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case AssignField(assignee, _, _, value) => f(using defined)( + value.toExpr, + if defined.contains(assignee.str) then fv + assignee.str else fv + ) + private def f(using defined: Set[Str])(node: Node, fv: Set[Str]): Set[Str] = node match + case Result(res) => res.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + case Jump(defnref, args) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + if extended_scope && !visited.contains(defnref.name) then + val defn = defnref.expectFn + visited.add(defn.name) + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + fv2 = f(using defined2)(defn, fv2) + fv2 + case Case(scrut, cases, default) => + val fv2 = scrut match + case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + case _ => fv + val fv3 = cases.foldLeft(fv2) { + case (acc, (cls, body)) => f(using defined)(body, acc) + } + fv3 + case LetMethodCall(resultNames, cls, method, args, body) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) + f(using defined2)(body, fv2) + case LetExpr(name, expr, body) => + val fv2 = f(using defined)(expr, fv) + val defined2 = defined + name.str + f(using defined2)(body, fv2) + case LetCall(resultNames, defnref, args, body) => + var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) + if extended_scope && !visited.contains(defnref.name) then + val defn = defnref.expectFn + visited.add(defn.name) + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + fv2 = f(using defined2)(defn, fv2) + f(using defined2)(body, fv2) + def run(node: Node) = f(using Set.empty)(node, Set.empty) + def run_with(node: Node, defined: Set[Str]) = f(using defined)(node, Set.empty) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala new file mode 100644 index 000000000..32a57a138 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -0,0 +1,231 @@ +package hkmc2 +package codegen.llir + +import hkmc2.codegen._ +import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } +import mlscript.utils._ +import mlscript.utils.shorthands._ +import hkmc2.semantics.BuiltinSymbol +import hkmc2.syntax.Tree.UnitLit +import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} +import hkmc2.Message.MessageContext +import hkmc2.codegen.llir.FuncRef.fromName +import scala.collection.mutable.ListBuffer +import hkmc2.codegen.js.Scope +import hkmc2._ +import hkmc2.document._ +import hkmc2.semantics.Elaborator.State +import hkmc2.codegen.Program +import hkmc2.utils.TraceLogger + + +def err(msg: Message)(using Raise): Unit = + raise(ErrorReport(msg -> N :: Nil, + source = Diagnostic.Source.Compilation)) + +final case class Ctx( + symbol_ctx: Map[Str, Name] = Map.empty, + fn_ctx: Map[Local, Name] = Map.empty, // is a known function + closure_ctx: Map[Local, Name] = Map.empty, // closure name + class_ctx: Map[Local, Name] = Map.empty, + block_ctx: Map[Local, Name] = Map.empty, + def_acc: ListBuffer[Func] = ListBuffer.empty, +): + def addName(n: Str, m: Name) = copy(symbol_ctx = symbol_ctx + (n -> m)) + def findName(n: Str)(using Raise): Name = symbol_ctx.get(n) match + case None => + err(msg"Name not found: $n") + Name("error") + case Some(value) => value + +object Ctx: + val empty = Ctx() + + +final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): + import tl.{trace, log} + def er = Expr.Ref + def nr = Node.Result + def nme(x: Str) = Name(x) + def sr(x: Str) = er(Name(x)) + def sr(x: Name) = er(x) + def nsr(xs: Ls[Name]) = xs.map(er(_)) + + private def allocIfNew(l: Local)(using Raise, Scope): String = + trace[Str](s"allocIfNew begin: $l", x => s"allocIfNew end: $x"): + if summon[Scope].lookup(l).isDefined then + getVar_!(l) + else + summon[Scope].allocateName(l) + + private def getVar_!(l: Local)(using Raise, Scope): String = + trace[Str](s"getVar_! begin", x => s"getVar_! end: $x"): + l match + case ts: semantics.TermSymbol => + ts.owner match + case S(owner) => ts.id.name + case N => + ts.id.name + case ts: semantics.BlockMemberSymbol => // this means it's a locally-defined member + ts.nme + // ts.trmTree + case ts: semantics.InnerSymbol => + summon[Scope].findThis_!(ts) + case _ => summon[Scope].lookup_!(l) + + private def bBind(name: Opt[Str], e: Result, body: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bBind begin: $name", x => s"bBind end: ${x.show}"): + bResult(e): + case r: Expr.Ref => + given Ctx = ctx.addName(name.getOrElse(fresh.make.str), r.name) + log(s"bBind ref: $name -> $r") + bBlock(body)(k) + case l: Expr.Literal => + val v = fresh.make + given Ctx = ctx.addName(name.getOrElse(fresh.make.str), v) + log(s"bBind lit: $name -> $v") + Node.LetExpr(v, l, bBlock(body)(k)) + + private def bArgs(e: Ls[Arg])(k: Ls[TrivialExpr] => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bArgs begin", x => s"bArgs end: ${x.show}"): + e match + case Nil => k(Nil) + case Arg(spread, x) :: xs => bPath(x): + case r: TrivialExpr => bArgs(xs): + case rs: Ls[TrivialExpr] => k(r :: rs) + + private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = + val FunDefn(sym, params, body) = e + if params.length != 1 then + err(msg"Unsupported number of parameters: ${params.length.toString}") + val paramsList = params.head.params.map(x => summon[Scope].allocateName(x.sym)).map(Name(_)) + Func( + fnUid.make, + sym.nme, + params = paramsList, + resultNum = 1, + body = bBlock(body)(x => Node.Result(Ls(x))) + ) + + private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bValue begin", x => s"bValue end: ${x.show}"): + v match + case Value.Ref(l) => k(ctx.findName(getVar_!(l)) |> sr) + case Value.This(sym) => err(msg"Unsupported value: This"); Node.Result(Ls()) + case Value.Lit(lit) => k(Expr.Literal(lit)) + case Value.Lam(params, body) => err(msg"Unsupported value: Lam"); Node.Result(Ls()) + case Value.Arr(elems) => err(msg"Unsupported value: Arr"); Node.Result(Ls()) + + private def bPath(p: Path)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bPath begin", x => s"bPath end: ${x.show}"): + p match + case Select(qual, name) => err(msg"Unsupported path: Select"); Node.Result(Ls()) + case x: Value => bValue(x)(k) + + private def bResult(r: Result)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bResult begin", x => s"bResult end: ${x.show}"): + r match + case Call(Value.Ref(sym: BuiltinSymbol), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.BasicOp(sym.nme, args), k(v |> sr)) + case Call(Select(Value.Ref(sym: BuiltinSymbol), name), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) + case Call(Value.Ref(name), args) if ctx.fn_ctx.contains(name) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + val fn = ctx.fn_ctx.get(name).get + Node.LetCall(Ls(v), FuncRef.fromName(fn), args, k(v |> sr)) + case Call(fn, args) => + bPath(fn): + case f: TrivialExpr => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetMethodCall(Ls(v), ClassRef(R("Callable")), Name("apply" + args.length), f :: args, k(v |> sr)) + case Instantiate(cls, args) => ??? + case x: Path => bPath(x)(k) + + private def bBlock(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bBlock begin", x => s"bBlock end: ${x.show}"): + blk match + case Match(scrut, arms, dflt, rest) => + bPath(scrut): + case e: TrivialExpr => + val jp = fresh.make("j") + // guess: the value of Match itself in Block is useless + // val res = fresh.make + val fvset = (rest.freeVars -- rest.definedVars).map(allocIfNew) + val fvs1 = fvset.toList + val new_ctx = fvs1.foldLeft(ctx)((acc, x) => acc.addName(x, fresh.make)) + val fvs = fvs1.map(new_ctx.findName(_)) + def cont(x: TrivialExpr)(using Ctx) = Node.Jump( + FuncRef.fromName(jp), + /* x :: */ fvs1.map(x => summon[Ctx].findName(x)).map(sr) + ) + given Ctx = new_ctx + val casesList: Ls[(Pat, Node)] = arms.map: + case (Case.Lit(lit), body) => + (Pat.Lit(lit), bBlock(body)(cont)) + case (Case.Cls(cls, _), body) => + (Pat.Class(ClassRef.fromName(cls.nme)), bBlock(body)(cont)) + case (Case.Tup(len, inf), body) => + (Pat.Class(ClassRef.fromName("Tuple" + len.toString())), bBlock(body)(cont)) + val defaultCase = dflt.map(bBlock(_)(cont)) + val jpdef = Func( + fnUid.make, + jp.str, + params = /* res :: */ fvs, + resultNum = 1, + bBlock(rest)(k), + ) + summon[Ctx].def_acc += jpdef + Node.Case(e, casesList, defaultCase) + case Return(res, implct) => bResult(res)(x => Node.Result(Ls(x))) + case Throw(exc) => TODO("Throw not supported") + case Label(label, body, rest) => ??? + case Break(label) => TODO("Break not supported") + case Continue(label) => TODO("Continue not supported") + case Begin(sub, rest) => + // re-associate rest blocks to correctly handle the continuation + sub match + case _: BlockTail => + val definedVars = sub.definedVars + definedVars.foreach(allocIfNew) + bBlock(sub): + x => bBlock(rest)(k) + case Assign(lhs, rhs, rest2) => + bBlock(Assign(lhs, rhs, Begin(rest2, rest)))(k) + case Begin(sub, rest2) => + bBlock(Begin(sub, Begin(rest2, rest)))(k) + case Define(defn, rest2) => + bBlock(Define(defn, Begin(rest2, rest)))(k) + case Match(scrut, arms, dflt, rest2) => + bBlock(Match(scrut, arms, dflt, Begin(rest2, rest)))(k) + case _ => TODO(s"Other non-tail sub components of Begin not supported $sub") + case TryBlock(sub, finallyDo, rest) => TODO("TryBlock not supported") + case Assign(lhs, rhs, rest) => + val name = allocIfNew(lhs) + bBind(S(name), rhs, rest)(k) + case AssignField(lhs, nme, rhs, rest) => TODO("AssignField not supported") + case Define(FunDefn(sym, params, body), rest) => + val f = bFunDef(FunDefn(sym, params, body)) + ctx.def_acc += f + bBlock(rest)(k) + case End(msg) => k(Expr.Literal(UnitLit(false))) + case _: Block => + val docBlock = blk.showAsTree + err(msg"Unsupported block: $docBlock") + Node.Result(Ls()) + + def bProg(e: Program)(using Raise, Scope): LlirProgram = + val ctx = Ctx.empty + given Ctx = ctx + ctx.def_acc.clear() + val entry = bBlock(e.main)(x => Node.Result(Ls(x))) + LlirProgram(Set.empty, ctx.def_acc.toSet, entry) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala index bfa2fbe71..0d3a6c3d2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -34,7 +34,6 @@ case class Program( Sorting.quickSort(t2) given Conversion[String, Document] = raw stack( - "Program:", stack_list(t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList) |> indent, stack_list(t2.map(_.toDocument).toList) |> indent, main.toDocument |> indent @@ -72,6 +71,11 @@ case class Name(str: Str): def trySubst(map: Map[Str, Name]) = map.getOrElse(str, this) override def toString: String = str +object FuncRef: + def fromName(name: Str) = FuncRef(Right(name)) + def fromName(name: Name) = FuncRef(Right(name.str)) + def fromFunc(func: Func) = FuncRef(Left(func)) + class FuncRef(var func: Either[Func, Str]): def name: String = func.fold(_.name, x => x) def expectFn: Func = func.fold(identity, x => throw Exception(s"Expected a def, but got $x")) @@ -81,6 +85,11 @@ class FuncRef(var func: Either[Func, Str]): case _ => false } +object ClassRef: + def fromName(name: Str) = ClassRef(Right(name)) + def fromName(name: Name) = ClassRef(Right(name.str)) + def fromClass(cls: ClassInfo) = ClassRef(Left(cls)) + class ClassRef(var cls: Either[ClassInfo, Str]): def name: String = cls.fold(_.name, x => x) def expectCls: ClassInfo = cls.fold(identity, x => throw Exception(s"Expected a class, but got $x")) @@ -181,7 +190,8 @@ enum Node: // Terminal forms: case Result(res: Ls[TrivialExpr]) case Jump(func: FuncRef, args: Ls[TrivialExpr]) - case Case(scrutinee: Name, cases: Ls[(Pat, Node)], default: Opt[Node]) + case Case(scrutinee: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node]) + case Panic // Intermediate forms: case LetExpr(name: Name, expr: Expr, body: Node) case LetMethodCall(names: Ls[Name], cls: ClassRef, method: Name, args: Ls[TrivialExpr], body: Node) @@ -218,6 +228,7 @@ enum Node: case S(dc) => val default = Ls("_" <:> "=>", dc.toDocument |> indent) stack(first, (Document.Stacked(other ++ default) |> indent)) + case Panic => "panic" case LetExpr(x, expr, body) => stack( "let" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala index 54f0df95b..403d210b2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala @@ -43,6 +43,7 @@ private final class RefResolver(defs: Map[Str, Func], classes: Map[Str, ClassInf case None => if !allowInlineJp then throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") + case Panic => def run(node: Node) = f(node) def run(node: Func) = f(node.body) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala index ee1459344..9986835ff 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala @@ -24,6 +24,7 @@ private final class FuncRefInSet(defs: Set[Func], classes: Set[ClassInfo]): case Result(res) => case Jump(func, args) => case Case(x, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f + case Panic => case LetExpr(name, expr, body) => f(body) case LetMethodCall(names, cls, method, args, body) => f(body) case LetCall(res, ref, args, body) => diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls new file mode 100644 index 000000000..21519bd15 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -0,0 +1,164 @@ +:js + + +:llir +fun f1() = + let x = 1 + let x = 2 + x +//│ +//│ def f1() = +//│ let x$0 = 1 in +//│ let x$1 = 2 in +//│ x$1 +//│ undefined + +:slot +:llir +fun f2() = + let x = 0 + if x == 1 then 2 else 3 +//│ Pretty Lowered: +//│ define fun f2() { set x = 0 in set scrut = ==(x, 1) in match scrut true => return 2 else return 3 } in return null +//│ +//│ def f2() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ 2 +//│ _ => +//│ 3 +//│ def j$0() = +//│ null +//│ undefined + +:llir +:slot +fun f3() = + let x1 = 0 + let x2 = 1 + if true then x1 else x2 +//│ Pretty Lowered: +//│ +//│ define fun f3() { +//│ set x1 = 0 in +//│ set x2 = 1 in +//│ set scrut = true in +//│ match scrut +//│ true => +//│ return x1 +//│ else +//│ return x2 +//│ } in +//│ return null +//│ +//│ def f3() = +//│ let x$0 = 0 in +//│ let x$1 = 1 in +//│ let x$2 = true in +//│ case x$2 of +//│ BoolLit(true) => +//│ x$0 +//│ _ => +//│ x$1 +//│ def j$0() = +//│ null +//│ undefined + + +:slot +:llir +fun f4() = + let x = 0 + let x = if x == 1 then 2 else 3 + x +//│ Pretty Lowered: +//│ +//│ define fun f4() { +//│ set x = 0 in +//│ begin +//│ set scrut = ==(x, 1) in +//│ match scrut +//│ true => +//│ set tmp = 2 in +//│ end +//│ else +//│ set tmp = 3 in +//│ end; +//│ set x1 = tmp in +//│ return x1 +//│ } in +//│ return null +//│ +//│ def f4() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = 2 in +//│ jump j$0(x$3) +//│ _ => +//│ let x$4 = 3 in +//│ jump j$0(x$4) +//│ def j$0(x$2) = +//│ x$2 +//│ undefined + +:slot +:llir +fun f5() = + let x = 0 + let x = if x == 1 then 2 else 3 + let x = if x == 2 then 4 else 5 + x +//│ Pretty Lowered: +//│ +//│ define fun f5() { +//│ set x = 0 in +//│ begin +//│ set scrut = ==(x, 1) in +//│ match scrut +//│ true => +//│ set tmp = 2 in +//│ end +//│ else +//│ set tmp = 3 in +//│ end; +//│ set x1 = tmp in +//│ begin +//│ set scrut1 = ==(x1, 2) in +//│ match scrut1 +//│ true => +//│ set tmp1 = 4 in +//│ end +//│ else +//│ set tmp1 = 5 in +//│ end; +//│ set x2 = tmp1 in +//│ return x2 +//│ } in +//│ return null +//│ +//│ def f5() = +//│ let x$0 = 0 in +//│ let x$1 = ==(x$0,1) in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = 2 in +//│ jump j$0(x$3) +//│ _ => +//│ let x$4 = 3 in +//│ jump j$0(x$4) +//│ def j$0(x$2) = +//│ let x$5 = ==(x$2,2) in +//│ case x$5 of +//│ BoolLit(true) => +//│ let x$7 = 4 in +//│ jump j$1(x$7) +//│ _ => +//│ let x$8 = 5 in +//│ jump j$1(x$8) +//│ def j$1(x$6) = +//│ x$6 +//│ undefined From 78c9ab5882d63192107a1869f4e8de439b06ef44 Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 8 Jan 2025 16:17:01 +0800 Subject: [PATCH 03/23] Resolve classes without methods --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 3 +- .../scala/hkmc2/codegen/llir/Builder.scala | 106 ++++++++++++++---- .../main/scala/hkmc2/codegen/llir/Llir.scala | 6 +- .../hkmc2/codegen/llir/RefResolver.scala | 2 +- .../scala/hkmc2/codegen/llir/Validator.scala | 2 +- .../src/test/mlscript/llir/Playground.mls | 35 ++++++ 6 files changed, 124 insertions(+), 30 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index c639bede8..6323c3d1a 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -28,7 +28,8 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: given Scope = Scope.empty val fresh = Fresh() val fuid = FreshInt() - val llb = LlirBuilder(tl)(fresh, fuid) + val cuid = FreshInt() + val llb = LlirBuilder(tl)(fresh, fuid, cuid) given Ctx = Ctx.empty val llirProg = llb.bProg(le) output(llirProg.show()) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 32a57a138..33b22034d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -6,7 +6,7 @@ import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } import mlscript.utils._ import mlscript.utils.shorthands._ import hkmc2.semantics.BuiltinSymbol -import hkmc2.syntax.Tree.UnitLit +import hkmc2.syntax.Tree import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} import hkmc2.Message.MessageContext import hkmc2.codegen.llir.FuncRef.fromName @@ -17,6 +17,9 @@ import hkmc2.document._ import hkmc2.semantics.Elaborator.State import hkmc2.codegen.Program import hkmc2.utils.TraceLogger +import hkmc2.semantics.TermSymbol +import hkmc2.semantics.MemberSymbol +import hkmc2.semantics.FieldSymbol def err(msg: Message)(using Raise): Unit = @@ -24,25 +27,41 @@ def err(msg: Message)(using Raise): Unit = source = Diagnostic.Source.Compilation)) final case class Ctx( + def_acc: ListBuffer[Func], + class_acc: ListBuffer[ClassInfo], symbol_ctx: Map[Str, Name] = Map.empty, fn_ctx: Map[Local, Name] = Map.empty, // is a known function closure_ctx: Map[Local, Name] = Map.empty, // closure name class_ctx: Map[Local, Name] = Map.empty, block_ctx: Map[Local, Name] = Map.empty, - def_acc: ListBuffer[Func] = ListBuffer.empty, ): + def addFuncName(n: Local, m: Name) = copy(fn_ctx = fn_ctx + (n -> m)) + def findFuncName(n: Local)(using Raise) = fn_ctx.get(n) match + case None => + err(msg"Function name not found: ${n.toString()}") + Name("error") + case Some(value) => value + def addClassName(n: Local, m: Name) = copy(class_ctx = class_ctx + (n -> m)) + def findClassName(n: Local)(using Raise) = class_ctx.get(n) match + case None => + err(msg"Class name not found: ${n.toString()}") + Name("error") + case Some(value) => value def addName(n: Str, m: Name) = copy(symbol_ctx = symbol_ctx + (n -> m)) def findName(n: Str)(using Raise): Name = symbol_ctx.get(n) match case None => err(msg"Name not found: $n") Name("error") case Some(value) => value + def reset = + def_acc.clear() + class_acc.clear() object Ctx: - val empty = Ctx() + val empty = Ctx(ListBuffer.empty, ListBuffer.empty) -final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): +final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: FreshInt): import tl.{trace, log} def er = Expr.Ref def nr = Node.Result @@ -95,17 +114,31 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): case rs: Ls[TrivialExpr] => k(r :: rs) private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = - val FunDefn(sym, params, body) = e - if params.length != 1 then - err(msg"Unsupported number of parameters: ${params.length.toString}") - val paramsList = params.head.params.map(x => summon[Scope].allocateName(x.sym)).map(Name(_)) - Func( - fnUid.make, - sym.nme, - params = paramsList, - resultNum = 1, - body = bBlock(body)(x => Node.Result(Ls(x))) - ) + trace[Func](s"bFunDef begin", x => s"bFunDef end: ${x.show}"): + val FunDefn(sym, params, body) = e + if params.length != 1 then + err(msg"Curried function not supported: ${params.length.toString}") + val paramsList = params.head.params.map(x => x -> summon[Scope].allocateName(x.sym)) + val new_ctx = paramsList.foldLeft(ctx)((acc, x) => acc.addName(getVar_!(x._1.sym), x._2 |> nme)) + val pl = paramsList.map(_._2).map(nme) + Func( + fnUid.make, + sym.nme, + params = pl, + resultNum = 1, + body = bBlock(body)(x => Node.Result(Ls(x)))(using new_ctx) + ) + + private def bClsLikeDef(e: ClsLikeDefn)(using ctx: Ctx)(using Raise, Scope): ClassInfo = + trace[ClassInfo](s"bClsLikeDef begin", x => s"bClsLikeDef end: ${x.show}"): + val ClsLikeDefn(sym, kind, methods, privateFields, publicFields, ctor) = e + val clsDefn = sym.defn.getOrElse(die) + val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + ClassInfo( + clsUid.make, + sym.nme, + clsParams.map(_.nme) + ) private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bValue begin", x => s"bValue end: ${x.show}"): @@ -116,10 +149,26 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): case Value.Lam(params, body) => err(msg"Unsupported value: Lam"); Node.Result(Ls()) case Value.Arr(elems) => err(msg"Unsupported value: Arr"); Node.Result(Ls()) + private def getClassOfMem(p: FieldSymbol)(using ctx: Ctx)(using Raise, Scope): Local = + trace[Local](s"bMemSym begin", x => s"bMemSym end: $x"): + p match + case ts: TermSymbol => ts.owner.get + case ms: MemberSymbol[?] => ms.defn.get.sym + private def bPath(p: Path)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bPath begin", x => s"bPath end: ${x.show}"): p match - case Select(qual, name) => err(msg"Unsupported path: Select"); Node.Result(Ls()) + case s @ Select(qual, name) => + log(s"bPath Select: $qual.$name with ${s.symbol.get}") + bPath(qual): + case q: Expr.Ref => + val v = fresh.make + val cls = ClassRef.fromName(ctx.findClassName(getClassOfMem(s.symbol.get))) + val field = name.name + Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) + case q: Expr.Literal => + err(msg"Unsupported select on literal") + Node.Result(Ls()) case x: Value => bValue(x)(k) private def bResult(r: Result)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = @@ -148,7 +197,9 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): case args: Ls[TrivialExpr] => val v = fresh.make Node.LetMethodCall(Ls(v), ClassRef(R("Callable")), Name("apply" + args.length), f :: args, k(v |> sr)) - case Instantiate(cls, args) => ??? + case Instantiate(cls, args) => + err(msg"Unsupported result: Instantiate") + Node.Result(Ls()) case x: Path => bPath(x)(k) private def bBlock(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = @@ -187,7 +238,8 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) case Return(res, implct) => bResult(res)(x => Node.Result(Ls(x))) - case Throw(exc) => TODO("Throw not supported") + case Throw(Instantiate(Select(Value.Ref(globalThis), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name == "Error" => + Node.Panic(e) case Label(label, body, rest) => ??? case Break(label) => TODO("Break not supported") case Continue(label) => TODO("Continue not supported") @@ -213,11 +265,17 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): val name = allocIfNew(lhs) bBind(S(name), rhs, rest)(k) case AssignField(lhs, nme, rhs, rest) => TODO("AssignField not supported") - case Define(FunDefn(sym, params, body), rest) => - val f = bFunDef(FunDefn(sym, params, body)) + case Define(fd @ FunDefn(sym, params, body), rest) => + val f = bFunDef(fd) ctx.def_acc += f - bBlock(rest)(k) - case End(msg) => k(Expr.Literal(UnitLit(false))) + val new_ctx = ctx.addFuncName(sym, Name(f.name)) + bBlock(rest)(k)(using new_ctx) + case Define(cd @ ClsLikeDefn(sym, kind, methods, privateFields, publicFields, ctor), rest) => + val c = bClsLikeDef(cd) + ctx.class_acc += c + val new_ctx = ctx.addClassName(sym, Name(c.name)) + bBlock(rest)(k)(using new_ctx) + case End(msg) => k(Expr.Literal(Tree.UnitLit(false))) case _: Block => val docBlock = blk.showAsTree err(msg"Unsupported block: $docBlock") @@ -226,6 +284,6 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt): def bProg(e: Program)(using Raise, Scope): LlirProgram = val ctx = Ctx.empty given Ctx = ctx - ctx.def_acc.clear() + ctx.reset val entry = bBlock(e.main)(x => Node.Result(Ls(x))) - LlirProgram(Set.empty, ctx.def_acc.toSet, entry) \ No newline at end of file + LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entry) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala index 0d3a6c3d2..41606f64d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -159,7 +159,7 @@ enum Expr: case CtorApp(cls, args) => cls.name <#> "(" <#> (args |> showArguments) <#> ")" case Select(s, cls, fld) => - cls.name <#> "." <#> fld <#> "(" <#> s.toString <#> ")" + s.toString <#> ".<" <#> cls.name <#> ":" <#> fld <#> ">" case BasicOp(name: Str, args) => name <#> "(" <#> (args |> showArguments) <#> ")" case AssignField(assignee, clsInfo, fieldName, value) => @@ -191,7 +191,7 @@ enum Node: case Result(res: Ls[TrivialExpr]) case Jump(func: FuncRef, args: Ls[TrivialExpr]) case Case(scrutinee: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node]) - case Panic + case Panic(msg: Str) // Intermediate forms: case LetExpr(name: Name, expr: Expr, body: Node) case LetMethodCall(names: Ls[Name], cls: ClassRef, method: Name, args: Ls[TrivialExpr], body: Node) @@ -228,7 +228,7 @@ enum Node: case S(dc) => val default = Ls("_" <:> "=>", dc.toDocument |> indent) stack(first, (Document.Stacked(other ++ default) |> indent)) - case Panic => "panic" + case Panic(msg) => "panic" <:> "\"" <#> msg <#> "\"" case LetExpr(x, expr, body) => stack( "let" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala index 403d210b2..5b6da3eab 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala @@ -43,7 +43,7 @@ private final class RefResolver(defs: Map[Str, Func], classes: Map[Str, ClassInf case None => if !allowInlineJp then throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") - case Panic => + case Panic(_) => def run(node: Node) = f(node) def run(node: Func) = f(node.body) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala index 9986835ff..a660f9f07 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala @@ -24,7 +24,7 @@ private final class FuncRefInSet(defs: Set[Func], classes: Set[ClassInfo]): case Result(res) => case Jump(func, args) => case Case(x, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f - case Panic => + case Panic(_) => case LetExpr(name, expr, body) => f(body) case LetMethodCall(names, cls, method, args, body) => f(body) case LetCall(res, ref, args, body) => diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 21519bd15..3126c123e 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -1,5 +1,40 @@ :js +:slot +:llir +abstract class Option[out T]: Some[T] | None +class Some[out T](x: T) extends Option[T] +object None extends Option +fun fromSome(s) = if s is Some(x) then x +//│ Pretty Lowered: +//│ +//│ define class Option in +//│ define class Some in +//│ define class None in +//│ define fun fromSome(s) { +//│ match s +//│ Some => +//│ set param0 = s.x in +//│ set x = param0 in +//│ return x +//│ else +//│ throw new globalThis.Error("match error") +//│ } in +//│ return null +//│ class Option() +//│ class Some(x) +//│ class None() +//│ def fromSome(s) = +//│ case s of +//│ Some => +//│ let x$0 = s. in +//│ x$0 +//│ _ => +//│ panic "match error" +//│ def j$0() = +//│ null +//│ undefined + :llir fun f1() = From c39293de2f70aaf4e2da2a74f05aa0891d59bc5b Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 8 Jan 2025 19:06:46 +0800 Subject: [PATCH 04/23] Improve pretty printer --- .../jvm/src/test/scala/hkmc2/LlirDiffMaker.scala | 2 +- .../src/main/scala/hkmc2/codegen/Printer.scala | 15 ++++++++++++++- .../main/scala/hkmc2/codegen/llir/Builder.scala | 6 +++--- .../shared/src/test/mlscript/llir/Playground.mls | 6 +++--- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 6323c3d1a..1dd115f2b 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -12,7 +12,7 @@ import hkmc2.syntax.Tree.Ident import hkmc2.codegen.Path import hkmc2.semantics.Term.Blk import hkmc2.codegen.llir.Fresh -import hkmc2.codegen.js.Scope +import hkmc2.utils.Scope import hkmc2.codegen.llir.Ctx import hkmc2.codegen.llir._ diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index b86f008df..2015624b4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -64,7 +64,20 @@ object Printer: case ValDefn(owner, k, sym, rhs) => doc"val ${sym.nme} = ${mkDocument(rhs)}" case ClsLikeDefn(sym, k, parentSym, methods, privateFields, publicFields, preCtor, ctor) => - doc"class ${sym.nme} #{ #} " + def optFldBody(t: semantics.TermDefinition) = + t.body match + case Some(x) => doc" = ..." + case None => doc"" + val clsDefn = sym.defn.getOrElse(die) + val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + val ctorParams = clsParams.map(p => summon[Scope].allocateName(p)) + val privFields = privateFields.map(x => doc"let ${x.id.name} = ...").mkDocument(sep = doc" # ") + val pubFields = publicFields.map(x => doc"${x.k.str} ${x.sym.nme}${optFldBody(x)}").mkDocument(sep = doc" # ") + val docPrivFlds = if privateFields.isEmpty then doc"" else doc" # ${privFields}" + val docPubFlds = if publicFields.isEmpty then doc"" else doc" # ${pubFields}" + val docBody = if publicFields.isEmpty && privateFields.isEmpty then doc"" else doc" { #{ ${docPrivFlds}${docPubFlds} #} # }" + val docCtorParams = if clsParams.isEmpty then doc"" else doc"(${ctorParams.mkString(", ")})" + doc"class ${sym.nme}${docCtorParams}${docBody}" def mkDocument(arg: Arg)(using Raise, Scope): Document = val doc = mkDocument(arg.value) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 33b22034d..76759b429 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -11,7 +11,7 @@ import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} import hkmc2.Message.MessageContext import hkmc2.codegen.llir.FuncRef.fromName import scala.collection.mutable.ListBuffer -import hkmc2.codegen.js.Scope +import hkmc2.utils.Scope import hkmc2._ import hkmc2.document._ import hkmc2.semantics.Elaborator.State @@ -131,7 +131,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: private def bClsLikeDef(e: ClsLikeDefn)(using ctx: Ctx)(using Raise, Scope): ClassInfo = trace[ClassInfo](s"bClsLikeDef begin", x => s"bClsLikeDef end: ${x.show}"): - val ClsLikeDefn(sym, kind, methods, privateFields, publicFields, ctor) = e + val ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e val clsDefn = sym.defn.getOrElse(die) val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) ClassInfo( @@ -270,7 +270,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: ctx.def_acc += f val new_ctx = ctx.addFuncName(sym, Name(f.name)) bBlock(rest)(k)(using new_ctx) - case Define(cd @ ClsLikeDefn(sym, kind, methods, privateFields, publicFields, ctor), rest) => + case Define(cd @ ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor), rest) => val c = bClsLikeDef(cd) ctx.class_acc += c val new_ctx = ctx.addClassName(sym, Name(c.name)) diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 3126c123e..667678dad 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -9,14 +9,14 @@ fun fromSome(s) = if s is Some(x) then x //│ Pretty Lowered: //│ //│ define class Option in -//│ define class Some in +//│ define class Some(x) in //│ define class None in //│ define fun fromSome(s) { //│ match s //│ Some => //│ set param0 = s.x in -//│ set x = param0 in -//│ return x +//│ set x1 = param0 in +//│ return x1 //│ else //│ throw new globalThis.Error("match error") //│ } in From a571f82bb268e78fa63d7f9e9004aa47878d46ec Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 8 Jan 2025 20:02:41 +0800 Subject: [PATCH 05/23] Add support for ctor app and public fields in the class --- .../scala/hkmc2/codegen/llir/Builder.scala | 24 +++- .../src/test/mlscript/llir/Playground.mls | 121 +++++++++++++++--- 2 files changed, 125 insertions(+), 20 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 76759b429..6602f5845 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -20,6 +20,7 @@ import hkmc2.utils.TraceLogger import hkmc2.semantics.TermSymbol import hkmc2.semantics.MemberSymbol import hkmc2.semantics.FieldSymbol +import hkmc2.semantics.TopLevelSymbol def err(msg: Message)(using Raise): Unit = @@ -114,7 +115,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case rs: Ls[TrivialExpr] => k(r :: rs) private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = - trace[Func](s"bFunDef begin", x => s"bFunDef end: ${x.show}"): + trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): val FunDefn(sym, params, body) = e if params.length != 1 then err(msg"Curried function not supported: ${params.length.toString}") @@ -134,10 +135,11 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: val ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e val clsDefn = sym.defn.getOrElse(die) val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + val clsFields = publicFields.map(_.sym) ClassInfo( clsUid.make, sym.nme, - clsParams.map(_.nme) + clsParams.map(_.nme) ++ clsFields.map(_.nme), ) private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = @@ -158,8 +160,12 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: private def bPath(p: Path)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bPath begin", x => s"bPath end: ${x.show}"): p match + case Select(Value.Ref(_: TopLevelSymbol), name) if name.name.head.isUpper => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), Ls()), k(v |> sr)) + // field selection case s @ Select(qual, name) => - log(s"bPath Select: $qual.$name with ${s.symbol.get}") + log(s"bPath Select: $qual.$name with ${s.symbol}") bPath(qual): case q: Expr.Ref => val v = fresh.make @@ -179,7 +185,17 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case args: Ls[TrivialExpr] => val v = fresh.make Node.LetExpr(v, Expr.BasicOp(sym.nme, args), k(v |> sr)) - case Call(Select(Value.Ref(sym: BuiltinSymbol), name), args) => + case Call(Select(Value.Ref(_: TopLevelSymbol), name), args) if name.name.head.isUpper => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) + case Call(Select(Value.Ref(_: TopLevelSymbol), name), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetCall(Ls(v), FuncRef.fromName(name.name), args, k(v |> sr)) + case Call(Select(Value.Ref(_: BuiltinSymbol), name), args) => bArgs(args): case args: Ls[TrivialExpr] => val v = fresh.make diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 667678dad..98709105e 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -1,29 +1,17 @@ :js -:slot :llir abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] object None extends Option fun fromSome(s) = if s is Some(x) then x -//│ Pretty Lowered: -//│ -//│ define class Option in -//│ define class Some(x) in -//│ define class None in -//│ define fun fromSome(s) { -//│ match s -//│ Some => -//│ set param0 = s.x in -//│ set x1 = param0 in -//│ return x1 -//│ else -//│ throw new globalThis.Error("match error") -//│ } in -//│ return null +class Lazy[out A](init: () -> A) with + mut val cache: Option[A] = None +fun lazy(x) = Lazy(x) //│ class Option() //│ class Some(x) //│ class None() +//│ class Lazy(init,cache) //│ def fromSome(s) = //│ case s of //│ Some => @@ -33,8 +21,109 @@ fun fromSome(s) = if s is Some(x) then x //│ panic "match error" //│ def j$0() = //│ null +//│ def lazy(x1) = +//│ let x$1 = Lazy(x1) in +//│ x$1 //│ undefined +:llir +abstract class Option[out T]: Some[T] | None +class Some[out T](x: T) extends Option[T] +object None extends Option +fun fromSome(s) = if s is Some(x) then x +abstract class Nat: S[Nat] | O +class S(s: Nat) extends Nat +object O extends Nat +fun aaa() = + let m = 1 + let n = 2 + let p = 3 + let q = 4 + m + n - p + q +fun bbb() = + let x = aaa() + x * 100 + 4 +fun not(x) = + if x then false else true +fun foo(x) = + if x then None + else Some(foo(not(x))) +fun main() = + let x = foo(false) + if x is + None then aaa() + Some(b1) then bbb() +main() +//│ = 404 +//│ class Option() +//│ class Some(x) +//│ class None() +//│ class Nat() +//│ class S(s) +//│ class O() +//│ def fromSome(s) = +//│ case s of +//│ Some => +//│ let x$0 = s. in +//│ x$0 +//│ _ => +//│ panic "match error" +//│ def j$0() = +//│ null +//│ def aaa() = +//│ let x$1 = 1 in +//│ let x$2 = 2 in +//│ let x$3 = 3 in +//│ let x$4 = 4 in +//│ let x$5 = +(x$1,x$2) in +//│ let x$6 = -(x$5,x$3) in +//│ let x$7 = +(x$6,x$4) in +//│ x$7 +//│ def bbb() = +//│ let* (x$8) = aaa() in +//│ let x$9 = *(x$8,100) in +//│ let x$10 = +(x$9,4) in +//│ x$10 +//│ def not(x2) = +//│ case x2 of +//│ BoolLit(true) => +//│ false +//│ _ => +//│ true +//│ def j$1() = +//│ null +//│ def foo(x3) = +//│ case x3 of +//│ BoolLit(true) => +//│ let x$11 = None() in +//│ x$11 +//│ _ => +//│ let* (x$12) = not(x3) in +//│ let* (x$13) = foo(x$12) in +//│ let x$14 = Some(x$13) in +//│ x$14 +//│ def j$2() = +//│ null +//│ def main() = +//│ let* (x$15) = foo(false) in +//│ case x$15 of +//│ None => +//│ let* (x$16) = aaa() in +//│ x$16 +//│ _ => +//│ case x$15 of +//│ Some => +//│ let x$17 = x$15. in +//│ let* (x$18) = bbb() in +//│ x$18 +//│ _ => +//│ panic "match error" +//│ def j$4() = +//│ jump j$3() +//│ def j$3() = +//│ null +//│ let* (x$19) = main() in +//│ x$19 :llir fun f1() = From b1896825d729d3cb130c18515ef9cc5c42476ebc Mon Sep 17 00:00:00 2001 From: waterlens Date: Thu, 9 Jan 2025 20:11:04 +0800 Subject: [PATCH 06/23] Add options for cpp backend --- hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala | 7 ++++++- .../src/main/scala/hkmc2/codegen/cpp/CodeGen.scala | 13 +++++++------ hkmc2/shared/src/test/mlscript/llir/Playground.mls | 7 +++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 1dd115f2b..0f02bd255 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -18,6 +18,7 @@ import hkmc2.codegen.llir._ abstract class LlirDiffMaker extends BbmlDiffMaker: val llir = NullaryCommand("llir") + val scpp = NullaryCommand("scpp") override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = super.processTerm(trm, inImport) @@ -32,5 +33,9 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val llb = LlirBuilder(tl)(fresh, fuid, cuid) given Ctx = Ctx.empty val llirProg = llb.bProg(le) + output("LLIR:") output(llirProg.show()) - \ No newline at end of file + if scpp.isSet then + val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) + output("\nCpp:") + output(cpp.toDocument.print) \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala index 81501fe2d..80f83ae7c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala @@ -7,11 +7,7 @@ import scala.collection.mutable.ListBuffer import hkmc2.codegen.llir.{Expr => IExpr, _} import hkmc2.codegen.cpp._ -def codegen(prog: Program): CompilationUnit = - val codegen = CppCodeGen() - codegen.codegen(prog) - -private class CppCodeGen: +object CppCodeGen: def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_') def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_') val freshName = Fresh(div = '_'); @@ -31,6 +27,7 @@ private class CppCodeGen: def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x))) def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args) def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut)) + def mlsIsBoolLit(scrut: Expr, lit: hkmc2.syntax.Tree.BoolLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(if lit.value then 1 else 0))) def mlsIsIntLit(scrut: Expr, lit: hkmc2.syntax.Tree.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value))) def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x)) def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init) @@ -124,6 +121,10 @@ private class CppCodeGen: val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) val stmt = Stmt.If(mlsIsIntLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) S(stmt) + case ((Pat.Lit(i @ hkmc2.syntax.Tree.BoolLit(_)), arm), nextarm) => + val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) + val stmt = Stmt.If(mlsIsBoolLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) + S(stmt) case _ => ??? } (decls, stmt.fold(stmts)(x => stmts :+ x)) @@ -169,7 +170,7 @@ private class CppCodeGen: (decls, stmts2) case Node.Jump(defn, args) => codegenJumpWithCall(defn, args, S(storeInto)) - case Node.Panic => (decls, stmts :+ Stmt.Raw("throw std::runtime_error(\"Panic\");")) + case Node.Panic(msg) => (decls, stmts :+ Stmt.Raw(s"throw std::runtime_error(\"$msg\");")) case Node.LetExpr(name, expr, body) => val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> mapName), codegen(expr))) codegen(body, storeInto)(using decls, stmts2) diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 98709105e..f10edeb6e 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -8,6 +8,7 @@ fun fromSome(s) = if s is Some(x) then x class Lazy[out A](init: () -> A) with mut val cache: Option[A] = None fun lazy(x) = Lazy(x) +//│ LLIR: //│ class Option() //│ class Some(x) //│ class None() @@ -55,6 +56,7 @@ fun main() = Some(b1) then bbb() main() //│ = 404 +//│ LLIR: //│ class Option() //│ class Some(x) //│ class None() @@ -130,6 +132,7 @@ fun f1() = let x = 1 let x = 2 x +//│ LLIR: //│ //│ def f1() = //│ let x$0 = 1 in @@ -144,6 +147,7 @@ fun f2() = if x == 1 then 2 else 3 //│ Pretty Lowered: //│ define fun f2() { set x = 0 in set scrut = ==(x, 1) in match scrut true => return 2 else return 3 } in return null +//│ LLIR: //│ //│ def f2() = //│ let x$0 = 0 in @@ -176,6 +180,7 @@ fun f3() = //│ return x2 //│ } in //│ return null +//│ LLIR: //│ //│ def f3() = //│ let x$0 = 0 in @@ -214,6 +219,7 @@ fun f4() = //│ return x1 //│ } in //│ return null +//│ LLIR: //│ //│ def f4() = //│ let x$0 = 0 in @@ -263,6 +269,7 @@ fun f5() = //│ return x2 //│ } in //│ return null +//│ LLIR: //│ //│ def f5() = //│ let x$0 = 0 in From b124738be4ad21c7dc9f82e32f8d7027afef1afb Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Fri, 10 Jan 2025 12:31:44 +0800 Subject: [PATCH 07/23] Minor improvements from meeting --- .../scala/hkmc2/codegen/llir/Builder.scala | 35 ++++++------- .../src/test/mlscript/llir/BadPrograms.mls | 50 +++++++++++++++++++ .../src/test/mlscript/llir/Playground.mls | 36 +++++++++++++ 3 files changed, 101 insertions(+), 20 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 6602f5845..8c55ea6d4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -1,38 +1,33 @@ package hkmc2 -package codegen.llir +package codegen +package llir -import hkmc2.codegen._ -import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } -import mlscript.utils._ -import mlscript.utils.shorthands._ -import hkmc2.semantics.BuiltinSymbol -import hkmc2.syntax.Tree -import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} -import hkmc2.Message.MessageContext -import hkmc2.codegen.llir.FuncRef.fromName import scala.collection.mutable.ListBuffer + +import mlscript.utils.* +import mlscript.utils.shorthands.* +import hkmc2.document.* import hkmc2.utils.Scope -import hkmc2._ -import hkmc2.document._ -import hkmc2.semantics.Elaborator.State -import hkmc2.codegen.Program import hkmc2.utils.TraceLogger -import hkmc2.semantics.TermSymbol -import hkmc2.semantics.MemberSymbol -import hkmc2.semantics.FieldSymbol -import hkmc2.semantics.TopLevelSymbol +import hkmc2.Message.MessageContext + +import hkmc2.syntax.Tree +import hkmc2.semantics.* +import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } +import FuncRef.fromName +import hkmc2.codegen.Program def err(msg: Message)(using Raise): Unit = raise(ErrorReport(msg -> N :: Nil, - source = Diagnostic.Source.Compilation)) + source = Diagnostic.Source.Compilation)) final case class Ctx( def_acc: ListBuffer[Func], class_acc: ListBuffer[ClassInfo], symbol_ctx: Map[Str, Name] = Map.empty, fn_ctx: Map[Local, Name] = Map.empty, // is a known function - closure_ctx: Map[Local, Name] = Map.empty, // closure name + closure_ctx: Map[Local, Name] = Map.empty, // closure name // TODO remove – not needed? class_ctx: Map[Local, Name] = Map.empty, block_ctx: Map[Local, Name] = Map.empty, ): diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls new file mode 100644 index 000000000..1424a04f4 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -0,0 +1,50 @@ + +:global +:llir +:scpp + + +// TODO should be rejected +fun oops(a) = + class A with + fun m = a + let x = 1 +//│ LLIR: +//│ class A() +//│ def oops(a) = +//│ let x$0 = 1 in +//│ undefined +//│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ struct _mls_A; +//│ _mlsValue _mls_oops(_mlsValue); +//│ _mlsValue _mlsMain(); +//│ struct _mls_A: public _mlsObject { +//│ +//│ constexpr static inline const char *typeName = "A"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); } +//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_A; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } +//│ }; +//│ _mlsValue _mls_oops(_mlsValue _mls_a){ +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain(){ +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + +:todo // Properly reject +let x = "oops" +x.m +//│ /!!!\ Uncaught error: java.util.NoSuchElementException: None.get + + diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index f10edeb6e..7c3e50a95 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -293,3 +293,39 @@ fun f5() = //│ def j$1(x$6) = //│ x$6 //│ undefined + +:llir +fun test() = + if true do test() +//│ LLIR: +//│ +//│ def test() = +//│ let x$0 = true in +//│ case x$0 of +//│ BoolLit(true) => +//│ let* (x$1) = test() in +//│ x$1 +//│ _ => +//│ undefined +//│ def j$0() = +//│ null +//│ undefined + +:llir +fun test() = + (if true then test()) + 1 +//│ LLIR: +//│ +//│ def test() = +//│ let x$0 = true in +//│ case x$0 of +//│ BoolLit(true) => +//│ let* (x$2) = test() in +//│ jump j$0(x$2) +//│ _ => +//│ panic "match error" +//│ def j$0(x$1) = +//│ let x$3 = +(x$1,1) in +//│ x$3 +//│ undefined + From 2501860eb20142252713cb5761496a4c3d5fc86b Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 14:17:04 +0800 Subject: [PATCH 08/23] Fill blank lines --- hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala index 8f3c28405..9dcc0d903 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala @@ -12,6 +12,9 @@ import mlscript.utils._, shorthands._ class MainDiffMaker(val rootPath: Str, val file: os.Path, val preludeFile: os.Path, val predefFile: os.Path, val relativeName: Str) extends LlirDiffMaker + + + class AllTests extends org.scalatest.Suites( new CompileTestRunner(DiffTestRunner.State){}, new DiffTestRunner(DiffTestRunner.State){}, From 42cffaa5d0abaf0a432bb10e50d902be64ab456b Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 14:18:02 +0800 Subject: [PATCH 09/23] Remove the extra blank line --- hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala index 9dcc0d903..5d770db33 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala @@ -14,7 +14,6 @@ class MainDiffMaker(val rootPath: Str, val file: os.Path, val preludeFile: os.Pa - class AllTests extends org.scalatest.Suites( new CompileTestRunner(DiffTestRunner.State){}, new DiffTestRunner(DiffTestRunner.State){}, From 8ba879e05b2c1e38f538a17dfa8764a8ab5e7849 Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 14:23:20 +0800 Subject: [PATCH 10/23] Remove unused context --- hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 8c55ea6d4..3dcf41e47 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -27,7 +27,6 @@ final case class Ctx( class_acc: ListBuffer[ClassInfo], symbol_ctx: Map[Str, Name] = Map.empty, fn_ctx: Map[Local, Name] = Map.empty, // is a known function - closure_ctx: Map[Local, Name] = Map.empty, // closure name // TODO remove – not needed? class_ctx: Map[Local, Name] = Map.empty, block_ctx: Map[Local, Name] = Map.empty, ): From f577445a09a0f28a0fc84ab22fe24fbdc7710fa9 Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 14:32:40 +0800 Subject: [PATCH 11/23] Add trailing newlines at the end of files --- hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 0f02bd255..858b30702 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -38,4 +38,4 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if scpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) output("\nCpp:") - output(cpp.toDocument.print) \ No newline at end of file + output(cpp.toDocument.print) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala index 94787c2fa..ea46ea07b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala @@ -209,4 +209,4 @@ enum Def: case VarDef(typ, name, init) => typ.toDocument() <#> s" $name" <#> init.fold(raw(""))(x => " = " <#> x.toDocument) <#> raw(";") case RawDef(x) => x - aux(this) \ No newline at end of file + aux(this) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 3dcf41e47..4153065fa 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -296,4 +296,4 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: given Ctx = ctx ctx.reset val entry = bBlock(e.main)(x => Node.Result(Ls(x))) - LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entry) \ No newline at end of file + LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entry) From 6dc013031ab5819c20db32d3ffa54c7c64f22138 Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 16:45:55 +0800 Subject: [PATCH 12/23] Use new pretty printer --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 2 +- .../main/scala/hkmc2/codegen/cpp/Ast.scala | 162 +++++++++--------- .../main/scala/hkmc2/codegen/llir/Llir.scala | 130 +++++--------- .../hkmc2/utils/document/LegacyDocument.scala | 52 ------ .../src/test/mlscript/llir/BadPrograms.mls | 6 +- .../src/test/mlscript/llir/Playground.mls | 14 +- 6 files changed, 135 insertions(+), 231 deletions(-) delete mode 100644 hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 858b30702..08c480cf7 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -38,4 +38,4 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if scpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) output("\nCpp:") - output(cpp.toDocument.print) + output(cpp.toDocument.toString) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala index ea46ea07b..c05e4f1d3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala @@ -4,42 +4,33 @@ import mlscript._ import mlscript.utils._ import mlscript.utils.shorthands._ -import hkmc2.utils.legacy_document._ +import hkmc2.Message.MessageContext +import hkmc2.document._ import scala.language.implicitConversions -given Conversion[String, Document] = raw +private def raw(x: String): Document = doc"$x" +given Conversion[String, Document] = x => doc"$x" enum Specifier: case Extern case Static case Inline - def toDocument = raw( + def toDocument = this match case Extern => "extern" case Static => "static" case Inline => "inline" - ) - override def toString: Str = toDocument.print + override def toString: Str = toDocument object Type: def toDocuments(args: Ls[Type], sep: Document, extraTypename: Bool = false): Document = - args.iterator.zipWithIndex.map { - case (x, 0) => - x.toDocument(extraTypename) - case (x, _) => - sep <#> x.toDocument(extraTypename) - }.fold(raw(""))(_ <#> _) + args.map(_.toDocument(extraTypename)).mkDocument(sep) def toDocuments(args: Ls[(Str, Type)], sep: Document): Document = - args.iterator.zipWithIndex.map { - case (x, 0) => - x._2.toDocument() <:> raw(x._1) - case (x, _) => - sep <#> x._2.toDocument() <:> raw(x._1) - }.fold(raw(""))(_ <#> _) + args.map(x => doc"${x._2.toDocument()} ${x._1}").mkDocument(sep) enum Type: case Prim(name: Str) @@ -56,22 +47,25 @@ enum Type: def toDocument(extraTypename: Bool = false): Document = def aux(x: Type): Document = x match case Prim(name) => name - case Ptr(inner) => aux(inner) <#> "*" - case Ref(inner) => aux(inner) <#> "&" - case Array(inner, size) => aux(inner) <#> "[" <#> size.fold(raw(""))(x => x.toString) <#> "]" - case FuncPtr(ret, args) => aux(ret) <#> "(" <#> Type.toDocuments(args, sep = ", ") <#> ")" - case Struct(name) => s"struct $name" - case Enum(name) => s"enum $name" - case Template(name, args) => s"$name" <#> "<" <#> Type.toDocuments(args, sep = ", ") <#> ">" + case Ptr(inner) => doc"${aux(inner)}*" + case Ref(inner) => doc"${aux(inner)}&" + case Array(inner, size) => + doc"${aux(inner)}[${size.fold("")(x => x.toString)}]" + case FuncPtr(ret, args) => + doc"${aux(ret)}(${Type.toDocuments(args, sep = ", ")})" + case Struct(name) => doc"struct $name" + case Enum(name) => doc"enum $name" + case Template(name, args) => + doc"$name<${Type.toDocuments(args, sep = ", ")}>" case Var(name) => name - case Qualifier(inner, qual) => aux(inner) <:> qual + case Qualifier(inner, qual) => doc"${aux(inner)} $qual" aux(this) - override def toString: Str = toDocument().print + override def toString: Str = toDocument().toString object Stmt: def toDocuments(decl: Ls[Decl], stmts: Ls[Stmt]): Document = - stack_list(decl.map(_.toDocument) ++ stmts.map(_.toDocument)) + (decl.map(_.toDocument) ++ stmts.map(_.toDocument)).mkDocument(doc" # ") enum Stmt: case AutoBind(lhs: Ls[Str], rhs: Expr) @@ -92,38 +86,37 @@ enum Stmt: case AutoBind(lhs, rhs) => lhs match case Nil => rhs.toDocument - case x :: Nil => "auto" <:> x <:> "=" <:> rhs.toDocument <#> ";" - case _ => "auto" <:> lhs.mkString("[", ",", "]") <:> "=" <:> rhs.toDocument <#> ";" - case Assign(lhs, rhs) => lhs <#> " = " <#> rhs.toDocument <#> ";" - case Return(expr) => "return " <#> expr.toDocument <#> ";" + case x :: Nil => + doc"auto $x = ${rhs.toDocument};" + case _ => + doc"auto [${lhs.mkDocument(", ")}] = ${rhs.toDocument};" + case Assign(lhs, rhs) => + doc"$lhs = ${rhs.toDocument};" + case Return(expr) => + doc"return ${expr.toDocument};" case If(cond, thenStmt, elseStmt) => - "if (" <#> cond.toDocument <#> ")" <#> thenStmt.toDocument <:> elseStmt.fold(raw(""))(x => "else" <:> x.toDocument) + doc"if (${cond.toDocument}) ${thenStmt.toDocument}${elseStmt.fold(doc" ")(x => doc" else ${x.toDocument}")}" case While(cond, body) => - "while (" <#> cond.toDocument <#> ")" <#> body.toDocument + doc"while (${cond.toDocument}) ${body.toDocument}" case For(init, cond, update, body) => - "for (" <#> init.toDocument <#> "; " <#> cond.toDocument <#> "; " <#> update.toDocument <#> ")" <#> body.toDocument - case ExprStmt(expr) => expr.toDocument <#> ";" + doc"for (${init.toDocument}; ${cond.toDocument}; ${update.toDocument}) ${body.toDocument}" + case ExprStmt(expr) => + doc"${expr.toDocument};" case Break => "break;" case Continue => "continue;" case Block(decl, stmts) => - stack( - "{", - Stmt.toDocuments(decl, stmts) |> indent, - "}") + doc"{ #{ # ${Stmt.toDocuments(decl, stmts)} #} # }" case Switch(expr, cases) => - "switch (" <#> expr.toDocument <#> ")" <#> "{" <#> stack_list(cases.map { - case (cond, stmt) => "case " <#> cond.toDocument <#> ":" <#> stmt.toDocument - }) <#> "}" + val docCases = cases.map { + case (cond, stmt) => doc"case ${cond.toDocument}: ${stmt.toDocument}" + }.mkDocument(doc" # ") + doc"switch (${expr.toDocument}) { #{ # ${docCases} #} # }" case Raw(stmt) => stmt aux(this) object Expr: def toDocuments(args: Ls[Expr], sep: Document): Document = - args.zipWithIndex.map { - case (x, i) => - if i == 0 then x.toDocument - else sep <#> x.toDocument - }.fold(raw(""))(_ <#> _) + args.map(_.toDocument).mkDocument(sep) enum Expr: case Var(name: Str) @@ -146,26 +139,31 @@ enum Expr: case DoubleLit(value) => value.toString case StrLit(value) => s"\"$value\"" // need more reliable escape utils case CharLit(value) => value.toInt.toString - case Call(func, args) => aux(func) <#> "(" <#> Expr.toDocuments(args, sep = ", ") <#> ")" - case Member(expr, member) => aux(expr) <#> "->" <#> member - case Index(expr, index) => aux(expr) <#> "[" <#> aux(index) <#> "]" - case Unary(op, expr) => "(" <#> op <#> aux(expr) <#> ")" - case Binary(op, lhs, rhs) => "(" <#> aux(lhs) <#> op <#> aux(rhs) <#> ")" - case Initializer(exprs) => "{" <#> Expr.toDocuments(exprs, sep = ", ") <#> "}" - case Constructor(name, init) => name <#> init.toDocument + case Call(func, args) => + doc"${func.toDocument}(${Expr.toDocuments(args, sep = ", ")})" + case Member(expr, member) => + doc"${expr.toDocument}->$member" + case Index(expr, index) => + doc"${expr.toDocument}[${index.toDocument}]" + case Unary(op, expr) => + doc"($op${expr.toDocument})" + case Binary(op, lhs, rhs) => + doc"(${lhs.toDocument} $op ${rhs.toDocument})" + case Initializer(exprs) => + doc"{${Expr.toDocuments(exprs, sep = ", ")}}" + case Constructor(name, init) => + doc"$name(${init.toDocument})" aux(this) case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): def toDocument: Document = - stack_list(includes.map(x => raw(x)) ++ decls.map(_.toDocument) ++ defs.map(_.toDocument)) + (includes.map(raw) ++ decls.map(_.toDocument) ++ defs.map(_.toDocument)).mkDocument(doc" # ") def toDocumentWithoutHidden: Document = - val hiddenNames = Set( - "HiddenTheseEntities", "True", "False", "Callable", "List", "Cons", "Nil", "Option", "Some", "None", "Pair", "Tuple2", "Tuple3", "Nat", "S", "O" - ) - stack_list(defs.filterNot { + val hiddenNames: Set[Str] = Set() + defs.filterNot { case Def.StructDef(name, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) case _ => false - }.map(_.toDocument)) + }.map(_.toDocument).mkDocument(doc" # ") enum Decl: case StructDecl(name: Str) @@ -175,10 +173,12 @@ enum Decl: def toDocument: Document = def aux(x: Decl): Document = x match - case StructDecl(name) => s"struct $name;" - case EnumDecl(name) => s"enum $name;" - case FuncDecl(ret, name, args) => ret.toDocument() <#> s" $name(" <#> Type.toDocuments(args, sep = ", ") <#> ");" - case VarDecl(name, typ) => typ.toDocument() <#> s" $name;" + case StructDecl(name) => doc"struct $name;" + case EnumDecl(name) => doc"enum $name;" + case FuncDecl(ret, name, args) => + doc"${ret.toDocument()} $name(${Type.toDocuments(args, sep = ", ")});" + case VarDecl(name, typ) => + doc"${typ.toDocument()} $name;" aux(this) enum Def: @@ -191,22 +191,30 @@ enum Def: def toDocument: Document = def aux(x: Def): Document = x match case StructDef(name, fields, inherit, defs) => - stack( - s"struct $name" <#> (if inherit.nonEmpty then ": public" <:> inherit.get.mkString(", ") else "" ) <:> "{", - stack_list(fields.map { - case (name, typ) => typ.toDocument() <#> " " <#> name <#> ";" - }) |> indent, - stack_list(defs.map(_.toDocument)) |> indent, - "};" - ) + val docFirst = doc"struct $name${inherit.fold(doc"")(x => doc": public ${x.mkDocument(doc", ")}")} {" + val docFields = fields.map { + case (name, typ) => doc"${typ.toDocument()} $name;" + }.mkDocument(doc" # ") + val docDefs = defs.map(_.toDocument).mkDocument(doc" # ") + val docLast = "};" + doc"$docFirst #{ # $docFields # $docDefs #} # $docLast" case EnumDef(name, fields) => - s"enum $name" <:> "{" <#> stack_list(fields.map { + val docFirst = doc"enum $name {" + val docFields = fields.map { case (name, value) => value.fold(s"$name")(x => s"$name = $x") - }) <#> "};" + }.mkDocument(doc" # ") + val docLast = "};" + doc"$docFirst #{ # $docFields #} # $docLast" case FuncDef(specret, name, args, body, or, virt) => - (if virt then "virtual " else "") - <#> specret.toDocument() <#> s" $name(" <#> Type.toDocuments(args, sep = ", ") <#> ")" <#> (if or then " override" else "") <#> body.toDocument + val docVirt = (if virt then doc"virtual " else doc"") + val docSpecRet = specret.toDocument() + val docArgs = Type.toDocuments(args, sep = ", ") + val docOverride = if or then doc" override" else doc"" + val docBody = body.toDocument + doc"$docVirt$docSpecRet $name($docArgs)$docOverride ${body.toDocument}" case VarDef(typ, name, init) => - typ.toDocument() <#> s" $name" <#> init.fold(raw(""))(x => " = " <#> x.toDocument) <#> raw(";") + val docTyp = typ.toDocument() + val docInit = init.fold(raw(""))(x => doc" = ${x.toDocument}") + doc"$docTyp $name$docInit;" case RawDef(x) => x aux(this) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala index 41606f64d..72be21209 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -4,14 +4,17 @@ import mlscript._ import mlscript.utils._ import mlscript.utils.shorthands._ -import hkmc2.utils.legacy_document._ import hkmc2.syntax._ +import hkmc2.Message.MessageContext +import hkmc2.document._ import util.Sorting import collection.immutable.SortedSet import language.implicitConversions import collection.mutable.{Map as MutMap, Set as MutSet, HashMap, ListBuffer} +private def raw(x: String): Document = doc"$x" + final case class LowLevelIRError(message: String) extends Exception(message) case class Program( @@ -26,18 +29,17 @@ case class Program( Sorting.quickSort(t2) s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$main)" - def show(hiddenNames: Set[Str] = Set.empty) = toDocument(hiddenNames).print + def show(hiddenNames: Set[Str] = Set.empty) = toDocument(hiddenNames).toString def toDocument(hiddenNames: Set[Str] = Set.empty) : Document = val t1 = classes.toArray val t2 = defs.toArray Sorting.quickSort(t1) Sorting.quickSort(t2) given Conversion[String, Document] = raw - stack( - stack_list(t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList) |> indent, - stack_list(t2.map(_.toDocument).toList) |> indent, - main.toDocument |> indent - ) + val docClasses = t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList.mkDocument(doc" # ") + val docDefs = t2.map(_.toDocument).toList.mkDocument(doc" # ") + val docMain = main.toDocument + doc" #{ $docClasses\n$docDefs\n$docMain #} " implicit object ClassInfoOrdering extends Ordering[ClassInfo] { def compare(a: ClassInfo, b: ClassInfo) = a.id.compare(b.id) @@ -54,18 +56,17 @@ case class ClassInfo( override def toString: String = s"ClassInfo($id, $name, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})" - def show = toDocument.print + def show = toDocument.toString def toDocument: Document = given Conversion[String, Document] = raw - val extension = if parents.isEmpty then "" else " extends " + parents.mkString(", ") + val ext = if parents.isEmpty then "" else " extends " + parents.mkString(", ") if methods.isEmpty then - "class" <:> name <#> "(" <#> fields.mkString(",") <#> ")" <#> extension + doc"class $name(${fields.mkString(",")})$ext" else - stack( - "class" <:> name <#> "(" <#> fields.mkString(",") <#> ")" <#> extension <:> "{", - stack_list( methods.map { (_, func) => func.toDocument |> indent }.toList), - "}" - ) + val docFirst = doc"class $name (${fields.mkString(",")})$ext {" + val docMethods = methods.map { (_, func) => func.toDocument }.toList.mkDocument(doc" # ") + val docLast = doc"}" + doc"$docFirst #{ # $docMethods # #} $docLast" case class Name(str: Str): def trySubst(map: Map[Str, Name]) = map.getOrElse(str, this) @@ -117,13 +118,12 @@ case class Func( val ps = params.map(_.toString).mkString("[", ",", "]") s"Def($id, $name, $ps, \n$resultNum, \n$body\n)" - def show = toDocument.print + def show = toDocument def toDocument: Document = given Conversion[String, Document] = raw - stack( - "def" <:> name <#> "(" <#> params.map(_.toString).mkString(",") <#> ")" <:> "=", - body.toDocument |> indent - ) + val docFirst = doc"def $name(${params.map(_.toString).mkString(",")}) =" + val docBody = body.toDocument + doc"$docFirst #{ # $docBody #} " sealed trait TrivialExpr: import Expr._ @@ -144,8 +144,7 @@ enum Expr: override def toString: String = show - def show: String = - toDocument.print + def show: String = toDocument.toString def toDocument: Document = given Conversion[String, Document] = raw @@ -157,31 +156,18 @@ enum Expr: case Literal(Tree.StrLit(lit)) => s"$lit" case Literal(Tree.UnitLit(undefinedOrNull)) => if undefinedOrNull then "undefined" else "null" case CtorApp(cls, args) => - cls.name <#> "(" <#> (args |> showArguments) <#> ")" + doc"${cls.name}(${args.map(_.toString).mkString(",")})" case Select(s, cls, fld) => - s.toString <#> ".<" <#> cls.name <#> ":" <#> fld <#> ">" + doc"${s.toString}.<${cls.name}:$fld>" case BasicOp(name: Str, args) => - name <#> "(" <#> (args |> showArguments) <#> ")" + doc"$name(${args.map(_.toString).mkString(",")})" case AssignField(assignee, clsInfo, fieldName, value) => - stack( - "assign" - <:> (assignee.toString + "." + fieldName) - <:> ":=" - <:> value.toDocument - ) + doc"${assignee.toString}.${fieldName} := ${value.toString}" enum Pat: case Lit(lit: hkmc2.syntax.Literal) case Class(cls: ClassRef) - def isTrue = this match - case Class(cls) => cls.name == "True" - case _ => false - - def isFalse = this match - case Class(cls) => cls.name == "False" - case _ => false - override def toString: String = this match case Lit(lit) => s"$lit" case Class(cls) => s"${cls.name}" @@ -199,68 +185,30 @@ enum Node: override def toString: String = show - def show: String = - toDocument.print + def show: String = toDocument.toString def toDocument: Document = given Conversion[String, Document] = raw this match case Result(res) => (res |> showArguments) case Jump(jp, args) => - "jump" - <:> jp.name - <#> "(" - <#> (args |> showArguments) - <#> ")" - case Case(x, Ls((true_pat, tru), (false_pat, fls)), N) if true_pat.isTrue && false_pat.isFalse => - val first = "if" <:> x.toString - val tru2 = indent(stack("true" <:> "=>", tru.toDocument |> indent)) - val fls2 = indent(stack("false" <:> "=>", fls.toDocument |> indent)) - Document.Stacked(Ls(first, tru2, fls2)) + doc"jump ${jp.name}(${args |> showArguments})" case Case(x, cases, default) => - val first = "case" <:> x.toString <:> "of" - val other = cases flatMap { - case (pat, node) => - Ls(pat.toString <:> "=>", node.toDocument |> indent) - } + val docFirst = doc"case ${x.toString} of" + val docCases = cases.map { + case (pat, node) => doc"${pat.toString} => #{ # ${node.toDocument} #} " + }.mkDocument(doc" # ") default match - case N => stack(first, (Document.Stacked(other) |> indent)) + case N => doc"$docFirst #{ # $docCases #} " case S(dc) => - val default = Ls("_" <:> "=>", dc.toDocument |> indent) - stack(first, (Document.Stacked(other ++ default) |> indent)) - case Panic(msg) => "panic" <:> "\"" <#> msg <#> "\"" + val docDeft = doc"_ => #{ # ${dc.toDocument} #} " + doc"$docFirst #{ # $docCases # $docDeft #} " + case Panic(msg) => + doc"panic ${s"\"$msg\""}" case LetExpr(x, expr, body) => - stack( - "let" - <:> x.toString - <:> "=" - <:> expr.toDocument - <:> "in", - body.toDocument) + doc"let ${x.toString} = ${expr.toString} in # ${body.toDocument}" case LetMethodCall(xs, cls, method, args, body) => - stack( - "let" - <:> xs.map(_.toString).mkString(",") - <:> "=" - <:> cls.name - <#> "." - <#> method.toString - <#> "(" - <#> args.map{ x => x.toString }.mkString(",") - <#> ")" - <:> "in", - body.toDocument) + doc"let ${xs.map(_.toString).mkString(",")} = ${cls.name}.${method.toString}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" case LetCall(xs, func, args, body) => - stack( - "let*" - <:> "(" - <#> xs.map(_.toString).mkString(",") - <#> ")" - <:> "=" - <:> func.name - <#> "(" - <#> args.map{ x => x.toString }.mkString(",") - <#> ")" - <:> "in", - body.toDocument) + doc"let* (${xs.map(_.toString).mkString(",")}) = ${func.name}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" diff --git a/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala b/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala deleted file mode 100644 index eec3b867f..000000000 --- a/hkmc2/shared/src/main/scala/hkmc2/utils/document/LegacyDocument.scala +++ /dev/null @@ -1,52 +0,0 @@ -package hkmc2.utils.legacy_document - -enum Document: - case Indented(content: Document) - case Unindented(content: Document) - case Stacked(docs: List[Document], emptyLines: Boolean = false) - case Lined(docs: List[Document], separator: Document) - case Raw(s: String) - - def <:>(other: Document) = line(List(this, other)) - def <#>(other: Document) = line(List(this, other), sep = "") - - override def toString: String = print - - def print: String = { - val sb = StringBuffer() - - def rec(d: Document)(implicit ind: Int, first: Boolean): Unit = d match { - case Raw(s) => - if first && s.nonEmpty then sb append (" " * ind) - sb append s - case Indented(doc) => - rec(doc)(ind + 1, first) - case Unindented(doc) => - assume(ind > 0) - rec(doc)(ind - 1, first) - case Lined(Nil, _) => // skip - case Lined(docs, sep) => - rec(docs.head) - docs.tail foreach { doc => - rec(sep)(ind, false) - rec(doc)(ind, false) - } - case Stacked(Nil, _) => // skip - case Stacked(docs, emptyLines) => - rec(docs.head) - docs.tail foreach { doc => - sb append "\n" - if emptyLines then sb append "\n" - rec(doc)(ind, true) - } - } - - rec(this)(0, true) - sb.toString - } - -def stack(docs: Document*) = Document.Stacked(docs.toList) -def stack_list(docs: List[Document]) = Document.Stacked(docs) -def line(docs: List[Document], sep: String = " ") = Document.Lined(docs, Document.Raw(sep)) -def raw(s: String) = Document.Raw(s) -def indent(doc: Document) = Document.Indented(doc) diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index 1424a04f4..185eba347 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -22,20 +22,20 @@ fun oops(a) = //│ _mlsValue _mls_oops(_mlsValue); //│ _mlsValue _mlsMain(); //│ struct _mls_A: public _mlsObject { -//│ +//│ //│ constexpr static inline const char *typeName = "A"; //│ constexpr static inline uint32_t typeTag = nextTypeTag(); //│ virtual void print() const override { std::printf("%s", typeName); } //│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } //│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_A; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } //│ }; -//│ _mlsValue _mls_oops(_mlsValue _mls_a){ +//│ _mlsValue _mls_oops(_mlsValue _mls_a) { //│ _mlsValue _mls_retval; //│ auto _mls_x_0 = _mlsValue::fromIntLit(1); //│ _mls_retval = _mlsValue::create<_mls_Unit>(); //│ return _mls_retval; //│ } -//│ _mlsValue _mlsMain(){ +//│ _mlsValue _mlsMain() { //│ _mlsValue _mls_retval; //│ _mls_retval = _mlsValue::create<_mls_Unit>(); //│ return _mls_retval; diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 7c3e50a95..81faaddcf 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -133,7 +133,7 @@ fun f1() = let x = 2 x //│ LLIR: -//│ +//│ //│ def f1() = //│ let x$0 = 1 in //│ let x$1 = 2 in @@ -148,7 +148,7 @@ fun f2() = //│ Pretty Lowered: //│ define fun f2() { set x = 0 in set scrut = ==(x, 1) in match scrut true => return 2 else return 3 } in return null //│ LLIR: -//│ +//│ //│ def f2() = //│ let x$0 = 0 in //│ let x$1 = ==(x$0,1) in @@ -181,7 +181,7 @@ fun f3() = //│ } in //│ return null //│ LLIR: -//│ +//│ //│ def f3() = //│ let x$0 = 0 in //│ let x$1 = 1 in @@ -220,7 +220,7 @@ fun f4() = //│ } in //│ return null //│ LLIR: -//│ +//│ //│ def f4() = //│ let x$0 = 0 in //│ let x$1 = ==(x$0,1) in @@ -270,7 +270,7 @@ fun f5() = //│ } in //│ return null //│ LLIR: -//│ +//│ //│ def f5() = //│ let x$0 = 0 in //│ let x$1 = ==(x$0,1) in @@ -298,7 +298,7 @@ fun f5() = fun test() = if true do test() //│ LLIR: -//│ +//│ //│ def test() = //│ let x$0 = true in //│ case x$0 of @@ -315,7 +315,7 @@ fun test() = fun test() = (if true then test()) + 1 //│ LLIR: -//│ +//│ //│ def test() = //│ let x$0 = true in //│ case x$0 of From 63bada9565ed0a4c8fd3a606b88e3d3ad5aacfba Mon Sep 17 00:00:00 2001 From: waterlens Date: Tue, 14 Jan 2025 20:51:26 +0800 Subject: [PATCH 13/23] Support instantiate --- .../scala/hkmc2/codegen/llir/Builder.scala | 18 ++++++++++++++++-- .../src/test/mlscript/llir/Playground.mls | 13 +++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 4153065fa..1c57d8104 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -108,6 +108,14 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case r: TrivialExpr => bArgs(xs): case rs: Ls[TrivialExpr] => k(r :: rs) + private def bPaths(e: Ls[Path])(k: Ls[TrivialExpr] => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + trace[Node](s"bArgs begin", x => s"bArgs end: ${x.show}"): + e match + case Nil => k(Nil) + case x :: xs => bPath(x): + case r: TrivialExpr => bPaths(xs): + case rs: Ls[TrivialExpr] => k(r :: rs) + private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): val FunDefn(sym, params, body) = e @@ -207,8 +215,14 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case args: Ls[TrivialExpr] => val v = fresh.make Node.LetMethodCall(Ls(v), ClassRef(R("Callable")), Name("apply" + args.length), f :: args, k(v |> sr)) + case Instantiate( + Select(Select(Value.Ref(_: TopLevelSymbol), name), Tree.Ident("class")), args) => + bPaths(args): + case args: Ls[TrivialExpr] => + val v = fresh.make + Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) case Instantiate(cls, args) => - err(msg"Unsupported result: Instantiate") + err(msg"Unsupported kind of Instantiate") Node.Result(Ls()) case x: Path => bPath(x)(k) @@ -248,7 +262,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) case Return(res, implct) => bResult(res)(x => Node.Result(Ls(x))) - case Throw(Instantiate(Select(Value.Ref(globalThis), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name == "Error" => + case Throw(Instantiate(Select(Value.Ref(_), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name == "Error" => Node.Panic(e) case Label(label, body, rest) => ??? case Break(label) => TODO("Break not supported") diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 81faaddcf..4025dfb65 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -27,6 +27,19 @@ fun lazy(x) = Lazy(x) //│ x$1 //│ undefined +:llir +fun testCtor1() = None +fun testCtor2() = new None +//│ LLIR: +//│ +//│ def testCtor1() = +//│ let x$0 = None() in +//│ x$0 +//│ def testCtor2() = +//│ let x$1 = None() in +//│ x$1 +//│ undefined + :llir abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] From 1c267a5b38822ce235ebfc370f89eda690031b95 Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 15 Jan 2025 20:42:55 +0800 Subject: [PATCH 14/23] Generate necessary errors --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 19 +++-- .../scala/hkmc2/codegen/llir/Builder.scala | 77 +++++++++++-------- .../src/test/mlscript/llir/BadPrograms.mls | 38 ++------- 3 files changed, 64 insertions(+), 70 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 08c480cf7..eb6ffaf44 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -32,10 +32,15 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val cuid = FreshInt() val llb = LlirBuilder(tl)(fresh, fuid, cuid) given Ctx = Ctx.empty - val llirProg = llb.bProg(le) - output("LLIR:") - output(llirProg.show()) - if scpp.isSet then - val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) - output("\nCpp:") - output(cpp.toDocument.toString) + try + val llirProg = llb.bProg(le) + output("LLIR:") + output(llirProg.show()) + if scpp.isSet then + val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) + output("\nCpp:") + output(cpp.toDocument.toString) + catch + case e: LowLevelIRError => + output("Stopped due to an error during the Llir generation") + diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 1c57d8104..6a7375991 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -22,6 +22,10 @@ def err(msg: Message)(using Raise): Unit = raise(ErrorReport(msg -> N :: Nil, source = Diagnostic.Source.Compilation)) +def errStop(msg: Message)(using Raise) = + err(msg) + throw LowLevelIRError("stopped") + final case class Ctx( def_acc: ListBuffer[Func], class_acc: ListBuffer[ClassInfo], @@ -29,6 +33,7 @@ final case class Ctx( fn_ctx: Map[Local, Name] = Map.empty, // is a known function class_ctx: Map[Local, Name] = Map.empty, block_ctx: Map[Local, Name] = Map.empty, + is_top_level: Bool = true, ): def addFuncName(n: Local, m: Name) = copy(fn_ctx = fn_ctx + (n -> m)) def findFuncName(n: Local)(using Raise) = fn_ctx.get(n) match @@ -51,6 +56,7 @@ final case class Ctx( def reset = def_acc.clear() class_acc.clear() + def nonTopLevel = copy(is_top_level = false) object Ctx: val empty = Ctx(ListBuffer.empty, ListBuffer.empty) @@ -119,30 +125,37 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): val FunDefn(sym, params, body) = e - if params.length != 1 then - err(msg"Curried function not supported: ${params.length.toString}") - val paramsList = params.head.params.map(x => x -> summon[Scope].allocateName(x.sym)) - val new_ctx = paramsList.foldLeft(ctx)((acc, x) => acc.addName(getVar_!(x._1.sym), x._2 |> nme)) - val pl = paramsList.map(_._2).map(nme) - Func( - fnUid.make, - sym.nme, - params = pl, - resultNum = 1, - body = bBlock(body)(x => Node.Result(Ls(x)))(using new_ctx) - ) + if !ctx.is_top_level then + errStop(msg"Non top-level definition ${sym.nme} not supported") + else if params.length != 1 then + errStop(msg"Curried function or zero arguments function not supported: ${params.length.toString}") + else + val paramsList = params.head.params.map(x => x -> summon[Scope].allocateName(x.sym)) + val ctx2 = paramsList.foldLeft(ctx)((acc, x) => acc.addName(getVar_!(x._1.sym), x._2 |> nme)) + val ctx3 = ctx2.nonTopLevel + val pl = paramsList.map(_._2).map(nme) + Func( + fnUid.make, + sym.nme, + params = pl, + resultNum = 1, + body = bBlock(body)(x => Node.Result(Ls(x)))(using ctx3) + ) private def bClsLikeDef(e: ClsLikeDefn)(using ctx: Ctx)(using Raise, Scope): ClassInfo = trace[ClassInfo](s"bClsLikeDef begin", x => s"bClsLikeDef end: ${x.show}"): val ClsLikeDefn(sym, kind, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e - val clsDefn = sym.defn.getOrElse(die) - val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) - val clsFields = publicFields.map(_.sym) - ClassInfo( - clsUid.make, - sym.nme, - clsParams.map(_.nme) ++ clsFields.map(_.nme), - ) + if !ctx.is_top_level then + errStop(msg"Non top-level definition ${sym.nme} not supported") + else + val clsDefn = sym.defn.getOrElse(die) + val clsParams = clsDefn.paramsOpt.fold(Nil)(_.paramSyms) + val clsFields = publicFields.map(_.sym) + ClassInfo( + clsUid.make, + sym.nme, + clsParams.map(_.nme) ++ clsFields.map(_.nme), + ) private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bValue begin", x => s"bValue end: ${x.show}"): @@ -166,17 +179,21 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: val v = fresh.make Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), Ls()), k(v |> sr)) // field selection - case s @ Select(qual, name) => + case s @ Select(qual, name) => log(s"bPath Select: $qual.$name with ${s.symbol}") - bPath(qual): - case q: Expr.Ref => - val v = fresh.make - val cls = ClassRef.fromName(ctx.findClassName(getClassOfMem(s.symbol.get))) - val field = name.name - Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) - case q: Expr.Literal => - err(msg"Unsupported select on literal") - Node.Result(Ls()) + s.symbol match + case None => + errStop(msg"Unsupported selection by users") + case Some(value) => + bPath(qual): + case q: Expr.Ref => + val v = fresh.make + val cls = ClassRef.fromName(ctx.findClassName(getClassOfMem(s.symbol.get))) + val field = name.name + Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) + case q: Expr.Literal => + err(msg"Unsupported select on literal") + Node.Result(Ls()) case x: Value => bValue(x)(k) private def bResult(r: Result)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index 185eba347..8bd737250 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -9,42 +9,14 @@ fun oops(a) = class A with fun m = a let x = 1 -//│ LLIR: -//│ class A() -//│ def oops(a) = -//│ let x$0 = 1 in -//│ undefined -//│ undefined -//│ -//│ Cpp: -//│ #include "mlsprelude.h" -//│ struct _mls_A; -//│ _mlsValue _mls_oops(_mlsValue); -//│ _mlsValue _mlsMain(); -//│ struct _mls_A: public _mlsObject { -//│ -//│ constexpr static inline const char *typeName = "A"; -//│ constexpr static inline uint32_t typeTag = nextTypeTag(); -//│ virtual void print() const override { std::printf("%s", typeName); } -//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } -//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_A; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } -//│ }; -//│ _mlsValue _mls_oops(_mlsValue _mls_a) { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ return _mls_retval; -//│ } -//│ _mlsValue _mlsMain() { -//│ _mlsValue _mls_retval; -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ return _mls_retval; -//│ } -//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } +//│ FAILURE: Unexpected runtime error +//│ ═══[COMPILATION ERROR] Non top-level definition A not supported +//│ Stopped due to an error during the Llir generation :todo // Properly reject let x = "oops" x.m -//│ /!!!\ Uncaught error: java.util.NoSuchElementException: None.get +//│ ═══[COMPILATION ERROR] Unsupported selection by users +//│ Stopped due to an error during the Llir generation From f243486a43b2ac2025d9aebaee536fad6abebb48 Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 15 Jan 2025 20:53:38 +0800 Subject: [PATCH 15/23] Update names of the flags --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 14 +-- .../src/test/mlscript/llir/BadPrograms.mls | 6 +- .../src/test/mlscript/llir/Playground.mls | 88 ++++--------------- 3 files changed, 26 insertions(+), 82 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index eb6ffaf44..21a5bfa08 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -18,6 +18,8 @@ import hkmc2.codegen.llir._ abstract class LlirDiffMaker extends BbmlDiffMaker: val llir = NullaryCommand("llir") + val cpp = NullaryCommand("cpp") + val sllir = NullaryCommand("sllir") val scpp = NullaryCommand("scpp") override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = @@ -34,12 +36,14 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: given Ctx = Ctx.empty try val llirProg = llb.bProg(le) - output("LLIR:") - output(llirProg.show()) - if scpp.isSet then + if sllir.isSet then + output("LLIR:") + output(llirProg.show()) + if cpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) - output("\nCpp:") - output(cpp.toDocument.toString) + if scpp.isSet then + output("\nCpp:") + output(cpp.toDocument.toString) catch case e: LowLevelIRError => output("Stopped due to an error during the Llir generation") diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index 8bd737250..0cd5a8212 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -1,10 +1,8 @@ :global :llir -:scpp +:cpp - -// TODO should be rejected fun oops(a) = class A with fun m = a @@ -13,9 +11,9 @@ fun oops(a) = //│ ═══[COMPILATION ERROR] Non top-level definition A not supported //│ Stopped due to an error during the Llir generation -:todo // Properly reject let x = "oops" x.m +//│ FAILURE: Unexpected runtime error //│ ═══[COMPILATION ERROR] Unsupported selection by users //│ Stopped due to an error during the Llir generation diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 4025dfb65..c2e170470 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -1,6 +1,7 @@ :js - :llir + +:sllir abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] object None extends Option @@ -27,7 +28,7 @@ fun lazy(x) = Lazy(x) //│ x$1 //│ undefined -:llir +:sllir fun testCtor1() = None fun testCtor2() = new None //│ LLIR: @@ -40,7 +41,7 @@ fun testCtor2() = new None //│ x$1 //│ undefined -:llir +:sllir abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] object None extends Option @@ -140,7 +141,7 @@ main() //│ let* (x$19) = main() in //│ x$19 -:llir +:sllir fun f1() = let x = 1 let x = 2 @@ -153,13 +154,10 @@ fun f1() = //│ x$1 //│ undefined -:slot -:llir +:sllir fun f2() = let x = 0 if x == 1 then 2 else 3 -//│ Pretty Lowered: -//│ define fun f2() { set x = 0 in set scrut = ==(x, 1) in match scrut true => return 2 else return 3 } in return null //│ LLIR: //│ //│ def f2() = @@ -174,25 +172,12 @@ fun f2() = //│ null //│ undefined -:llir -:slot + +:sllir fun f3() = let x1 = 0 let x2 = 1 if true then x1 else x2 -//│ Pretty Lowered: -//│ -//│ define fun f3() { -//│ set x1 = 0 in -//│ set x2 = 1 in -//│ set scrut = true in -//│ match scrut -//│ true => -//│ return x1 -//│ else -//│ return x2 -//│ } in -//│ return null //│ LLIR: //│ //│ def f3() = @@ -209,29 +194,11 @@ fun f3() = //│ undefined -:slot -:llir +:sllir fun f4() = let x = 0 let x = if x == 1 then 2 else 3 x -//│ Pretty Lowered: -//│ -//│ define fun f4() { -//│ set x = 0 in -//│ begin -//│ set scrut = ==(x, 1) in -//│ match scrut -//│ true => -//│ set tmp = 2 in -//│ end -//│ else -//│ set tmp = 3 in -//│ end; -//│ set x1 = tmp in -//│ return x1 -//│ } in -//│ return null //│ LLIR: //│ //│ def f4() = @@ -248,40 +215,13 @@ fun f4() = //│ x$2 //│ undefined -:slot -:llir +:sllir +:scpp fun f5() = let x = 0 let x = if x == 1 then 2 else 3 let x = if x == 2 then 4 else 5 x -//│ Pretty Lowered: -//│ -//│ define fun f5() { -//│ set x = 0 in -//│ begin -//│ set scrut = ==(x, 1) in -//│ match scrut -//│ true => -//│ set tmp = 2 in -//│ end -//│ else -//│ set tmp = 3 in -//│ end; -//│ set x1 = tmp in -//│ begin -//│ set scrut1 = ==(x1, 2) in -//│ match scrut1 -//│ true => -//│ set tmp1 = 4 in -//│ end -//│ else -//│ set tmp1 = 5 in -//│ end; -//│ set x2 = tmp1 in -//│ return x2 -//│ } in -//│ return null //│ LLIR: //│ //│ def f5() = @@ -307,7 +247,8 @@ fun f5() = //│ x$6 //│ undefined -:llir +:sllir +:scpp fun test() = if true do test() //│ LLIR: @@ -324,7 +265,8 @@ fun test() = //│ null //│ undefined -:llir +:sllir +:scpp fun test() = (if true then test()) + 1 //│ LLIR: From 8f57a4e651a7fa379ddd2f526fc8d913f459f452 Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 15 Jan 2025 21:06:58 +0800 Subject: [PATCH 16/23] Add the interpreter --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 7 +- .../scala/hkmc2/codegen/llir/Interp.scala | 212 ++++++++++++++++++ .../src/test/mlscript/llir/Playground.mls | 103 ++++++++- 3 files changed, 315 insertions(+), 7 deletions(-) create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 21a5bfa08..505ceadf1 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -21,6 +21,7 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val cpp = NullaryCommand("cpp") val sllir = NullaryCommand("sllir") val scpp = NullaryCommand("scpp") + val intl = NullaryCommand("intl") override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = super.processTerm(trm, inImport) @@ -39,11 +40,15 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if sllir.isSet then output("LLIR:") output(llirProg.show()) - if cpp.isSet then + if cpp.isSet || scpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) if scpp.isSet then output("\nCpp:") output(cpp.toDocument.toString) + if intl.isSet then + val intr = codegen.llir.Interpreter(verbose = true) + output("\nInterpreted:") + output(intr.interpret(llirProg)) catch case e: LowLevelIRError => output("Stopped due to an error during the Llir generation") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala new file mode 100644 index 000000000..60ab4cca1 --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala @@ -0,0 +1,212 @@ +package hkmc2.codegen.llir + +import mlscript.* +import mlscript.utils.* +import scala.collection.immutable.* +import scala.collection.mutable.ListBuffer +import shorthands.* +import scala.util.boundary, boundary.break + +import hkmc2.codegen.llir.* +import hkmc2.syntax.Tree + +enum Stuck: + case StuckExpr(expr: Expr, msg: Str) + case StuckNode(node: Node, msg: Str) + + override def toString: String = + this match + case StuckExpr(expr, msg) => s"StuckExpr(${expr.show}, $msg)" + case StuckNode(node, msg) => s"StuckNode(${node.show}, $msg)" + +final case class InterpreterError(message: String) extends Exception(message) + +class Interpreter(verbose: Bool): + private def log(x: Any) = if verbose then println(x) + import Stuck._ + + private case class Configuration( + var context: Ctx + ) + + private type Result[T] = Either[Stuck, T] + + private enum Value: + case Class(cls: ClassInfo, var fields: Ls[Value]) + case Literal(lit: hkmc2.syntax.Literal) + + override def toString: String = + import hkmc2.syntax.Tree.* + this match + case Class(cls, fields) => s"${cls.name}(${fields.mkString(",")})" + case Literal(IntLit(lit)) => lit.toString + case Literal(BoolLit(lit)) => lit.toString + case Literal(DecLit(lit)) => lit.toString + case Literal(StrLit(lit)) => lit.toString + case Literal(UnitLit(undefinedOrNull)) => if undefinedOrNull then "undefined" else "null" + + private final case class Ctx( + bindingCtx: Map[Str, Value], + classCtx: Map[Str, ClassInfo], + funcCtx: Map[Str, Func], + ) + + import Node._ + import Expr._ + + private def getTrue(using ctx: Ctx) = Value.Literal(hkmc2.syntax.Tree.BoolLit(true)) + private def getFalse(using ctx: Ctx) = Value.Literal(hkmc2.syntax.Tree.BoolLit(false)) + + private def eval(op: Str, x1: Value, x2: Value)(using ctx: Ctx): Opt[Value] = + import Value.{Literal => Li} + import hkmc2.syntax.Tree.* + (op, x1, x2) match + case ("+", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x + y))) + case ("-", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x - y))) + case ("*", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x * y))) + case ("/", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x / y))) + case ("==", Li(IntLit(x)), Li(IntLit(y))) => S(if x == y then getTrue else getFalse) + case ("!=", Li(IntLit(x)), Li(IntLit(y))) => S(if x != y then getTrue else getFalse) + case ("<=", Li(IntLit(x)), Li(IntLit(y))) => S(if x <= y then getTrue else getFalse) + case (">=", Li(IntLit(x)), Li(IntLit(y))) => S(if x >= y then getTrue else getFalse) + case (">", Li(IntLit(x)), Li(IntLit(y))) => S(if x > y then getTrue else getFalse) + case ("<", Li(IntLit(x)), Li(IntLit(y))) => S(if x < y then getTrue else getFalse) + case _ => N + + private def evalArgs(using ctx: Ctx)(exprs: Ls[TrivialExpr]): Result[Ls[Value]] = + var values = ListBuffer.empty[Value] + var stuck: Opt[Stuck] = None + exprs foreach { expr => + stuck match + case None => eval(expr) match + case L(x) => stuck = Some(x) + case R(x) => values += x + case _ => () + } + stuck.toLeft(values.toList) + + private def eval(expr: TrivialExpr)(using ctx: Ctx): Result[Value] = expr match + case e @ Ref(name) => ctx.bindingCtx.get(name.str).toRight(StuckExpr(e, s"undefined variable $name")) + case Literal(lit) => R(Value.Literal(lit)) + + private def eval(expr: Expr)(using ctx: Ctx): Result[Value] = expr match + case Ref(Name(x)) => ctx.bindingCtx.get(x).toRight(StuckExpr(expr, s"undefined variable $x")) + case Literal(x) => R(Value.Literal(x)) + case CtorApp(cls, args) => + for + xs <- evalArgs(args) + cls <- ctx.classCtx.get(cls.name).toRight(StuckExpr(expr, s"undefined class ${cls.name}")) + yield Value.Class(cls, xs) + case Select(name, cls, field) => + ctx.bindingCtx.get(name.str).toRight(StuckExpr(expr, s"undefined variable $name")).flatMap { + case Value.Class(cls2, xs) if cls.name == cls2.name => + xs.zip(cls2.fields).find{_._2 == field} match + case Some((x, _)) => R(x) + case None => L(StuckExpr(expr, s"unable to find selected field $field")) + case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) + case x => L(StuckExpr(expr, s"unexpected value $x")) + } + case BasicOp(name, args) => boundary: + evalArgs(args).flatMap( + xs => + name match + case "+" | "-" | "*" | "/" | "==" | "!=" | "<=" | ">=" | "<" | ">" => + if xs.length < 2 then break: + L(StuckExpr(expr, s"not enough arguments for basic operation $name")) + else eval(name, xs.head, xs.tail.head).toRight(StuckExpr(expr, s"unable to evaluate basic operation")) + case _ => L(StuckExpr(expr, s"unexpected basic operation $name"))) + case AssignField(assignee, cls, field, value) => + for + x <- eval(Ref(assignee): TrivialExpr) + y <- eval(value) + res <- x match + case obj @ Value.Class(cls2, xs) if cls.name == cls2.name => + xs.zip(cls2.fields).find{_._2 == field} match + case Some((_, _)) => + obj.fields = xs.map(x => if x == obj then y else x) + // Ideally, we should return a unit value here, but here we return the assignee value for simplicity. + R(obj) + case None => L(StuckExpr(expr, s"unable to find selected field $field")) + case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) + case x => L(StuckExpr(expr, s"unexpected value $x")) + yield res + + private def eval(node: Node)(using ctx: Ctx): Result[Ls[Value]] = node match + case Result(res) => evalArgs(res) + case Jump(func, args) => + for + xs <- evalArgs(args) + func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + res <- eval(func.body)(using ctx1) + yield res + case Case(scrut, cases, default) => + eval(scrut) flatMap { + case Value.Class(cls, fields) => + cases.find { + case (Pat.Class(cls2), _) => cls.name == cls2.name + case _ => false + } match { + case Some((_, x)) => eval(x) + case None => + default match + case S(x) => eval(x) + case N => L(StuckNode(node, s"can not find the matched case, ${cls.name} expected")) + } + case Value.Literal(lit) => + cases.find { + case (Pat.Lit(lit2), _) => lit == lit2 + case _ => false + } match { + case Some((_, x)) => eval(x) + case None => + default match + case S(x) => eval(x) + case N => L(StuckNode(node, s"can not find the matched case, $lit expected")) + } + } + case LetExpr(name, expr, body) => + for + x <- eval(expr) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx + (name.str -> x)) + res <- eval(body)(using ctx1) + yield res + case LetMethodCall(names, cls, method, args, body) => + for + ys <- evalArgs(args).flatMap { + case Value.Class(cls2, xs) :: args => + cls2.methods.get(method.str).toRight(StuckNode(node, s"undefined method ${method.str}")).flatMap { method => + val ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ cls2.fields.zip(xs) ++ method.params.map{_.str}.zip(args)) + eval(method.body)(using ctx1) + } + case _ => L(StuckNode(node, s"not enough arguments for method call, or the first argument is not a class")) + } + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + res <- eval(body)(using ctx2) + yield res + // case LetApply(names, fn, args, body) => eval(LetMethodCall(names, ClassRef(R("Callable")), Name("apply" + args.length), (Ref(fn): TrivialExpr) :: args, body)) + case LetCall(names, func, args, body) => + for + xs <- evalArgs(args) + func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + ys <- eval(func.body)(using ctx1) + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + res <- eval(body)(using ctx2) + yield res + case Panic(msg) => L(StuckNode(node, msg)) + + private def f(prog: Program): Ls[Value] = + val Program(classes, defs, main) = prog + given Ctx = Ctx( + bindingCtx = Map.empty, + classCtx = classes.map{cls => cls.name -> cls}.toMap, + funcCtx = defs.map{func => func.name -> func}.toMap, + ) + eval(main) match + case R(x) => x + case L(x) => throw InterpreterError("Stuck evaluation: " + x.toString) + + def interpret(prog: Program): Str = + val node = f(prog) + node.map(_.toString).mkString(",") diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index c2e170470..4a4a1f308 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -42,6 +42,7 @@ fun testCtor2() = new None //│ undefined :sllir +:intl abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] object None extends Option @@ -140,24 +141,37 @@ main() //│ null //│ let* (x$19) = main() in //│ x$19 +//│ +//│ Interpreted: +//│ 404 :sllir +:intl fun f1() = let x = 1 let x = 2 x +f1() +//│ = 2 //│ LLIR: //│ //│ def f1() = //│ let x$0 = 1 in //│ let x$1 = 2 in //│ x$1 -//│ undefined +//│ let* (x$2) = f1() in +//│ x$2 +//│ +//│ Interpreted: +//│ 2 :sllir +:intl fun f2() = let x = 0 if x == 1 then 2 else 3 +f2() +//│ = 3 //│ LLIR: //│ //│ def f2() = @@ -170,7 +184,11 @@ fun f2() = //│ 3 //│ def j$0() = //│ null -//│ undefined +//│ let* (x$2) = f2() in +//│ x$2 +//│ +//│ Interpreted: +//│ 3 :sllir @@ -178,6 +196,8 @@ fun f3() = let x1 = 0 let x2 = 1 if true then x1 else x2 +f3() +//│ = 0 //│ LLIR: //│ //│ def f3() = @@ -191,14 +211,18 @@ fun f3() = //│ x$1 //│ def j$0() = //│ null -//│ undefined +//│ let* (x$3) = f3() in +//│ x$3 :sllir +:intl fun f4() = let x = 0 let x = if x == 1 then 2 else 3 x +f4() +//│ = 3 //│ LLIR: //│ //│ def f4() = @@ -213,15 +237,21 @@ fun f4() = //│ jump j$0(x$4) //│ def j$0(x$2) = //│ x$2 -//│ undefined +//│ let* (x$5) = f4() in +//│ x$5 +//│ +//│ Interpreted: +//│ 3 :sllir -:scpp +:intl fun f5() = let x = 0 let x = if x == 1 then 2 else 3 let x = if x == 2 then 4 else 5 x +f5() +//│ = 5 //│ LLIR: //│ //│ def f5() = @@ -245,7 +275,11 @@ fun f5() = //│ jump j$1(x$8) //│ def j$1(x$6) = //│ x$6 -//│ undefined +//│ let* (x$9) = f5() in +//│ x$9 +//│ +//│ Interpreted: +//│ 5 :sllir :scpp @@ -264,6 +298,34 @@ fun test() = //│ def j$0() = //│ null //│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(); +//│ _mlsValue _mls_test(); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_test() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { +//│ auto _mls_x_1 = _mls_test(); +//│ _mls_retval = _mls_x_1; +//│ } else { +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } :sllir :scpp @@ -283,4 +345,33 @@ fun test() = //│ let x$3 = +(x$1,1) in //│ x$3 //│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_test(); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0(_mlsValue _mls_x_1) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_3 = (_mls_x_1 + _mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x_3; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_test() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { +//│ auto _mls_x_2 = _mls_test(); +//│ _mls_retval = _mls_j_0(_mls_x_2); +//│ } else { +//│ throw std::runtime_error("match error"); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } From 72f817d717d1502248bb073bb2244c170317c8ee Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 15 Jan 2025 21:45:17 +0800 Subject: [PATCH 17/23] Add the compiler host --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 14 +- .../hkmc2/codegen/cpp/CompilerHost.scala | 4 +- .../src/test/mlscript-compile/cpp/Makefile | 27 + .../test/mlscript-compile/cpp/mlsprelude.h | 568 ++++++++++++++++++ .../src/test/mlscript/llir/Playground.mls | 8 + 5 files changed, 617 insertions(+), 4 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript-compile/cpp/Makefile create mode 100644 hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 505ceadf1..db6ae399f 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -7,7 +7,8 @@ import utils.* import document.* import codegen.Block -import codegen.llir.LlirBuilder +import codegen.llir.* +import codegen.cpp.* import hkmc2.syntax.Tree.Ident import hkmc2.codegen.Path import hkmc2.semantics.Term.Blk @@ -21,6 +22,7 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val cpp = NullaryCommand("cpp") val sllir = NullaryCommand("sllir") val scpp = NullaryCommand("scpp") + val rcpp = NullaryCommand("rcpp") val intl = NullaryCommand("intl") override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = @@ -40,11 +42,19 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if sllir.isSet then output("LLIR:") output(llirProg.show()) - if cpp.isSet || scpp.isSet then + if cpp.isSet || scpp.isSet || rcpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) if scpp.isSet then output("\nCpp:") output(cpp.toDocument.toString) + if rcpp.isSet then + val auxPath = os.pwd/os.up/"shared"/"src"/"test"/"mlscript-compile"/"cpp" + val cppHost = CppCompilerHost(auxPath.toString, output.apply) + if !cppHost.ready then + output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") + else + output("\n") + cppHost.compileAndRun(cpp.toDocument.toString) if intl.isSet then val intr = codegen.llir.Interpreter(verbose = true) output("\nInterpreted:") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala index 291f5a0c4..867ffe4e6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CompilerHost.scala @@ -4,7 +4,7 @@ import mlscript._ import mlscript.utils.shorthands._ import scala.collection.mutable.ListBuffer -final class CppCompilerHost(val auxPath: Str): +final class CppCompilerHost(val auxPath: Str, val output: Str => Unit): import scala.sys.process._ private def ifAnyCppCompilerExists(): Boolean = Seq("g++", "--version").! == 0 || Seq("clang++", "--version").! == 0 @@ -15,7 +15,7 @@ final class CppCompilerHost(val auxPath: Str): val ready = ifAnyCppCompilerExists() && isMakeExists() - def compileAndRun(src: Str, output: Str => Unit): Unit = + def compileAndRun(src: Str): Unit = if !ready then return val srcPath = os.temp(contents = src, suffix = ".cpp") diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile new file mode 100644 index 000000000..45aae4802 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile @@ -0,0 +1,27 @@ +CXX := g++ +CFLAGS += -O3 -Wall -Wextra -std=c++20 -I. -Wno-inconsistent-missing-override -I/opt/homebrew/include +LDFLAGS += -L/opt/homebrew/lib +LDLIBS := -lmimalloc -lgmp +SRC := +INCLUDES := mlsprelude.h +DST := +DEFAULT_TARGET := mls +TARGET := $(or $(DST),$(DEFAULT_TARGET)) + +.PHONY: pre all run clean auto + +all: $(TARGET) + +run: $(TARGET) + ./$(TARGET) + +pre: $(SRC) + sed -i '' 's#^//│ ##g' $(SRC) + +clean: + rm -r $(TARGET) $(TARGET).dSYM + +auto: $(TARGET) + +$(TARGET): $(SRC) $(INCLUDES) + $(CXX) $(CFLAGS) $(LDFLAGS) $(SRC) $(LDLIBS) -o $(TARGET) diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h new file mode 100644 index 000000000..8415951c9 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h @@ -0,0 +1,568 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr std::size_t _mlsAlignment = 8; + +template class tuple_type { + template > struct impl; + template struct impl> { + template using wrap = T; + using type = std::tuple...>; + }; + +public: + using type = typename impl<>::type; +}; +template struct counter { + using tag = counter; + + struct generator { + friend consteval auto is_defined(tag) { return true; } + }; + friend consteval auto is_defined(tag); + + template + static consteval auto exists(auto) { + return true; + } + + static consteval auto exists(...) { return generator(), false; } +}; + +template +consteval auto nextTypeTag() { + if constexpr (not counter::exists(Id)) + return Id; + else + return nextTypeTag(); +} + +struct _mlsObject { + uint32_t refCount; + uint32_t tag; + constexpr static inline uint32_t stickyRefCount = + std::numeric_limits::max(); + + void incRef() { + if (refCount != stickyRefCount) + ++refCount; + } + bool decRef() { + if (refCount != stickyRefCount && --refCount == 0) + return true; + return false; + } + + virtual void print() const = 0; + virtual void destroy() = 0; +}; + +struct _mls_True; +struct _mls_False; + +class _mlsValue { + using uintptr_t = std::uintptr_t; + using uint64_t = std::uint64_t; + + void *value alignas(_mlsAlignment); + + bool isInt63() const { return (reinterpret_cast(value) & 1) == 1; } + + bool isPtr() const { return (reinterpret_cast(value) & 1) == 0; } + + uint64_t asInt63() const { return reinterpret_cast(value) >> 1; } + + uintptr_t asRawInt() const { return reinterpret_cast(value); } + + static _mlsValue fromRawInt(uintptr_t i) { + return _mlsValue(reinterpret_cast(i)); + } + + static _mlsValue fromInt63(uint64_t i) { + return _mlsValue(reinterpret_cast((i << 1) | 1)); + } + + void *asPtr() const { + assert(!isInt63()); + return value; + } + + _mlsObject *asObject() const { + assert(isPtr()); + return static_cast<_mlsObject *>(value); + } + + bool eqInt63(const _mlsValue &other) const { + return asRawInt() == other.asRawInt(); + } + + _mlsValue addInt63(const _mlsValue &other) const { + return fromRawInt(asRawInt() + other.asRawInt() - 1); + } + + _mlsValue subInt63(const _mlsValue &other) const { + return fromRawInt(asRawInt() - other.asRawInt() + 1); + } + + _mlsValue mulInt63(const _mlsValue &other) const { + return fromInt63(asInt63() * other.asInt63()); + } + + _mlsValue divInt63(const _mlsValue &other) const { + return fromInt63(asInt63() / other.asInt63()); + } + + _mlsValue gtInt63(const _mlsValue &other) const { + return asInt63() > other.asInt63() ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue ltInt63(const _mlsValue &other) const { + return asInt63() < other.asInt63() ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue geInt63(const _mlsValue &other) const { + return asInt63() >= other.asInt63() ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue leInt63(const _mlsValue &other) const { + return asInt63() <= other.asInt63() ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + +public: + explicit _mlsValue() : value(nullptr) {} + explicit _mlsValue(void *value) : value(value) {} + _mlsValue(const _mlsValue &other) : value(other.value) { + if (isPtr()) + asObject()->incRef(); + } + + _mlsValue &operator=(const _mlsValue &other) { + if (value != nullptr && isPtr()) + asObject()->decRef(); + value = other.value; + if (isPtr()) + asObject()->incRef(); + return *this; + } + + ~_mlsValue() { + if (isPtr()) + if (asObject()->decRef()) { + asObject()->destroy(); + value = nullptr; + } + } + + uint64_t asInt() const { + assert(isInt63()); + return asInt63(); + } + + static _mlsValue fromIntLit(uint64_t i) { return fromInt63(i); } + + template static tuple_type<_mlsValue, N> never() { + __builtin_unreachable(); + } + static _mlsValue never() { __builtin_unreachable(); } + + template static _mlsValue create(U... args) { + return _mlsValue(T::create(args...)); + } + + static void destroy(_mlsValue &v) { v.~_mlsValue(); } + + template static bool isValueOf(const _mlsValue &v) { + return v.asObject()->tag == T::typeTag; + } + + static bool isIntLit(const _mlsValue &v, uint64_t n) { + return v.asInt63() == n; + } + + static bool isIntLit(const _mlsValue &v) { return v.isInt63(); } + + template static T *as(const _mlsValue &v) { + return dynamic_cast(v.asObject()); + } + + template static T *cast(_mlsValue &v) { + return static_cast(v.asObject()); + } + + // Operators + + _mlsValue operator==(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return eqInt63(other) ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + assert(false); + } + + _mlsValue operator+(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return addInt63(other); + assert(false); + } + + _mlsValue operator-(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return subInt63(other); + assert(false); + } + + _mlsValue operator*(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return mulInt63(other); + assert(false); + } + + _mlsValue operator/(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return divInt63(other); + assert(false); + } + + _mlsValue operator>(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return gtInt63(other); + assert(false); + } + + _mlsValue operator<(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return ltInt63(other); + assert(false); + } + + _mlsValue operator>=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return geInt63(other); + assert(false); + } + + _mlsValue operator<=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return leInt63(other); + assert(false); + } + + // Auxiliary functions + + void print() const { + if (isInt63()) + std::printf("%" PRIu64, asInt63()); + else if (isPtr() && asObject()) + asObject()->print(); + } +}; + +struct _mls_Callable : public _mlsObject { + virtual _mlsValue _mls_apply0() { throw std::runtime_error("Not implemented"); } + virtual _mlsValue _mls_apply1(_mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply2(_mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply3(_mlsValue, _mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual _mlsValue _mls_apply4(_mlsValue, _mlsValue, _mlsValue, _mlsValue) { + throw std::runtime_error("Not implemented"); + } + virtual void destroy() override {} +}; + +inline static _mls_Callable *_mlsToCallable(_mlsValue fn) { + auto *ptr = _mlsValue::as<_mls_Callable>(fn); + if (!ptr) + throw std::runtime_error("Not a callable object"); + return ptr; +} + +template +inline static _mlsValue _mlsCall(_mlsValue f, U... args) { + static_assert(sizeof...(U) <= 4, "Too many arguments"); + if constexpr (sizeof...(U) == 0) + return _mlsToCallable(f)->_mls_apply0(); + else if constexpr (sizeof...(U) == 1) + return _mlsToCallable(f)->_mls_apply1(args...); + else if constexpr (sizeof...(U) == 2) + return _mlsToCallable(f)->_mls_apply2(args...); + else if constexpr (sizeof...(U) == 3) + return _mlsToCallable(f)->_mls_apply3(args...); + else if constexpr (sizeof...(U) == 4) + return _mlsToCallable(f)->_mls_apply4(args...); +} + +template +inline static T *_mlsMethodCall(_mlsValue self) { + auto *ptr = _mlsValue::as(self); + if (!ptr) + throw std::runtime_error("unable to convert object for method calls"); + return ptr; +} + +inline int _mlsLargeStack(void *(*fn)(void *)) { + pthread_t thread; + pthread_attr_t attr; + + size_t stacksize = 512 * 1024 * 1024; + pthread_attr_init(&attr); + pthread_attr_setstacksize(&attr, stacksize); + + int rc = pthread_create(&thread, &attr, fn, nullptr); + if (rc) { + printf("ERROR: return code from pthread_create() is %d\n", rc); + return 1; + } + pthread_join(thread, NULL); + return 0; +} + +_mlsValue _mlsMain(); + +inline void *_mlsMainWrapper(void *) { + _mlsValue res = _mlsMain(); + res.print(); + return nullptr; +} + +struct _mls_Unit final : public _mlsObject { + constexpr static inline const char *typeName = "Unit"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { std::printf(typeName); } + static _mlsValue create() { + static _mls_Unit mlsUnit alignas(_mlsAlignment); + mlsUnit.refCount = stickyRefCount; + mlsUnit.tag = typeTag; + return _mlsValue(&mlsUnit); + } + virtual void destroy() override {} +}; + +struct _mls_Boolean : public _mlsObject {}; + +struct _mls_True final : public _mls_Boolean { + constexpr static inline const char *typeName = "True"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { std::printf(typeName); } + static _mlsValue create() { + static _mls_True mlsTrue alignas(_mlsAlignment); + mlsTrue.refCount = stickyRefCount; + mlsTrue.tag = typeTag; + return _mlsValue(&mlsTrue); + } + virtual void destroy() override {} +}; + +struct _mls_False final : public _mls_Boolean { + constexpr static inline const char *typeName = "False"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { std::printf(typeName); } + static _mlsValue create() { + static _mls_False mlsFalse alignas(_mlsAlignment); + mlsFalse.refCount = stickyRefCount; + mlsFalse.tag = typeTag; + return _mlsValue(&mlsFalse); + } + virtual void destroy() override {} +}; + +#include + +struct _mls_ZInt final : public _mlsObject { + boost::multiprecision::mpz_int z; + constexpr static inline const char *typeName = "Z"; + constexpr static inline uint32_t typeTag = nextTypeTag(); + virtual void print() const override { + std::printf(typeName); + std::printf("("); + std::printf("%s", z.str().c_str()); + std::printf(")"); + } + virtual void destroy() override { + z.~number(); + operator delete(this, std::align_val_t(_mlsAlignment)); + } + static _mlsValue create() { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + static _mlsValue create(_mlsValue z) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; + _mlsVal->z = z.asInt(); + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + _mlsValue operator+(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z + other.z; + return _mlsVal; + } + + _mlsValue operator-(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z - other.z; + return _mlsVal; + } + + _mlsValue operator*(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z * other.z; + return _mlsVal; + } + + _mlsValue operator/(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z / other.z; + return _mlsVal; + } + + _mlsValue operator%(const _mls_ZInt &other) const { + auto _mlsVal = _mlsValue::create<_mls_ZInt>(); + _mlsValue::cast<_mls_ZInt>(_mlsVal)->z = z % other.z; + return _mlsVal; + } + + _mlsValue operator==(const _mls_ZInt &other) const { + return z == other.z ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue operator>(const _mls_ZInt &other) const { + return z > other.z ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue operator<(const _mls_ZInt &other) const { + return z < other.z ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue operator>=(const _mls_ZInt &other) const { + return z >= other.z ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue operator<=(const _mls_ZInt &other) const { + return z <= other.z ? _mlsValue::create<_mls_True>() + : _mlsValue::create<_mls_False>(); + } + + _mlsValue toInt() const { + return _mlsValue::fromIntLit(z.convert_to()); + } + + static _mlsValue fromInt(uint64_t i) { + return _mlsValue::create<_mls_ZInt>(_mlsValue::fromIntLit(i)); + } +}; + +__attribute__((noinline)) inline void _mlsNonExhaustiveMatch() { + throw std::runtime_error("Non-exhaustive match"); +} + +inline _mlsValue _mls_builtin_z_add(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) + *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_sub(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) - *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_mul(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) * *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_div(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) / *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_mod(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) % *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_equal(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) == *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_gt(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) > *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_lt(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) < *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_geq(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) >= *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_leq(_mlsValue a, _mlsValue b) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + return *_mlsValue::cast<_mls_ZInt>(a) <= *_mlsValue::cast<_mls_ZInt>(b); +} + +inline _mlsValue _mls_builtin_z_to_int(_mlsValue a) { + assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + return _mlsValue::cast<_mls_ZInt>(a)->toInt(); +} + +inline _mlsValue _mls_builtin_z_of_int(_mlsValue a) { + assert(_mlsValue::isIntLit(a)); + return _mlsValue::create<_mls_ZInt>(a); +} + +inline _mlsValue _mls_builtin_print(_mlsValue a) { + a.print(); + return _mlsValue::create<_mls_Unit>(); +} + +inline _mlsValue _mls_builtin_println(_mlsValue a) { + a.print(); + std::puts(""); + return _mlsValue::create<_mls_Unit>(); +} + +inline _mlsValue _mls_builtin_debug(_mlsValue a) { + a.print(); + std::puts(""); + return a; +} diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 4a4a1f308..ae496abdc 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -281,6 +281,14 @@ f5() //│ Interpreted: //│ 5 +:rcpp +1 +//│ = 1 +//│ +//│ +//│ Execution succeeded: +//│ 1 + :sllir :scpp fun test() = From 35f734637cec3b453712a044720cf756b4fd3abd Mon Sep 17 00:00:00 2001 From: waterlens Date: Wed, 15 Jan 2025 21:55:51 +0800 Subject: [PATCH 18/23] Fix --- hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala | 4 ++-- hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index db6ae399f..d5f3578c9 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -48,7 +48,7 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: output("\nCpp:") output(cpp.toDocument.toString) if rcpp.isSet then - val auxPath = os.pwd/os.up/"shared"/"src"/"test"/"mlscript-compile"/"cpp" + val auxPath = os.pwd/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" val cppHost = CppCompilerHost(auxPath.toString, output.apply) if !cppHost.ready then output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") @@ -62,4 +62,4 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: catch case e: LowLevelIRError => output("Stopped due to an error during the Llir generation") - + diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index 0cd5a8212..5aa3264fc 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -3,17 +3,17 @@ :llir :cpp +:ge fun oops(a) = class A with fun m = a let x = 1 -//│ FAILURE: Unexpected runtime error //│ ═══[COMPILATION ERROR] Non top-level definition A not supported //│ Stopped due to an error during the Llir generation +:ge let x = "oops" x.m -//│ FAILURE: Unexpected runtime error //│ ═══[COMPILATION ERROR] Unsupported selection by users //│ Stopped due to an error during the Llir generation From 7ff48a7c4caf896542d45f7599274498723b4aed Mon Sep 17 00:00:00 2001 From: waterlens Date: Thu, 16 Jan 2025 14:18:27 +0800 Subject: [PATCH 19/23] Add wcpp option; Change bool implementation --- .../src/test/scala/hkmc2/LlirDiffMaker.scala | 12 +++- .../test/mlscript-compile/cpp/mlsprelude.h | 65 ++++--------------- 2 files changed, 24 insertions(+), 53 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index d5f3578c9..01462854e 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -24,6 +24,12 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val scpp = NullaryCommand("scpp") val rcpp = NullaryCommand("rcpp") val intl = NullaryCommand("intl") + val wcpp = Command[Str]("wcpp", false)(x => x.stripLeading()) + + def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) = { + val p = new java.io.PrintWriter(f) + try { op(p) } finally { p.close() } + } override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = super.processTerm(trm, inImport) @@ -42,11 +48,15 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if sllir.isSet then output("LLIR:") output(llirProg.show()) - if cpp.isSet || scpp.isSet || rcpp.isSet then + if cpp.isSet || scpp.isSet || rcpp.isSet || wcpp.isSet then val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) if scpp.isSet then output("\nCpp:") output(cpp.toDocument.toString) + if wcpp.isSet then + printToFile(java.io.File(s"hkmc2/shared/src/test/mlscript-compile/cpp/${wcpp.get.get}.cpp")) { p => + p.println(cpp.toDocument.toString) + } if rcpp.isSet then val auxPath = os.pwd/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" val cppHost = CppCompilerHost(auxPath.toString, output.apply) diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h index 8415951c9..fed52549b 100644 --- a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h @@ -67,9 +67,6 @@ struct _mlsObject { virtual void destroy() = 0; }; -struct _mls_True; -struct _mls_False; - class _mlsValue { using uintptr_t = std::uintptr_t; using uint64_t = std::uint64_t; @@ -123,23 +120,19 @@ class _mlsValue { } _mlsValue gtInt63(const _mlsValue &other) const { - return asInt63() > other.asInt63() ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(asInt63() > other.asInt63()); } _mlsValue ltInt63(const _mlsValue &other) const { - return asInt63() < other.asInt63() ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(asInt63() < other.asInt63()); } _mlsValue geInt63(const _mlsValue &other) const { - return asInt63() >= other.asInt63() ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(asInt63() >= other.asInt63()); } _mlsValue leInt63(const _mlsValue &other) const { - return asInt63() <= other.asInt63() ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(asInt63() <= other.asInt63()); } public: @@ -174,6 +167,8 @@ class _mlsValue { static _mlsValue fromIntLit(uint64_t i) { return fromInt63(i); } + static _mlsValue fromBoolLit(bool b) { return fromInt63(b); } + template static tuple_type<_mlsValue, N> never() { __builtin_unreachable(); } @@ -207,8 +202,7 @@ class _mlsValue { _mlsValue operator==(const _mlsValue &other) const { if (isInt63() && other.isInt63()) - return eqInt63(other) ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(eqInt63(other)); assert(false); } @@ -355,34 +349,6 @@ struct _mls_Unit final : public _mlsObject { virtual void destroy() override {} }; -struct _mls_Boolean : public _mlsObject {}; - -struct _mls_True final : public _mls_Boolean { - constexpr static inline const char *typeName = "True"; - constexpr static inline uint32_t typeTag = nextTypeTag(); - virtual void print() const override { std::printf(typeName); } - static _mlsValue create() { - static _mls_True mlsTrue alignas(_mlsAlignment); - mlsTrue.refCount = stickyRefCount; - mlsTrue.tag = typeTag; - return _mlsValue(&mlsTrue); - } - virtual void destroy() override {} -}; - -struct _mls_False final : public _mls_Boolean { - constexpr static inline const char *typeName = "False"; - constexpr static inline uint32_t typeTag = nextTypeTag(); - virtual void print() const override { std::printf(typeName); } - static _mlsValue create() { - static _mls_False mlsFalse alignas(_mlsAlignment); - mlsFalse.refCount = stickyRefCount; - mlsFalse.tag = typeTag; - return _mlsValue(&mlsFalse); - } - virtual void destroy() override {} -}; - #include struct _mls_ZInt final : public _mlsObject { @@ -402,7 +368,7 @@ struct _mls_ZInt final : public _mlsObject { static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; _mlsVal->refCount = 1; - _mlsVal->tag = typeTag; + _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } static _mlsValue create(_mlsValue z) { @@ -443,28 +409,23 @@ struct _mls_ZInt final : public _mlsObject { } _mlsValue operator==(const _mls_ZInt &other) const { - return z == other.z ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(z == other.z); } _mlsValue operator>(const _mls_ZInt &other) const { - return z > other.z ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(z > other.z); } _mlsValue operator<(const _mls_ZInt &other) const { - return z < other.z ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(z < other.z); } _mlsValue operator>=(const _mls_ZInt &other) const { - return z >= other.z ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(z >= other.z); } _mlsValue operator<=(const _mls_ZInt &other) const { - return z <= other.z ? _mlsValue::create<_mls_True>() - : _mlsValue::create<_mls_False>(); + return _mlsValue::fromBoolLit(z <= other.z); } _mlsValue toInt() const { From 9a41b6a94d537e46f0a60046292ae24024b773c4 Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Thu, 16 Jan 2025 14:48:10 +0800 Subject: [PATCH 20/23] Try with new CI/nix versions --- .github/workflows/nix.yml | 2 +- flake.nix | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 044e4df66..670030768 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -18,7 +18,7 @@ jobs: - name: Install TypeScript run: npm ci - name: Run test - run: sbt -J-Xmx4096M -J-Xss4M test + run: sbt -J-Xmx4096M -J-Xss8M test - name: Check no changes run: | git update-index -q --refresh diff --git a/flake.nix b/flake.nix index 5f81da24a..8e657face 100644 --- a/flake.nix +++ b/flake.nix @@ -11,7 +11,7 @@ (system: let sbtOverlay = self: super: { - sbt = super.sbt.override { jre = super.jdk8_headless; }; + sbt = super.sbt.override { jre = super.jdk11_headless; }; }; pkgs = import nixpkgs { inherit system; From b6eb77e0fad04da970de1cb4f721262455a9f0e3 Mon Sep 17 00:00:00 2001 From: waterlens Date: Thu, 16 Jan 2025 15:54:04 +0800 Subject: [PATCH 21/23] Fix silly path --- hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala | 10 ++++------ hkmc2/shared/src/test/mlscript/llir/Playground.mls | 8 -------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala index 01462854e..ef9317828 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -26,10 +26,9 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: val intl = NullaryCommand("intl") val wcpp = Command[Str]("wcpp", false)(x => x.stripLeading()) - def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) = { + def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) = val p = new java.io.PrintWriter(f) try { op(p) } finally { p.close() } - } override def processTerm(trm: Blk, inImport: Bool)(using Raise): Unit = super.processTerm(trm, inImport) @@ -53,12 +52,11 @@ abstract class LlirDiffMaker extends BbmlDiffMaker: if scpp.isSet then output("\nCpp:") output(cpp.toDocument.toString) + val auxPath = os.Path(rootPath) / "hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" if wcpp.isSet then - printToFile(java.io.File(s"hkmc2/shared/src/test/mlscript-compile/cpp/${wcpp.get.get}.cpp")) { p => - p.println(cpp.toDocument.toString) - } + printToFile(java.io.File((auxPath / s"${wcpp.get.get}.cpp").toString)): + p => p.println(cpp.toDocument.toString) if rcpp.isSet then - val auxPath = os.pwd/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" val cppHost = CppCompilerHost(auxPath.toString, output.apply) if !cppHost.ready then output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index ae496abdc..4a4a1f308 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -281,14 +281,6 @@ f5() //│ Interpreted: //│ 5 -:rcpp -1 -//│ = 1 -//│ -//│ -//│ Execution succeeded: -//│ 1 - :sllir :scpp fun test() = From 376ad91b8f8fc93ede0fb33ebe4b44227084800d Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Thu, 16 Jan 2025 17:03:32 +0800 Subject: [PATCH 22/23] Add some problematic examples --- .../src/test/mlscript/llir/Playground.mls | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index 4a4a1f308..ed8f15052 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -1,6 +1,7 @@ :js :llir + :sllir abstract class Option[out T]: Some[T] | None class Some[out T](x: T) extends Option[T] @@ -375,3 +376,113 @@ fun test() = //│ } //│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + +// FIXME: this fails C++ compilation +// :rcpp +:sllir +:scpp +let x = 10 +if true do + set x += 1 +x +//│ = 11 +//│ x = 11 +//│ LLIR: +//│ +//│ def j$0(x$2) = +//│ x$2 +//│ let x$0 = 10 in +//│ let x$1 = true in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = +(x$2,1) in +//│ let x$4 = undefined in +//│ jump j$0(x$3) +//│ _ => +//│ let x$5 = undefined in +//│ jump j$0(x$2) +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0(_mlsValue _mls_x_2) { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mls_x_2; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = _mlsValue::fromIntLit(10); +//│ auto _mls_x_1 = _mlsValue::fromIntLit(1); +//│ if (_mlsValue::isIntLit(_mls_x_1, 1)) { +//│ auto _mls_x_3 = (_mls_x_2 + _mlsValue::fromIntLit(1)); +//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_3); +//│ } else { +//│ auto _mls_x_5 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_2); +//│ } +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + + +// TODO: catch this early in the pipeline (currently it fails the C++ compilation) +// :rcpp +:sllir +:scpp +fun foo(a) = + let x + if a > 0 do + x = 1 + x + 1 +//│ LLIR: +//│ +//│ def foo(a) = +//│ let x$0 = >(a,0) in +//│ case x$0 of +//│ BoolLit(true) => +//│ let x$2 = 1 in +//│ let x$3 = undefined in +//│ jump j$0(x$2) +//│ _ => +//│ let x$4 = undefined in +//│ jump j$0(x$1) +//│ def j$0(x$1) = +//│ let x$5 = +(x$1,1) in +//│ x$5 +//│ undefined +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_foo(_mlsValue); +//│ _mlsValue _mlsMain(); +//│ _mlsValue _mls_j_0(_mlsValue _mls_x_1) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_5 = (_mls_x_1 + _mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x_5; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_foo(_mlsValue _mls_a) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x_0 = (_mls_a > _mlsValue::fromIntLit(0)); +//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { +//│ auto _mls_x_2 = _mlsValue::fromIntLit(1); +//│ auto _mls_x_3 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_2); +//│ } else { +//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j_0(_mls_x_1); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } + + From a46ba4ee6f7b3bd099973dc166f4bfbbd6d24411 Mon Sep 17 00:00:00 2001 From: waterlens Date: Thu, 16 Jan 2025 18:52:22 +0800 Subject: [PATCH 23/23] stop early; fix wrong ctx passing --- .../scala/hkmc2/codegen/llir/Builder.scala | 36 +++--- .../scala/hkmc2/codegen/llir/Interp.scala | 1 - .../src/test/mlscript/llir/BadPrograms.mls | 9 +- .../src/test/mlscript/llir/Playground.mls | 104 +++++------------- 4 files changed, 53 insertions(+), 97 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 6a7375991..06caaffce 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -38,19 +38,19 @@ final case class Ctx( def addFuncName(n: Local, m: Name) = copy(fn_ctx = fn_ctx + (n -> m)) def findFuncName(n: Local)(using Raise) = fn_ctx.get(n) match case None => - err(msg"Function name not found: ${n.toString()}") + errStop(msg"Function name not found: ${n.toString()}") Name("error") case Some(value) => value def addClassName(n: Local, m: Name) = copy(class_ctx = class_ctx + (n -> m)) def findClassName(n: Local)(using Raise) = class_ctx.get(n) match case None => - err(msg"Class name not found: ${n.toString()}") + errStop(msg"Class name not found: ${n.toString()}") Name("error") case Some(value) => value def addName(n: Str, m: Name) = copy(symbol_ctx = symbol_ctx + (n -> m)) def findName(n: Str)(using Raise): Name = symbol_ctx.get(n) match case None => - err(msg"Name not found: $n") + errStop(msg"Name not found: $n") Name("error") case Some(value) => value def reset = @@ -88,7 +88,6 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: ts.id.name case ts: semantics.BlockMemberSymbol => // this means it's a locally-defined member ts.nme - // ts.trmTree case ts: semantics.InnerSymbol => summon[Scope].findThis_!(ts) case _ => summon[Scope].lookup_!(l) @@ -161,10 +160,10 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: trace[Node](s"bValue begin", x => s"bValue end: ${x.show}"): v match case Value.Ref(l) => k(ctx.findName(getVar_!(l)) |> sr) - case Value.This(sym) => err(msg"Unsupported value: This"); Node.Result(Ls()) + case Value.This(sym) => errStop(msg"Unsupported value: This"); Node.Result(Ls()) case Value.Lit(lit) => k(Expr.Literal(lit)) - case Value.Lam(params, body) => err(msg"Unsupported value: Lam"); Node.Result(Ls()) - case Value.Arr(elems) => err(msg"Unsupported value: Arr"); Node.Result(Ls()) + case Value.Lam(params, body) => errStop(msg"Unsupported value: Lam"); Node.Result(Ls()) + case Value.Arr(elems) => errStop(msg"Unsupported value: Arr"); Node.Result(Ls()) private def getClassOfMem(p: FieldSymbol)(using ctx: Ctx)(using Raise, Scope): Local = trace[Local](s"bMemSym begin", x => s"bMemSym end: $x"): @@ -192,7 +191,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: val field = name.name Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) case q: Expr.Literal => - err(msg"Unsupported select on literal") + errStop(msg"Unsupported select on literal") Node.Result(Ls()) case x: Value => bValue(x)(k) @@ -239,7 +238,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: val v = fresh.make Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(name.name), args), k(v |> sr)) case Instantiate(cls, args) => - err(msg"Unsupported kind of Instantiate") + errStop(msg"Unsupported kind of Instantiate") Node.Result(Ls()) case x: Path => bPath(x)(k) @@ -250,31 +249,28 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: bPath(scrut): case e: TrivialExpr => val jp = fresh.make("j") - // guess: the value of Match itself in Block is useless - // val res = fresh.make val fvset = (rest.freeVars -- rest.definedVars).map(allocIfNew) val fvs1 = fvset.toList val new_ctx = fvs1.foldLeft(ctx)((acc, x) => acc.addName(x, fresh.make)) val fvs = fvs1.map(new_ctx.findName(_)) - def cont(x: TrivialExpr)(using Ctx) = Node.Jump( + def cont(x: TrivialExpr)(using ctx: Ctx) = Node.Jump( FuncRef.fromName(jp), - /* x :: */ fvs1.map(x => summon[Ctx].findName(x)).map(sr) + fvs1.map(x => ctx.findName(x)).map(sr) ) - given Ctx = new_ctx val casesList: Ls[(Pat, Node)] = arms.map: case (Case.Lit(lit), body) => - (Pat.Lit(lit), bBlock(body)(cont)) + (Pat.Lit(lit), bBlock(body)(cont)(using ctx)) case (Case.Cls(cls, _), body) => - (Pat.Class(ClassRef.fromName(cls.nme)), bBlock(body)(cont)) + (Pat.Class(ClassRef.fromName(cls.nme)), bBlock(body)(cont)(using ctx)) case (Case.Tup(len, inf), body) => - (Pat.Class(ClassRef.fromName("Tuple" + len.toString())), bBlock(body)(cont)) + (Pat.Class(ClassRef.fromName("Tuple" + len.toString())), bBlock(body)(cont)(using ctx)) val defaultCase = dflt.map(bBlock(_)(cont)) val jpdef = Func( fnUid.make, jp.str, - params = /* res :: */ fvs, + params = fvs, resultNum = 1, - bBlock(rest)(k), + bBlock(rest)(k)(using new_ctx), ) summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) @@ -319,7 +315,7 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case End(msg) => k(Expr.Literal(Tree.UnitLit(false))) case _: Block => val docBlock = blk.showAsTree - err(msg"Unsupported block: $docBlock") + errStop(msg"Unsupported block: $docBlock") Node.Result(Ls()) def bProg(e: Program)(using Raise, Scope): LlirProgram = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala index 60ab4cca1..4f5416d1f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala @@ -184,7 +184,6 @@ class Interpreter(verbose: Bool): ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) res <- eval(body)(using ctx2) yield res - // case LetApply(names, fn, args, body) => eval(LetMethodCall(names, ClassRef(R("Callable")), Name("apply" + args.length), (Ref(fn): TrivialExpr) :: args, body)) case LetCall(names, func, args, body) => for xs <- evalArgs(args) diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index 5aa3264fc..ec2a852a7 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -17,4 +17,11 @@ x.m //│ ═══[COMPILATION ERROR] Unsupported selection by users //│ Stopped due to an error during the Llir generation - +:ge +fun foo(a) = + let x + if a > 0 do + x = 1 + x + 1 +//│ ═══[COMPILATION ERROR] Name not found: x +//│ Stopped due to an error during the Llir generation diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls index ed8f15052..a0bce1b91 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Playground.mls @@ -377,112 +377,66 @@ fun test() = //│ int main() { return _mlsLargeStack(_mlsMainWrapper); } -// FIXME: this fails C++ compilation -// :rcpp :sllir +:intl :scpp -let x = 10 -if true do - set x += 1 -x +fun f() = + let x = 10 + if true do + set x += 1 + x +f() //│ = 11 -//│ x = 11 //│ LLIR: //│ +//│ def f() = +//│ let x$0 = 10 in +//│ let x$1 = true in +//│ case x$1 of +//│ BoolLit(true) => +//│ let x$3 = +(x$0,1) in +//│ let x$4 = undefined in +//│ jump j$0(x$3) +//│ _ => +//│ let x$5 = undefined in +//│ jump j$0(x$0) //│ def j$0(x$2) = //│ x$2 -//│ let x$0 = 10 in -//│ let x$1 = true in -//│ case x$1 of -//│ BoolLit(true) => -//│ let x$3 = +(x$2,1) in -//│ let x$4 = undefined in -//│ jump j$0(x$3) -//│ _ => -//│ let x$5 = undefined in -//│ jump j$0(x$2) +//│ let* (x$6) = f() in +//│ x$6 //│ //│ Cpp: //│ #include "mlsprelude.h" //│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_f(); //│ _mlsValue _mlsMain(); //│ _mlsValue _mls_j_0(_mlsValue _mls_x_2) { //│ _mlsValue _mls_retval; //│ _mls_retval = _mls_x_2; //│ return _mls_retval; //│ } -//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_f() { //│ _mlsValue _mls_retval; //│ auto _mls_x_0 = _mlsValue::fromIntLit(10); //│ auto _mls_x_1 = _mlsValue::fromIntLit(1); //│ if (_mlsValue::isIntLit(_mls_x_1, 1)) { -//│ auto _mls_x_3 = (_mls_x_2 + _mlsValue::fromIntLit(1)); +//│ auto _mls_x_3 = (_mls_x_0 + _mlsValue::fromIntLit(1)); //│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); //│ _mls_retval = _mls_j_0(_mls_x_3); //│ } else { //│ auto _mls_x_5 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_2); -//│ } -//│ return _mls_retval; -//│ } -//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } - - -// TODO: catch this early in the pipeline (currently it fails the C++ compilation) -// :rcpp -:sllir -:scpp -fun foo(a) = - let x - if a > 0 do - x = 1 - x + 1 -//│ LLIR: -//│ -//│ def foo(a) = -//│ let x$0 = >(a,0) in -//│ case x$0 of -//│ BoolLit(true) => -//│ let x$2 = 1 in -//│ let x$3 = undefined in -//│ jump j$0(x$2) -//│ _ => -//│ let x$4 = undefined in -//│ jump j$0(x$1) -//│ def j$0(x$1) = -//│ let x$5 = +(x$1,1) in -//│ x$5 -//│ undefined -//│ -//│ Cpp: -//│ #include "mlsprelude.h" -//│ _mlsValue _mls_j_0(_mlsValue); -//│ _mlsValue _mls_foo(_mlsValue); -//│ _mlsValue _mlsMain(); -//│ _mlsValue _mls_j_0(_mlsValue _mls_x_1) { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_5 = (_mls_x_1 + _mlsValue::fromIntLit(1)); -//│ _mls_retval = _mls_x_5; -//│ return _mls_retval; -//│ } -//│ _mlsValue _mls_foo(_mlsValue _mls_a) { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = (_mls_a > _mlsValue::fromIntLit(0)); -//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { -//│ auto _mls_x_2 = _mlsValue::fromIntLit(1); -//│ auto _mls_x_3 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_2); -//│ } else { -//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_1); +//│ _mls_retval = _mls_j_0(_mls_x_0); //│ } //│ return _mls_retval; //│ } //│ _mlsValue _mlsMain() { //│ _mlsValue _mls_retval; -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ auto _mls_x_6 = _mls_f(); +//│ _mls_retval = _mls_x_6; //│ return _mls_retval; //│ } //│ int main() { return _mlsLargeStack(_mlsMainWrapper); } - +//│ +//│ Interpreted: +//│ 11