Skip to content

Commit

Permalink
Merge pull request #187 from hkust-taco/new-definition-typing
Browse files Browse the repository at this point in the history
New MLscript frontend
  • Loading branch information
LPTK authored Jan 12, 2024
2 parents ea4ddab + 05752ff commit 4511944
Show file tree
Hide file tree
Showing 532 changed files with 48,226 additions and 7,604 deletions.
1 change: 0 additions & 1 deletion .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches: [ mlscript ]
pull_request:
branches: [ mlscript ]

jobs:
build:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ metals.sbt
project/Dependencies.scala
project/metals.sbt
**.worksheet.sc
.DS_Store
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
"other": "off",
"strings": "off"
}
}
},
"files.autoSave": "off"
}
18 changes: 8 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ import Wart._

enablePlugins(ScalaJSPlugin)

ThisBuild / scalaVersion := "2.13.9"
ThisBuild / scalaVersion := "2.13.12"
ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / organization := "io.lptk"
ThisBuild / organizationName := "LPTK"
ThisBuild / scalacOptions ++= Seq(
"-deprecation",
"-feature",
"-unchecked",
)

lazy val root = project.in(file("."))
.aggregate(mlscriptJS, mlscriptJVM, ts2mlsTest, compilerJVM)
Expand All @@ -18,9 +23,6 @@ lazy val mlscript = crossProject(JSPlatform, JVMPlatform).in(file("."))
.settings(
name := "mlscript",
scalacOptions ++= Seq(
"-deprecation",
"-feature",
"-unchecked",
"-language:higherKinds",
"-Ywarn-value-discard",
"-Ypatmat-exhaust-depth:160",
Expand All @@ -36,7 +38,8 @@ lazy val mlscript = crossProject(JSPlatform, JVMPlatform).in(file("."))
StringPlusAny, Any, ToString,
JavaSerializable, Serializable, Product, ToString,
LeakingSealed, Overloading,
Option2Iterable, IterableOps, ListAppend
Option2Iterable, IterableOps, ListAppend, SeqApply,
TripleQuestionMark,
),
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.12" % Test,
libraryDependencies += "com.lihaoyi" %%% "sourcecode" % "0.3.0",
Expand All @@ -60,10 +63,6 @@ lazy val mlscriptJS = mlscript.js
lazy val ts2mls = crossProject(JSPlatform, JVMPlatform).in(file("ts2mls"))
.settings(
name := "ts2mls",
scalaVersion := "2.13.8",
scalacOptions ++= Seq(
"-deprecation"
)
)
.jvmSettings()
.jsSettings(
Expand All @@ -76,7 +75,6 @@ lazy val ts2mlsJVM = ts2mls.jvm

lazy val ts2mlsTest = project.in(file("ts2mls"))
.settings(
scalaVersion := "2.13.8",
Test / test := ((ts2mlsJVM / Test / test) dependsOn (ts2mlsJS / Test / test)).value
)

Expand Down
341 changes: 221 additions & 120 deletions compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/DataType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mlscript.compiler

abstract class DataType

object DataType:
sealed class Singleton(value: Expr.Literal, dataType: DataType) extends DataType:
override def toString(): String = value.toString()

enum Primitive(name: String) extends DataType:
case Integer extends Primitive("int")
case Decimal extends Primitive("real")
case Boolean extends Primitive("bool")
case String extends Primitive("str")
override def toString(): String = this.name
end Primitive

sealed case class Tuple(elementTypes: List[DataType]) extends DataType:
override def toString(): String = elementTypes.mkString("(", ", ", ")")

sealed case class Class(declaration: Item.TypeDecl) extends DataType:
override def toString(): String = s"class ${declaration.name.name}"

sealed case class Function(parameterTypes: List[DataType], returnType: DataType) extends DataType:
def this(returnType: DataType, parameterTypes: DataType*) =
this(parameterTypes.toList, returnType)
override def toString(): String =
val parameterList = parameterTypes.mkString("(", ", ", ")")
s"$parameterList -> $returnType"

sealed case class Record(fields: Map[String, DataType]) extends DataType:
def this(fields: (String, DataType)*) = this(Map.from(fields))
override def toString(): String =
fields.iterator.map { (name, ty) => s"$name: $ty" }.mkString("{", ", ", "}")

case object Unknown extends DataType:
override def toString(): String = "unknown"
18 changes: 18 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/DataTypeInferer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package mlscript.compiler
import mlscript.compiler.mono.MonomorphError

trait DataTypeInferer:
import DataType._

def findClassByName(name: String): Option[Item.TypeDecl]

def infer(expr: Expr, compatiableType: Option[DataType]): DataType =
expr match
case Expr.Tuple(elements) => DataType.Tuple(elements.map(infer(_, None)))
case lit @ Expr.Literal(value: BigInt) => Singleton(lit, Primitive.Integer)
case lit @ Expr.Literal(value: BigDecimal) => Singleton(lit, Primitive.Decimal)
case lit @ Expr.Literal(value: String) => Singleton(lit, Primitive.String)
case lit @ Expr.Literal(value: Boolean) => Singleton(lit, Primitive.Boolean)
case Expr.Apply(Expr.Ref(name), args) =>
findClassByName(name).fold(DataType.Unknown)(DataType.Class(_))
case _ => throw MonomorphError(s"I can't infer the type of $expr now")
177 changes: 177 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/Helpers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package mlscript
package compiler

import mlscript.compiler.mono.{Monomorph, MonomorphError}
import scala.collection.mutable.ArrayBuffer

object Helpers:
/**
* Extract parameters for monomorphization from a `Tup`.
*/
def toFuncParams(term: Term): Iterator[Parameter] = term match
case Tup(fields) => fields.iterator.flatMap {
// The new parser emits `Tup(_: UnitLit(true))` from `fun f() = x`.
case (_, Fld(FldFlags(_, _, _), UnitLit(true))) => None
case (None, Fld(FldFlags(_, spec, _), Var(name))) => Some((spec, Expr.Ref(name)))
case (Some(Var(name)), Fld(FldFlags(_, spec, _), _)) => Some((spec, Expr.Ref(name)))
case _ => throw new MonomorphError(
s"only `Var` can be parameters but we meet $term"
)
}
case _ => throw MonomorphError("expect the list of parameters to be a `Tup`")

def toFuncArgs(term: Term): IterableOnce[Term] = term match
// The new parser generates `(undefined, )` when no arguments.
// Let's do this temporary fix.
case Tup((_, Fld(FldFlags(_, _, _), UnitLit(true))) :: Nil) => Iterable.empty
case Tup(fields) => fields.iterator.map(_._2.value)
case _ => Some(term)

def term2Expr(term: Term): Expr = {
term match
case Var(name) => Expr.Ref(name)
case Lam(lhs, rhs) =>
val params = toFuncParams(lhs).toList
Expr.Lambda(params, term2Expr(rhs))
case App(App(Var("=>"), Bra(false, args: Tup)), body) =>
val params = toFuncParams(args).toList
Expr.Lambda(params, term2Expr(body))
case App(App(Var("."), self), App(Var(method), args: Tup)) =>
Expr.Apply(Expr.Select(term2Expr(self), Expr.Ref(method)), List.from(toFuncArgs(args).iterator.map(term2Expr)))
case App(lhs, rhs) =>
val callee = term2Expr(lhs)
val arguments = toFuncArgs(rhs).iterator.map(term2Expr).toList
Expr.Apply(callee, arguments)
case Tup(fields) =>
Expr.Tuple(fields.map {
case (_, Fld(FldFlags(mut, spec, genGetter), value)) => term2Expr(value)
})
case Rcd(fields) =>
Expr.Record(fields.map {
case (name, Fld(FldFlags(mut, spec, genGetter), value)) => (Expr.Ref(name.name), term2Expr(value))
})
case Sel(receiver, fieldName) =>
Expr.Select(term2Expr(receiver), Expr.Ref(fieldName.name))
case Let(rec, Var(name), rhs, body) =>
val exprRhs = term2Expr(rhs)
val exprBody = term2Expr(body)
Expr.LetIn(rec, Expr.Ref(name), exprRhs, exprBody)
case Blk(stmts) => Expr.Block(stmts.flatMap[Expr | Item.FuncDecl | Item.FuncDefn] {
case term: Term => Some(term2Expr(term))
case tyDef: NuTypeDef => throw MonomorphError(s"Unimplemented term2Expr ${term}")
case funDef: NuFunDef =>
val NuFunDef(_, nme, sn, targs, rhs) = funDef
val ret: Item.FuncDecl | Item.FuncDefn = rhs match
case Left(Lam(params, body)) =>
Item.FuncDecl(Expr.Ref(nme.name), toFuncParams(params).toList, term2Expr(body))
case Left(body: Term) => Item.FuncDecl(Expr.Ref(nme.name), Nil, term2Expr(body))
case Right(tp) => Item.FuncDefn(Expr.Ref(nme.name), targs, PolyType(Nil, tp)) //TODO: Check correctness in Type -> Polytype conversion
Some(ret)
case mlscript.DataDefn(_) => throw MonomorphError("unsupported DataDefn")
case mlscript.DatatypeDefn(_, _) => throw MonomorphError("unsupported DatatypeDefn")
case mlscript.TypeDef(_, _, _, _, _, _, _, _) => throw MonomorphError("unsupported TypeDef")
case mlscript.Def(_, _, _, _) => throw MonomorphError("unsupported Def")
case mlscript.LetS(_, _, _) => throw MonomorphError("unsupported LetS")
case mlscript.Constructor(_, _) => throw MonomorphError("unsupported Constructor")
})
case Bra(rcd, term) => term2Expr(term)
case Asc(term, ty) => Expr.As(term2Expr(term), ty)
case _: Bind => throw MonomorphError("cannot monomorphize `Bind`")
case _: Test => throw MonomorphError("cannot monomorphize `Test`")
case With(term, Rcd(fields)) =>
Expr.With(term2Expr(term), Expr.Record(fields.map {
case (name, Fld(FldFlags(mut, spec, getGetter), value)) => (Expr.Ref(name.name), term2Expr(term))
}))
case CaseOf(term, cases) =>
def rec(bra: CaseBranches)(using buffer: ArrayBuffer[CaseBranch]): Unit = bra match
case Case(pat, body, rest) =>
val newCase = pat match
case Var(name) => CaseBranch.Instance(Expr.Ref(name), Expr.Ref("_"), term2Expr(body))
case DecLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case IntLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case StrLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case UnitLit(undefinedOrNull) => CaseBranch.Constant(Expr.Literal(UnitValue.Undefined), term2Expr(body))
buffer.addOne(newCase)
rec(rest)
case NoCases => ()
case Wildcard(body) =>
buffer.addOne(CaseBranch.Wildcard(term2Expr(body)))
val branchBuffer = ArrayBuffer[CaseBranch]()
rec(cases)(using branchBuffer)
Expr.Match(term2Expr(term), branchBuffer)

case Subs(array, index) =>
Expr.Subscript(term2Expr(array), term2Expr(index))
case Assign(lhs, rhs) =>
Expr.Assign(term2Expr(lhs), term2Expr(rhs))
case New(None, body) =>
throw MonomorphError(s"Unimplemented term2Expr ${term}")
case New(Some((constructor, args)), body) =>
val typeName = constructor match
case AppliedType(TypeName(name), _) => name
case TypeName(name) => name
Expr.New(TypeName(typeName), toFuncArgs(args).iterator.map(term2Expr).toList)
// case Blk(unit) => Expr.Isolated(trans2Expr(TypingUnit(unit)))
case If(body, alternate) => body match
case IfThen(condition, consequent) =>
Expr.IfThenElse(
term2Expr(condition),
term2Expr(consequent),
alternate.map(term2Expr)
)
case term: IfElse => throw MonomorphError("unsupported IfElse")
case term: IfLet => throw MonomorphError("unsupported IfLet")
case term: IfOpApp => throw MonomorphError("unsupported IfOpApp")
case term: IfOpsApp => throw MonomorphError("unsupported IfOpsApp")
case term: IfBlock => throw MonomorphError("unsupported IfBlock")
case IntLit(value) => Expr.Literal(value)
case DecLit(value) => Expr.Literal(value)
case StrLit(value) => Expr.Literal(value)
case UnitLit(undefinedOrNull) =>
Expr.Literal(if undefinedOrNull
then UnitValue.Undefined
else UnitValue.Null)
case _ => throw MonomorphError("unsupported term"+ term.toString)
}

def func2Item(funDef: NuFunDef): Item.FuncDecl | Item.FuncDefn =
val NuFunDef(_, nme, sn, targs, rhs) = funDef
rhs match
case Left(Lam(params, body)) =>
Item.FuncDecl(Expr.Ref(nme.name), toFuncParams(params).toList, term2Expr(body))
case Left(body: Term) => Item.FuncDecl(Expr.Ref(nme.name), Nil, term2Expr(body))
case Right(tp) => Item.FuncDefn(Expr.Ref(nme.name), targs, PolyType(Nil, tp)) //TODO: Check correctness in Type -> Polytype conversion

def type2Item(tyDef: NuTypeDef): Item.TypeDecl =
val NuTypeDef(kind, className, tparams, params, _, _, parents, _, _, body) = tyDef
val isolation = Isolation(body.entities.flatMap {
// Question: Will there be pure terms in class body?
case term: Term =>
Some(term2Expr(term))
case subTypeDef: NuTypeDef => throw MonomorphError(s"Unimplemented func2Item ${tyDef}")
case subFunDef: NuFunDef =>
Some(func2Item(subFunDef))
case term => throw MonomorphError(term.toString)
})
val typeDecl: Item.TypeDecl = Item.TypeDecl(
Expr.Ref(className.name), // name
kind match // kind
case Als => TypeDeclKind.Alias
case Cls => TypeDeclKind.Class
case Trt => TypeDeclKind.Trait
case _ => throw MonomorphError(s"Unsupported TypeDefKind conversion ${kind}")
,
tparams.map(_._2), // typeParams
toFuncParams(params.getOrElse(Tup(Nil))).toList, // params
parents.map {
case Var(name) => (TypeName(name), Nil)
case App(Var(name), args) => (TypeName(name), term2Expr(args) match{
case Expr.Tuple(fields) => fields
case _ => Nil
})
case _ => throw MonomorphError("unsupported parent term")
}, // parents
isolation // body
)
typeDecl

20 changes: 10 additions & 10 deletions compiler/shared/main/scala/mlscript/compiler/PrettyPrinter.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package mlscript.compiler

import mlscript.{TypingUnit, NuFunDef, NuTypeDef, Term}
import mlscript.{TypingUnit, NuFunDef, NuTypeDef, Term, Tup}
import mlscript.compiler.debug.DebugOutput

// For pretty printing terms in debug output.
object PrettyPrinter:
def show(term: Term): DebugOutput = DebugOutput.Code(term.toString.linesIterator.toList)
def show(term: Term): DebugOutput = DebugOutput.Code(term.showDbg.linesIterator.toList)
def show(unit: TypingUnit): DebugOutput = DebugOutput.Code(showTypingUnit(unit, 0).linesIterator.toList)
def show(funDef: NuFunDef): DebugOutput = DebugOutput.Code(showFunDef(funDef).linesIterator.toList)
def show(tyDef: NuTypeDef): DebugOutput = DebugOutput.Code(showTypeDef(tyDef, 0).linesIterator.toList)
Expand All @@ -16,7 +16,7 @@ object PrettyPrinter:
case term: Term => show(term)
case tyDef: NuTypeDef => showTypeDef(tyDef)
case funDef: NuFunDef => showFunDef(funDef)
case others => others.toString()
case others => others.showDbg
}.mkString("{", "; ", "}")
if (singleLine.length < 60)
singleLine
Expand All @@ -26,7 +26,7 @@ object PrettyPrinter:
case term: Term => show(term)
case tyDef: NuTypeDef => showTypeDef(tyDef)
case funDef: NuFunDef => showFunDef(funDef)
case others => others.toString()
case others => others.showDbg
}.map(indentStr + " " + _).mkString("{\n", "\n", s"\n$indentStr}")

def showFunDef(funDef: NuFunDef): String =
Expand All @@ -36,19 +36,19 @@ object PrettyPrinter:
case Some(true) => "let'"
}
s"$st ${funDef.nme.name}"
+ (if funDef.targs.isEmpty
+ (if funDef.tparams.isEmpty
then ""
else funDef.targs.map(_.name).mkString("[", ", ", "]"))
else funDef.tparams.map(_.name).mkString("[", ", ", "]"))
+ " = "
+ funDef.rhs.fold(_.toString, _.body.show)
+ funDef.rhs.fold(_.showDbg, _.show(newDefs = true))

def showTypeDef(tyDef: NuTypeDef, indent: Int = 0): String =
s"${tyDef.kind.str} ${tyDef.nme.name}"
+ (if tyDef.tparams.isEmpty
then ""
else tyDef.tparams.map(_.name).mkString("[", ",", "]"))
+ "(" + tyDef.params + ")"
else tyDef.tparams.map(_._2.name).mkString("[", ",", "]"))
+ tyDef.params.fold("")(params => s"(${params.showDbg})")
+ (if tyDef.parents.isEmpty
then ""
else ": " + tyDef.parents.map(_.toString).mkString(", "))
else ": " + tyDef.parents.map(_.showDbg).mkString(", "))
+ showTypingUnit(tyDef.body, indent + 1)
Loading

0 comments on commit 4511944

Please sign in to comment.