Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loader, generate-json-serializer, generate-pprint, generate-eq, and generate-utest #901

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions src/stdlib/mexpr/generate-eq.mc
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
-- Generate code required to compare two arbitrary (monomorphic,
-- non-function) values based on their structure

include "ast.mc"
include "type-check.mc"

include "mlang/loader.mc"

lang GenerateEq = Ast
type GEqEnv =
{ conFunctions : Map Name Name -- For TyCons
, varFunctions : Map Name Name -- For TyVars
, newFunctions : [(Name, Expr)] -- To be defined

, tcEnv : TCEnv -- Current typechecking environment

, eqSeq : Name
, eqBool : Name
}

sem getEqFunction : GEqEnv -> Type -> (GEqEnv, Expr)
sem getEqFunction env = | ty -> _getEqFunction env (unwrapType ty)

sem _getEqFunction : GEqEnv -> Type -> (GEqEnv, Expr)
end

lang GenerateEqInt = GenerateEq + IntTypeAst + CmpIntAst
sem _getEqFunction env =
| TyInt _ -> (env, uconst_ (CEqi ()))
end

lang GenerateEqFloat = GenerateEq + FloatTypeAst + CmpFloatAst
sem _getEqFunction env =
| TyFloat _ -> (env, uconst_ (CEqf ()))
end

lang GenerateEqBool = GenerateEq + BoolTypeAst
sem _getEqFunction env =
| TyBool _ -> (env, nvar_ env.eqBool)
end

lang GenerateEqSeq = GenerateEq + SeqTypeAst
sem _getEqFunction env =
| TySeq x ->
match getEqFunction env x.ty with (env, elemF) in
(env, app_ (nvar_ env.eqSeq) elemF)
end

lang GenerateEqChar = GenerateEq + CharTypeAst + CmpCharAst
sem _getEqFunction env =
| TyChar _ ->
(env, uconst_ (CEqc ()))
end

lang GenerateEqRecord = GenerateEq + RecordTypeAst
sem _getEqFunction env =
| ty & TyRecord x ->
if mapIsEmpty x.fields then (env, ulam_ "" (ulam_ "" true_)) else

let lName = nameSym "l" in
let l = withType ty (nvar_ lName) in
let rName = nameSym "r" in
let r = withType ty (nvar_ rName) in

let genRecElem = lam acc. lam label. lam ty. snoc acc (lam env.
match getEqFunction env ty with (env, eqF) in
let label = sidToString label in
(env, appf2_ eqF (recordproj_ label l) (recordproj_ label r))) in
let elems = mapFoldWithKey genRecElem [] x.fields in
match mapAccumL (lam env. lam f. f env) env elems with (env, [first] ++ elems) in

let f = lam acc. lam elem. if_ elem acc false_ in
(env, nulam_ lName (nulam_ rName (foldl f first elems)))
end

lang GenerateEqApp = GenerateEq + AppTypeAst
sem _getEqFunction env =
| TyApp x ->
match getEqFunction env x.lhs with (env, lhs) in
match getEqFunction env x.rhs with (env, rhs) in
(env, app_ lhs rhs)
end

lang GenerateEqCon = GenerateEq + ConTypeAst
sem _getEqFunction env =
| TyCon x ->
-- TODO(vipa, 2025-01-27): Invalidate old eq functions if
-- we've introduced constructors to pre-existing types
match mapLookup x.ident env.conFunctions with Some n then (env, nvar_ n) else

let fname = nameSym (concat "eq" (nameGetStr x.ident)) in
let env = {env with conFunctions = mapInsert x.ident fname env.conFunctions} in

-- TODO(vipa, 2025-01-27): We cannot see locally defined types
-- here, which might be an issue
let params = match mapLookup x.ident env.tcEnv.tyConEnv with Some (_, params, _)
then params
else errorSingle [x.info] (concat "Typecheck environment does not contain information about type " (nameGetStr x.ident)) in
let paramFNames = foldl (lam acc. lam n. mapInsert n (nameSetNewSym n) acc) (mapEmpty nameCmp) params in
let prevVarFunctions = env.varFunctions in
let env = {env with varFunctions = mapUnion env.varFunctions paramFNames} in

let constructors = mapIntersectWith
(lam. lam pair. pair.1)
(mapLookupOr (setEmpty nameCmp) x.ident env.tcEnv.conDeps)
env.tcEnv.conEnv in

let lName = nameSym "l" in
let rName = nameSym "r" in
let addMatch = lam acc. lam c. lam t.
match acc with (env, tm) in
match getEqFunction env t with (env, subf) in
let subl = nameSym "subl" in
let subr = nameSym "subr" in
let tm = match_ (nvar_ lName) (npcon_ c (npvar_ subl))
(match_ (nvar_ rName) (npcon_ c (npvar_ subr))
(appf2_ subf (nvar_ subl) (nvar_ subr))
false_)
tm in
(env, tm) in
match mapFoldWithKey addMatch (env, never_) constructors with (env, matchChain) in
let matchChain = nulam_ lName (nulam_ rName matchChain) in
let body = foldr (lam pname. lam body. nulam_ (mapFindExn pname paramFNames) body) matchChain params in

let env = {env with varFunctions = prevVarFunctions, newFunctions = snoc env.newFunctions (fname, body)} in
(env, nvar_ fname)
end

lang GenerateEqVar = GenerateEq + VarTypeAst
-- NOTE(vipa, 2025-01-27): This function will error when it
-- encounters a polymorphic value of unknown type. We could
-- arbitrarily say "equal" or "not equal", but that seems error
-- prone, or we could somehow ask surrounding code to be rewritten
-- to carry an extra eq function for the polymorphic type.
sem _getEqFunction env =
| TyVar x ->
match mapLookup x.ident env.varFunctions with Some fname
then (env, nvar_ fname)
else errorSingle [x.info] (join ["I don't know how to compare values of the polymorphic type ", nameGetStr x.ident])
end

lang MExprGenerateEq
= GenerateEqRecord
+ GenerateEqBool
+ GenerateEqInt
+ GenerateEqFloat
+ GenerateEqSeq
+ GenerateEqChar
+ GenerateEqApp
+ GenerateEqCon
+ GenerateEqVar
end

lang GenerateEqLoader = MCoreLoader + GenerateEq
syn Hook =
| EqHook
{ baseEnv : GEqEnv
, functions : Ref (Map Name Name) -- Names for TyCon related Eq functions
}

sem enableEqGeneration : Loader -> Loader
sem enableEqGeneration = | loader ->
if hasHook (lam x. match x with EqHook _ then true else false) loader then loader else

match includeFileExn "." "stdlib::seq.mc" loader with (seqEnv, loader) in
match includeFileExn "." "stdlib::bool.mc" loader with (boolEnv, loader) in

let baseEnv =
{ conFunctions = mapEmpty nameCmp
, varFunctions = mapEmpty nameCmp
, newFunctions = []
, tcEnv = typcheckEnvEmpty
, eqSeq = _getVarExn "eqSeq" seqEnv
, eqBool = _getVarExn "eqBool" boolEnv
} in

let hook = EqHook
{ baseEnv = baseEnv
, functions = ref (mapEmpty nameCmp)
} in
addHook loader hook

sem _eqFunctionsFor : [Type] -> Loader -> Hook -> Option (Loader, [Expr])
sem _eqFunctionsFor tys loader =
| _ -> None ()
| EqHook hook ->
match mapAccumL getEqFunction {hook.baseEnv with conFunctions = deref hook.functions, tcEnv = _getTCEnv loader} tys
with (env, printFs) in

modref hook.functions env.conFunctions;
let loader = if null env.newFunctions
then loader
else _addDeclExn loader (decl_nureclets_ env.newFunctions) in
Some (loader, printFs)

sem eqFunctionsFor : [Type] -> Loader -> (Loader, [Expr])
sem eqFunctionsFor tys = | loader ->
withHookState (_eqFunctionsFor tys) loader
end
Loading
Loading