Skip to content

Commit

Permalink
Track function argument types
Browse files Browse the repository at this point in the history
  • Loading branch information
Negabinary committed Jan 23, 2025
1 parent 5cd1744 commit 2b3225c
Show file tree
Hide file tree
Showing 23 changed files with 95 additions and 115 deletions.
2 changes: 1 addition & 1 deletion src/haz3lcore/dynamics/DHExp.re
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ let rec strip_casts =
let assign_name_if_none = (t, name) => {
let (term, rewrap) = unwrap(t);
switch (term) {
| Fun(arg, body, None) => Fun(arg, body, name) |> rewrap
| Fun(arg, body, typ, None) => Fun(arg, body, typ, name) |> rewrap
| TypFun(utpat, body, None) => TypFun(utpat, body, name) |> rewrap
| _ => t
};
Expand Down
6 changes: 3 additions & 3 deletions src/haz3lcore/dynamics/EvalCtx.re
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type term =
| Seq2(DHExp.t, t)
| Let1(Pat.t, t, DHExp.t)
| Let2(Pat.t, DHExp.t, t)
| Fun(Pat.t, t, option(Var.t))
| Fun(Pat.t, t, option(Typ.t), option(Var.t))
| FixF(Pat.t, t, option(ClosureEnvironment.t))
| TypAp(t, Typ.t)
| Ap1(Operators.ap_direction, t, DHExp.t)
Expand Down Expand Up @@ -124,9 +124,9 @@ let rec compose = (ctx: t, d: DHExp.t): DHExp.t => {
| Let2(dp, d1, ctx) =>
let d = compose(ctx, d);
Let(dp, d1, d) |> wrap;
| Fun(dp, ctx, v) =>
| Fun(dp, ctx, typ, v) =>
let d = compose(ctx, d);
Fun(dp, d, v) |> wrap;
Fun(dp, d, typ, v) |> wrap;
| FixF(v, ctx, env) =>
let d = compose(ctx, d);
FixF(v, d, env) |> wrap;
Expand Down
4 changes: 2 additions & 2 deletions src/haz3lcore/dynamics/EvaluatorStep.re
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ let rec matches =
| Let2(d1, d2, ctx) =>
let+ ctx = matches(env, flt, ctx, exp, act, idx);
Let2(d1, d2, ctx) |> rewrap;
| Fun(dp, ctx, name) =>
| Fun(dp, ctx, ty, name) =>
// TODO: Should this env include the bound variables?
let+ ctx = matches(env, flt, ctx, exp, act, idx);
Fun(dp, ctx, name) |> rewrap;
Fun(dp, ctx, ty, name) |> rewrap;
| FixF(name, ctx, env') =>
let+ ctx =
matches(Option.value(~default=env, env'), flt, ctx, exp, act, idx);
Expand Down
2 changes: 1 addition & 1 deletion src/haz3lcore/dynamics/FilterMatcher.re
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ let rec matches_exp =
| (TypFun(pat1, d1, s1), TypFun(pat2, d2, s2)) =>
s1 == s2 && matches_utpat(pat1, pat2) && matches_exp(d1, d2)
| (TypFun(_), _) => false
| (Fun(dp1, d1, _), Fun(fp1, f1, _)) =>
| (Fun(dp1, d1, _, _), Fun(fp1, f1, _, _)) =>
matches_fun(~denv, dp1, d1, ~fenv, fp1, f1)
| (Fun(_), _) => false

Expand Down
6 changes: 3 additions & 3 deletions src/haz3lcore/dynamics/Substitution.re
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ let rec subst_var = (m, d1: DHExp.t, x: Var.t, d2: DHExp.t): DHExp.t => {
subst_var(m, d1, x, d3);
};
FixF(y, d3, env') |> rewrap;
| Fun(dp, d3, s) =>
| Fun(dp, d3, ty, s) =>
if (binds_var(m, x, dp)) {
Fun(dp, d3, s) |> rewrap;
Fun(dp, d3, ty, s) |> rewrap;
} else {
let d3 = subst_var(m, d1, x, d3);
Fun(dp, d3, s) |> rewrap;
Fun(dp, d3, ty, s) |> rewrap;
}
| TypFun(tpat, d3, s) =>
TypFun(tpat, subst_var(m, d1, x, d3), s) |> rewrap
Expand Down
2 changes: 1 addition & 1 deletion src/haz3lcore/dynamics/Transition.re
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ module Transition = (EV: EV_MODE) => {
is_value: false,
});
| TypFun(_)
| Fun(_, _, _) =>
| Fun(_, _, _, _) =>
let. _ = otherwise(env, d);
let.wrap_closure _ = env;
Value;
Expand Down
9 changes: 7 additions & 2 deletions src/haz3lcore/dynamics/TypeAssignment.re
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ and typ_of_dhexp = (ctx: Ctx.t, m: Statics.Map.t, dh: DHExp.t): option(Typ.t) =>
};
let* ctx = dhpat_extend_ctx(dhp, ty_p, ctx);
typ_of_dhexp(ctx, m, d);
| Fun(dhp, d, _) =>
let* ty_p = dhpat_synthesize(dhp, ctx);
| Fun(dhp, d, ty, _) =>
let* ty_p =
switch (ty) {
| None => dhpat_synthesize(dhp, ctx)
| Some(t) => Some(t)
};

let* ctx = dhpat_extend_ctx(dhp, ty_p, ctx);
let* ty2 = typ_of_dhexp(ctx, m, d);
Some(Arrow(ty_p, ty2) |> Typ.temp);
Expand Down
4 changes: 2 additions & 2 deletions src/haz3lcore/dynamics/Unboxing.re
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ let rec unbox: type a. (unbox_request(a), DHExp.t) => unboxed(a) =

/* Function-like things can look like the following when values */
| (Fun, Constructor(name, _)) => Matches(Constructor(name)) // Perhaps we should check if the constructor actually is a function?
| (Fun, Closure(env', {term: Fun(dp, d3, _), _})) =>
| (Fun, Closure(env', {term: Fun(dp, d3, _, _), _})) =>
Matches(FunEnv(dp, d3, env'))
| (
Fun,
Expand Down Expand Up @@ -221,7 +221,7 @@ let rec unbox: type a. (unbox_request(a), DHExp.t) => unboxed(a) =
Invalid(_) | Undefined | EmptyHole | MultiHole(_) | DynamicErrorHole(_) |
Var(_) |
Let(_) |
Fun(_, _, _) |
Fun(_, _, _, _) |
FixF(_) |
TyAlias(_) |
Ap(_) |
Expand Down
7 changes: 4 additions & 3 deletions src/haz3lcore/pretty/ExpToSegment.re
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ let rec exp_to_pretty = (~settings: Settings.t, exp: Exp.t): pretty => {
let id = exp |> Exp.rep_id;
let+ es = es |> List.map(any_to_pretty(~settings)) |> all;
ListUtil.flat_intersperse(Grout({id, shape: Concave}), es);
| Parens({term: Fun(p, e, _), _} as inner_exp) =>
| Parens({term: Fun(p, e, _, _), _} as inner_exp) =>
// TODO: Add optional newlines
let id = inner_exp |> Exp.rep_id;
let+ p = pat_to_pretty(~settings: Settings.t, p)
Expand All @@ -249,7 +249,7 @@ let rec exp_to_pretty = (~settings: Settings.t, exp: Exp.t): pretty => {
let fun_form = [mk_form("fun_", id, [p])] @ e;
[mk_form("parens_exp", exp |> Exp.rep_id, [fun_form])]
|> fold_fun_if(settings.fold_fn_bodies, name);
| Fun(p, e, _) =>
| Fun(p, e, _, _) =>
// TODO: Add optional newlines
let id = exp |> Exp.rep_id;
let+ p = pat_to_pretty(~settings: Settings.t, p)
Expand Down Expand Up @@ -850,10 +850,11 @@ let rec parenthesize = (~show_filters: bool, exp: Exp.t): Exp.t => {
// Other forms
| Constructor(c, t) =>
Constructor(c, paren_typ_at(Precedence.cast, t)) |> rewrap
| Fun(p, e, n) =>
| Fun(p, e, typ, n) =>
Fun(
parenthesize_pat(p) |> paren_pat_at(Precedence.min),
parenthesize(e) |> paren_assoc_at(Precedence.fun_),
typ, // this typ is currently never output
n,
)
|> rewrap
Expand Down
8 changes: 5 additions & 3 deletions src/haz3lcore/statics/Elaborator.re
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,12 @@ let rec elaborate = (m: Statics.Map.t, uexp: Exp.t): (DHExp.t, Typ.t) => {
};
let t = t |> Typ.normalize(ctx) |> Typ.all_ids_temp;
Constructor(c, t) |> rewrap |> cast_from(t);
| Fun(p, e, n) =>
| Fun(p, e, _, n) =>
let (p', typ) = elaborate_pattern(m, p);
let (e', tye) = elaborate(m, e);
Fun(p', e', n) |> rewrap |> cast_from(Arrow(typ, tye) |> Typ.temp);
Fun(p', e', Some(typ), n)
|> rewrap
|> cast_from(Arrow(typ, tye) |> Typ.temp);
| TypFun(tpat, e, name) =>
let (e', tye) = elaborate(m, e);
TypFun(tpat, e', name)
Expand All @@ -281,7 +283,7 @@ let rec elaborate = (m: Statics.Map.t, uexp: Exp.t): (DHExp.t, Typ.t) => {
(name, exp) => {
let (term, rewrap) = DHExp.unwrap(exp);
switch (term) {
| Fun(p, e, _) => Fun(p, e, name) |> rewrap
| Fun(p, e, t, _) => Fun(p, e, t, name) |> rewrap
| TypFun(tpat, e, _) => TypFun(tpat, e, name) |> rewrap
| _ => exp
};
Expand Down
2 changes: 1 addition & 1 deletion src/haz3lcore/statics/MakeTerm.re
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ and exp_term: unsorted => (Exp.term, list(Id.t)) = {
| (["$"], []) => UnOp(Meta(Unquote), r)
| (["-"], []) => UnOp(Int(Minus), r)
| (["!"], []) => UnOp(Bool(Not), r)
| (["fun", "->"], [Pat(pat)]) => Fun(pat, r, None)
| (["fun", "->"], [Pat(pat)]) => Fun(pat, r, None, None)
| (["fix", "->"], [Pat(pat)]) => FixF(pat, r, None)
| (["typfun", "->"], [TPat(tpat)]) => TypFun(tpat, r, None)
| (["let", "=", "in"], [Pat(pat), Exp(def)]) => Let(pat, def, r)
Expand Down
16 changes: 10 additions & 6 deletions src/haz3lcore/statics/Mode.re
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ let ty_of: t => Typ.t =
Forall(Var("syntypfun") |> TPat.fresh, Unknown(SynSwitch) |> Typ.temp)
|> Typ.temp; /* TODO: naming the type variable? */

let of_arrow = (ctx: Ctx.t, mode: t): (t, t) =>
switch (mode) {
| Syn
| SynFun
| SynTypFun => (Syn, Syn)
| Ana(ty) => ty |> Typ.matched_arrow(ctx) |> TupleUtil.map2(ana)
// ty is Some if the expression is an annotated lambda
let of_arrow = (ctx: Ctx.t, mode: t, ty: option(Typ.t)): (t, t) =>
switch (mode, ty) {
| (Syn | SynFun | SynTypFun, None) => (Syn, Syn)
| (Syn | SynFun | SynTypFun, Some(ty)) => (Ana(ty), Syn)
| (Ana(ty), None) => ty |> Typ.matched_arrow(ctx) |> TupleUtil.map2(ana)
| (Ana(ty), Some(ty')) =>
let (t1, t2) = ty |> Typ.matched_arrow(ctx);
(Typ.join(~fix=true, ctx, t1, ty') |> Option.value(~default=ty'), t2)
|> TupleUtil.map2(ana);
};

let of_forall = (ctx: Ctx.t, name_opt: option(string), mode: t): t =>
Expand Down
4 changes: 2 additions & 2 deletions src/haz3lcore/statics/Statics.re
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ and uexp_to_info_map =
let (args, m) = map_m_go(m, modes, args);
let arg_co_ctx = CoCtx.union(List.map(Info.exp_co_ctx, args));
add'(~self, ~co_ctx=CoCtx.union([fn.co_ctx, arg_co_ctx]), m);
| Fun(p, e, _) =>
let (mode_pat, mode_body) = Mode.of_arrow(ctx, mode);
| Fun(p, e, typ, _) =>
let (mode_pat, mode_body) = Mode.of_arrow(ctx, mode, typ);
let (p', _) =
go_pat(~is_synswitch=false, ~co_ctx=CoCtx.empty, ~mode=mode_pat, p, m);
let (e, m) = go'(~ctx=p'.ctx, ~mode=mode_body, e, m);
Expand Down
5 changes: 3 additions & 2 deletions src/haz3lcore/statics/Term.re
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ module Exp = {
new_bound_vars,
e,
)
| Fun(p, e, n) =>
| Fun(p, e, t, n) =>
let pat_bound_vars = Pat.bound_vars(p);
Fun(
p,
Expand All @@ -628,6 +628,7 @@ module Exp = {
pat_bound_vars @ new_bound_vars,
e,
),
t,
n,
)
|> rewrap;
Expand Down Expand Up @@ -742,7 +743,7 @@ module Exp = {

let rec get_fn_name = (e: t) => {
switch (e.term) {
| Fun(_, _, n) => n
| Fun(_, _, _, n) => n
| FixF(_, e, _) => get_fn_name(e)
| Parens(e) => get_fn_name(e)
| TypFun(_, _, n) => n
Expand Down
16 changes: 12 additions & 4 deletions src/haz3lcore/statics/TermBase.re
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ and exp_term =
| String(string)
| ListLit(list(exp_t))
| Constructor(string, typ_t) // Typ.t field is only meaningful in dynamic expressions
| Fun(pat_t, exp_t, option(Var.t))
| Fun(pat_t, exp_t, option(typ_t), option(Var.t)) // typ_t field is only used to display types in results
| TypFun(tpat_t, exp_t, option(Var.t))
| Tuple(list(exp_t))
| Var(Var.t)
Expand Down Expand Up @@ -304,7 +304,13 @@ and Exp: {
| FailedCast(e, t1, t2) =>
FailedCast(exp_map_term(e), typ_map_term(t1), typ_map_term(t2))
| ListLit(ts) => ListLit(List.map(exp_map_term, ts))
| Fun(p, e, f) => Fun(pat_map_term(p), exp_map_term(e), f)
| Fun(p, e, t, f) =>
Fun(
pat_map_term(p),
exp_map_term(e),
Option.map(typ_map_term, t),
f,
)
| TypFun(tp, e, f) => TypFun(tpat_map_term(tp), exp_map_term(e), f)
| Tuple(xs) => Tuple(List.map(exp_map_term, xs))
| Let(p, e1, e2) =>
Expand Down Expand Up @@ -369,8 +375,10 @@ and Exp: {
List.length(xs) == List.length(ys) && List.equal(fast_equal, xs, ys)
| (Constructor(c1, ty1), Constructor(c2, ty2)) =>
c1 == c2 && Typ.fast_equal(ty1, ty2)
| (Fun(p1, e1, _), Fun(p2, e2, _)) =>
Pat.fast_equal(p1, p2) && fast_equal(e1, e2)
| (Fun(p1, e1, t1, _), Fun(p2, e2, t2, _)) =>
Pat.fast_equal(p1, p2)
&& fast_equal(e1, e2)
&& Option.equal(Typ.fast_equal, t1, t2)
| (TypFun(tp1, e1, _), TypFun(tp2, e2, _)) =>
TPat.fast_equal(tp1, tp2) && fast_equal(e1, e2)
| (Tuple(xs), Tuple(ys)) =>
Expand Down
6 changes: 3 additions & 3 deletions src/haz3lmenhir/Conversion.re
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ module rec Exp: {
| Fun(p, e, name_opt) =>
switch (name_opt) {
| Some(name_str) =>
Fun(Pat.of_menhir_ast(p), of_menhir_ast(e), Some(name_str))
| None => Fun(Pat.of_menhir_ast(p), of_menhir_ast(e), None)
Fun(Pat.of_menhir_ast(p), of_menhir_ast(e), None, Some(name_str))
| None => Fun(Pat.of_menhir_ast(p), of_menhir_ast(e), None, None)
}
| ApExp(e1, args) =>
switch (args) {
Expand Down Expand Up @@ -350,7 +350,7 @@ module rec Exp: {
| Constructor(s, typ) => Constructor(s, Typ.of_core(typ))
| DeferredAp(e, es) =>
ApExp(of_core(e), TupleExp(List.map(of_core, es)))
| Fun(p, e, name_opt) => Fun(Pat.of_core(p), of_core(e), name_opt)
| Fun(p, e, _, name_opt) => Fun(Pat.of_core(p), of_core(e), name_opt)
| Ap(Reverse, _, _) => raise(Failure("Reverse not supported"))
};
};
Expand Down
2 changes: 1 addition & 1 deletion src/haz3lweb/app/explainthis/ExplainThis.re
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ let get_doc =
};
/* TODO: More could be done here probably for different patterns. */
basic(TypFunctionExp.type_functions_basic);
| Fun(pat, body, _) =>
| Fun(pat, body, _, _) =>
let basic = group_id => {
let pat_id = List.nth(pat.ids, 0);
let body_id = List.nth(body.ids, 0);
Expand Down
8 changes: 4 additions & 4 deletions src/haz3lweb/exercises/SyntaxTest.re
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ let rec find_fn = (name: string, uexp: Exp.t, l: list(Exp.t)): list(Exp.t) => {
List.fold_left((acc, u1) => {find_fn(name, u1, acc)}, l, ul)
| TypFun(_, body, _)
| FixF(_, body, _)
| Fun(_, body, _) => l |> find_fn(name, body)
| Fun(_, body, _, _) => l |> find_fn(name, body)
| TypAp(u1, _)
| Parens(u1)
| Cast(u1, _, _)
Expand Down Expand Up @@ -178,7 +178,7 @@ let rec var_mention = (name: string, uexp: Exp.t): bool => {
| Constructor(_)
| Undefined
| Deferral(_) => false
| Fun(args, body, _) =>
| Fun(args, body, _, _) =>
var_mention_upat(name, args) ? false : var_mention(name, body)
| ListLit(l)
| Tuple(l) =>
Expand Down Expand Up @@ -239,7 +239,7 @@ let rec var_applied = (name: string, uexp: Exp.t): bool => {
| Constructor(_)
| Undefined
| Deferral(_) => false
| Fun(args, body, _)
| Fun(args, body, _, _)
| FixF(args, body, _) =>
var_mention_upat(name, args) ? false : var_applied(name, body)
| ListLit(l)
Expand Down Expand Up @@ -332,7 +332,7 @@ let rec tail_check = (name: string, uexp: Exp.t): bool => {
| Var(_)
| BuiltinFun(_) => true
| FixF(args, body, _)
| Fun(args, body, _) =>
| Fun(args, body, _, _) =>
var_mention_upat(name, args) ? false : tail_check(name, body)
| Let(p, def, body) =>
var_mention_upat(name, p) || var_mention(name, def)
Expand Down
Loading

0 comments on commit 2b3225c

Please sign in to comment.