Skip to content

Commit

Permalink
Improve type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-byers committed Jul 19, 2024
1 parent ff49dc2 commit 86ab83e
Show file tree
Hide file tree
Showing 23 changed files with 784 additions and 726 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ assert(vec[-1:] == [3])

### Maps
```
// empty map requires type annotation
let empty: [int: string] = [:]
let empty = [:]
empty[0] = 'abc'
// infer K = int, V = string
let map = [1: 'a', 2: 'b']
map[3] = 42
Expand Down Expand Up @@ -314,7 +316,6 @@ assert(status != 0)
## Known problems
+ Compiler will allow functions that don't return a value in all code paths
+ Likely requires a CFG and some data flow analysis: it would be very difficult to get right otherwise
+ Leaking unification tables: maybe allocate them in the AST arena
+ It isn't possible to create an empty vector or map of a given type without creating a temporary: `let vec: [int] = []`
+ Could use Swift syntax, or something similar:
```
Expand Down
3 changes: 0 additions & 3 deletions src/api.c
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ int paw_load(paw_Env *P, paw_Reader input, const char *name, void *ud)
const int status = pawC_try(P, parse_aux, &p);
pawM_free_vec(P, p.mem.scratch.data, p.mem.scratch.alloc);
pawM_free_vec(P, p.mem.labels.values, p.mem.labels.capacity);
while (p.mem.unifier.table) {
pawU_leave_binder(&p.mem.unifier); // TODO: leaking unification tables! fix it
}
return status;
}

Expand Down
170 changes: 82 additions & 88 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ static void visit_signature_expr(AstVisitor *V, FuncType *e)

static void visit_closure_expr(AstVisitor *V, ClosureExpr *e)
{
visit_exprs(V, e->params);
visit_decls(V, e->params);
V->visit_expr(V, e->result);
V->visit_block_stmt(V, e->body);
}
Expand Down Expand Up @@ -568,8 +568,8 @@ static void visit_expr(AstVisitor *V, AstExpr *expr)
case EXPR_CLOSURE:
V->visit_closure_expr(V, &expr->clos);
break;
case EXPR_FUNCTYPE:
V->visit_signature_expr(V, &expr->func);
case EXPR_SIGNATURE:
V->visit_signature_expr(V, &expr->sig);
break;
case EXPR_TYPELIST:
V->visit_typelist_expr(V, &expr->typelist);
Expand Down Expand Up @@ -847,7 +847,7 @@ static AstExpr *fold_signature_expr(AstFolder *F, FuncType *e)

static AstExpr *fold_closure_expr(AstFolder *F, ClosureExpr *e)
{
fold_exprs(F, e->params);
fold_decls(F, e->params);
e->result = F->fold_expr(F, e->result);
fold_block(F, e->body);
return cast_expr(e);
Expand Down Expand Up @@ -1085,8 +1085,8 @@ static AstExpr *fold_expr(AstFolder *F, AstExpr *expr)
return F->fold_sitem_expr(F, &expr->sitem);
case EXPR_CLOSURE:
return F->fold_closure_expr(F, &expr->clos);
case EXPR_FUNCTYPE:
return F->fold_signature_expr(F, &expr->func);
case EXPR_SIGNATURE:
return F->fold_signature_expr(F, &expr->sig);
case EXPR_VECTORTYPE:
return F->fold_vtype_expr(F, &expr->vtype);
case EXPR_MAPTYPE:
Expand Down Expand Up @@ -1281,7 +1281,6 @@ static AstType *fold_tuple(AstTypeFolder *F, AstTupleType *t)

static AstType *fold_adt(AstTypeFolder *F, AstAdt *t)
{
printf("FOLD-ADT %d\n", t->base);
if (t->types != NULL) {
F->fold_binder(F, t->types);
}
Expand Down Expand Up @@ -1471,15 +1470,15 @@ static AstStmt *copy_expr_stmt(AstFolder *F, AstExprStmt *s)
static AstExpr *copy_signature_expr(AstFolder *F, FuncType *e)
{
AstExpr *r = copy_prep_expr(F, e);
r->func.params = copy_exprs(F, e->params);
r->func.result = F->fold_expr(F, e->result);
r->sig.params = copy_exprs(F, e->params);
r->sig.result = F->fold_expr(F, e->result);
return r;
}

static AstExpr *copy_closure_expr(AstFolder *F, ClosureExpr *e)
{
AstExpr *r = copy_prep_expr(F, e);
r->clos.params = copy_exprs(F, e->params);
r->clos.params = copy_decls(F, e->params);
r->clos.result = F->fold_expr(F, e->result);
r->clos.body = copy_block(F, e->body);
return r;
Expand Down Expand Up @@ -2005,8 +2004,8 @@ static AstStmt *stencil_expr_stmt(AstFolder *F, AstExprStmt *s)
static AstExpr *stencil_signature_expr(AstFolder *F, FuncType *e)
{
AstExpr *r = stencil_prep_expr(F, e);
r->func.params = stencil_exprs(F, e->params);
r->func.result = F->fold_expr(F, e->result);
r->sig.params = stencil_exprs(F, e->params);
r->sig.result = F->fold_expr(F, e->result);
return r;
}

Expand All @@ -2016,7 +2015,7 @@ static AstExpr *stencil_closure_expr(AstFolder *F, ClosureExpr *e)

ScopeState state;
enter_scope(F, &state, e->scope);
r->clos.params = stencil_exprs(F, e->params);
r->clos.params = stencil_decls(F, e->params);
r->clos.result = F->fold_expr(F, e->result);
r->clos.body = stencil_block(F, e->body);
r->clos.scope = leave_scope(F);
Expand Down Expand Up @@ -2498,7 +2497,7 @@ static void init_links(Lex *lex, Map *map, FuncDecl *base, InstanceDecl *inst)
}
}

FuncDecl *pawA_stencil_func(Ast *ast, FuncDecl *base, AstDecl *inst)
static FuncDecl *do_stencil_func(Ast *ast, FuncDecl *base, AstDecl *inst)
{
paw_Env *P = env(ast->lex);
Value *pv = pawC_push0(P);
Expand Down Expand Up @@ -2535,11 +2534,13 @@ FuncDecl *pawA_stencil_func(Ast *ast, FuncDecl *base, AstDecl *inst)
AstDecl copy = *inst;
inst->func.kind = DECL_FUNC;
inst->func.name = base->name;
inst->func.generics = copy.inst.types;
inst->func.generics = copy.inst.types; // TODO: no generics on instances!
inst->func.generics = new_list(ast, 0); // TODO
add_symbol(&F, inst); // callee slot
inst->func.params = stencil_decls(&F, base->params);
inst->func.result = F.fold_expr(&F, base->result);
inst->func.body = stencil_block(&F, base->body);
inst->func.monos = new_list(ast, 0);
inst->func.fn_kind = base->fn_kind;

inst->func.scope = leave_scope(&F);
Expand All @@ -2548,81 +2549,25 @@ FuncDecl *pawA_stencil_func(Ast *ast, FuncDecl *base, AstDecl *inst)
return &inst->func;
}

struct Normalizer {
struct AstTypeFolder fold;
struct Unifier *U;
struct Ast *ast;
UniTable *table;
};

static AstType *normalize_unknown(AstTypeFolder *F, AstUnknown *t)
static void stencil_func(AstVisitor *V, FuncDecl *d)
{
struct Normalizer *N = F->state;
AstType *type = pawU_normalize(N->table, a_cast_type(t));
if (a_is_unknown(type)) {
// TODO: better error messaage, line number, etc.
pawE_error(env(N->ast->lex), PAW_ETYPE, -1, "unable to infer type");
}
return type;
}

static void normalize_expr(AstVisitor *V, AstExpr *expr)
{
if (expr != NULL) {
visit_expr(V, expr);

struct Normalizer *N = V->state.N;
expr->hdr.type = pawA_fold_type(&N->fold, a_type(expr));
if (d->generics->count == 0) {
V->visit_block_stmt(V, d->body);
return;
}
}

static void normalize_decl(AstVisitor *V, AstDecl *decl)
{
if (decl != NULL) {
visit_decl(V, decl);

struct Normalizer *N = V->state.N;
decl->hdr.type = pawA_fold_type(&N->fold, a_type(decl));
for (int i = 0; i < d->monos->count; ++i) {
AstDecl *decl = d->monos->data[i];
do_stencil_func(V->ast, d, decl);
V->visit_func_decl(V, &decl->func);
}
}

static void normalize_func_decl(AstVisitor *V, FuncDecl *d)
{
// do not enter nested function bodies
paw_unused(V);
paw_unused(d);
}

static void setup_normalizer(Ast *ast, struct Normalizer *N, AstVisitor *V, UniTable *table)
{
*N = (struct Normalizer){
.U = &ast->lex->pm->unifier,
.table = table,
.ast = ast,
};
pawA_visitor_init(V, N->ast, (union AstState){.N = N});
V->visit_expr = normalize_expr;
V->visit_decl = normalize_decl;
V->visit_func_decl = normalize_func_decl;

pawA_type_folder_init(&N->fold, N);
N->fold.fold_unknown = normalize_unknown;
}

void pawA_normalize_expr(Ast *ast, AstExpr *expr, UniTable *table)
{
struct Normalizer N;
struct AstVisitor V;
setup_normalizer(ast, &N, &V, table);
V.visit_expr(&V, expr);
}

void pawA_normalize_stmt(Ast *ast, AstStmt *stmt, UniTable *table)
void pawA_stencil_stmts(Ast *ast, AstList *stmts)
{
struct Normalizer N;
struct AstVisitor V;
setup_normalizer(ast, &N, &V, table);
V.visit_stmt(&V, stmt);
AstVisitor V;
pawA_visitor_init(&V, ast, (AstState){0});
V.visit_func_decl = stencil_func;
visit_stmts(&V, stmts);
}

typedef struct Printer {
Expand Down Expand Up @@ -2785,7 +2730,7 @@ static void print_expr_kind(Printer *P, void *node)
case EXPR_SELECTOR:
fprintf(P->out, "Selector");
break;
case EXPR_FUNCTYPE:
case EXPR_SIGNATURE:
fprintf(P->out, "FuncType");
break;
case EXPR_VECTORTYPE:
Expand Down Expand Up @@ -3187,10 +3132,10 @@ static void dump_expr(Printer *P, AstExpr *e)
dump_msg(P, "value: ");
dump_expr(P, e->mtype.value);
break;
case EXPR_FUNCTYPE:
dump_expr_list(P, e->func.params, "params");
case EXPR_SIGNATURE:
dump_expr_list(P, e->sig.params, "params");
dump_msg(P, "result: ");
dump_expr(P, e->func.result);
dump_expr(P, e->sig.result);
break;
case EXPR_MATCH:
dump_msg(P, "target: ");
Expand All @@ -3209,6 +3154,55 @@ static void dump_expr(Printer *P, AstExpr *e)
dump_msg(P, "}\n");
}

// TODO: Have this output a String, or fill a Buffer, move somewhere else
void pawA_repr_type(FILE *out, const AstType *type)
{
switch (a_kind(type)) {
case AST_TYPE_TUPLE:
fprintf(out, "(");
for (int i = 0; i < type->tuple.elems->count; ++i) {
dump_type(out, type->tuple.elems->data[i]);
if (i < type->tuple.elems->count - 1) {
fprintf(out, ", ");
}
}
fprintf(out, ")");
break;
case AST_TYPE_FPTR:
case AST_TYPE_FUNC:
fprintf(out, "fn(");
for (int i = 0; i < type->fptr.params->count; ++i) {
dump_type(out, type->fptr.params->data[i]);
if (i < type->fptr.params->count - 1) {
fprintf(out, ", ");
}
}
fprintf(out, ") -> ");
dump_type(out, type->fptr.result);
break;
case AST_TYPE_ADT:
fprintf(out, "%d", type->adt.base); // TODO: Print the name
if (type->adt.types != NULL) {
fprintf(out, "<");
const AstList *binder = type->adt.types;
for (int i = 0; i < binder->count; ++i) {
dump_type(out, binder->data[i]);
if (i < binder->count - 1) {
fprintf(out, ", ");
}
}
fprintf(out, ">");
}
break;
case AST_TYPE_UNKNOWN:
fprintf(out, "?%d", type->unknown.index);
break;
default:
paw_assert(a_is_generic(type));
fprintf(out, "?%s", type->generic.name->text);
}
}

void pawA_dump_type(FILE *out, AstType *type)
{
Printer P = {.out = out};
Expand Down
12 changes: 6 additions & 6 deletions src/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ typedef struct AstTypeHeader {
typedef struct AstGeneric {
AST_TYPE_HEADER;
String *name;
DefId did;
} AstGeneric;

// Represents a type that is in the process of being inferred
Expand Down Expand Up @@ -442,7 +443,7 @@ typedef enum AstExprKind {
EXPR_STRUCTITEM,
EXPR_MATCH,
EXPR_MATCHARM,
EXPR_FUNCTYPE,
EXPR_SIGNATURE,
EXPR_VECTORTYPE,
EXPR_MAPTYPE,
EXPR_TYPELIST,
Expand Down Expand Up @@ -640,7 +641,7 @@ typedef struct AstExpr {
PathType pathtype;
StructItem sitem;
MapItem mitem;
FuncType func; // TODO: rename ftype
FuncType sig;
VectorType vtype;
MapType mtype;
MatchExpr match;
Expand Down Expand Up @@ -754,7 +755,6 @@ typedef union AstState {
struct Resolver *R; // symbol resolution (pass 2) state
struct Generator *G; // code generation (pass 3) state
struct Stenciler *S; // template expansion state
struct Normalizer *N; // type normalizer state
struct Copier *C; // AST copier state
} AstState;

Expand Down Expand Up @@ -964,9 +964,8 @@ Ast *pawA_new_ast(Lex *lex);
void pawA_free_ast(Ast *ast);

AstDecl *pawA_copy_decl(Ast *ast, AstDecl *decl);
FuncDecl *pawA_stencil_func(Ast *ast, FuncDecl *base, AstDecl *inst);
void pawA_normalize_expr(Ast *ast, AstExpr *expr, UniTable *table);
void pawA_normalize_stmt(Ast *ast, AstStmt *stmt, UniTable *table);

void pawA_stencil_stmts(Ast *ast, AstList *stmts);

DefId pawA_add_decl(Ast *ast, AstDecl *decl);
AstDecl *pawA_get_decl(Ast *ast, DefId id);
Expand Down Expand Up @@ -997,6 +996,7 @@ AstDecl *pawA_get_decl(Ast *ast, DefId id);

#define a_adt_id(t) check_exp(a_is_adt(t), (t)->adt.base - PAW_TSTRING)

void pawA_repr_type(FILE *out, const AstType *type);
void pawA_dump_type(FILE *out, AstType *type);
void pawA_dump_path(FILE *out, AstPath *path);
void pawA_dump_decl(FILE *out, AstDecl *decl);
Expand Down
Loading

0 comments on commit 86ab83e

Please sign in to comment.