diff --git a/src/typecheck/typecheck.test.ts b/src/typecheck/typecheck.test.ts index 8792a679..a2c7ade2 100644 --- a/src/typecheck/typecheck.test.ts +++ b/src/typecheck/typecheck.test.ts @@ -2,7 +2,7 @@ import { expect, test } from "vitest"; import { unsafeParse } from "../parser"; import { typecheck, TypeError } from "./typecheck"; import { typePPrint } from "./pretty-printer"; -import { Context } from "./unify"; +import { ConcreteType, Context } from "./unify"; test("infer int", () => { const [types, errors] = tc(` @@ -72,6 +72,37 @@ test("typechecking previously defined vars", () => { }); }); +test("fn returning a constant", () => { + const [types, errors] = tc(` + let f = fn { 42 } + `); + + expect(errors).toEqual([]); + expect(types).toEqual({ + f: "Fn() -> Int", + }); +}); + +test("application return type", () => { + const [types, errors] = tc( + ` + let x = 1 > 2 + `, + { + ">": { + type: "fn", + args: [Int, Int], + return: Bool, + }, + }, + ); + + expect(errors).toEqual([]); + expect(types).toEqual({ + x: "Bool", + }); +}); + function tc(src: string, context: Context = {}) { const parsedProgram = unsafeParse(src); const [typed, errors] = typecheck(parsedProgram, context); @@ -83,3 +114,6 @@ function tc(src: string, context: Context = {}) { return [Object.fromEntries(kvs), errors]; } + +const Int: ConcreteType = { type: "named", name: "Int", args: [] }; +const Bool: ConcreteType = { type: "named", name: "Bool", args: [] }; diff --git a/src/typecheck/typecheck.ts b/src/typecheck/typecheck.ts index f53ad6f1..ca17674d 100644 --- a/src/typecheck/typecheck.ts +++ b/src/typecheck/typecheck.ts @@ -69,7 +69,26 @@ function* typecheckAnnotatedExpr( } case "fn": + // TODO handle params + yield* unifyYieldErr(ast, ast.$.asType(), { + type: "fn", + args: [], + return: ast.body.$.asType(), + }); + yield* typecheckAnnotatedExpr(ast.body, context); + return; + case "application": + yield* unifyYieldErr(ast, ast.caller.$.asType(), { + type: "fn", + args: ast.args.map((arg) => arg.$.asType()), + return: ast.$.asType(), + }); + // TODO typecheck args + yield* typecheckAnnotatedExpr(ast.caller, context); + + return; + case "let": case "if": throw new Error("TODO typecheckExpr with type: " + ast.type); @@ -83,7 +102,22 @@ function annotateExpr(ast: Expr): Expr { return { ...ast, $: TVar.fresh() }; case "fn": + return { + ...ast, + $: TVar.fresh(), + body: annotateExpr(ast.body), + params: ast.params.map((p) => ({ + ...p, + $: TVar.fresh(), + })), + }; case "application": + return { + ...ast, + $: TVar.fresh(), + caller: annotateExpr(ast.caller), + args: ast.args.map(annotateExpr), + }; case "let": case "if": throw new Error("TODO annotateExpr of: " + ast.type);