Skip to content

Commit

Permalink
Add a builtin that allows specifying which iterate method to use
Browse files Browse the repository at this point in the history
When using the Casette mechanism to intercept calls to _apply,
a common strategy is to rewrite the function argument to properly
consider the context and then falling back to regular _apply.
However, as showin in JuliaLabs/Cassette.jl#146,
this strategy is insufficient as the _apply itself may recurse into
various `iterate` calls which are not properly tracked. This is an
attempt to resolve this problem with a minimal performance penalty.
Attempting to duplicate the _apply logic in julia, would lead to
code that is very hard for inference (and nested Cassette passes to
understand). In contrast, this simply adds a version of _apply that
takes `iterate` as an explicit argument. Cassette and similar tools
can override this argument and provide a function that properly
allows the context to recurse through the iteration, while still
allowing inference to take advantage of the special handling of _apply
for simple cases.
  • Loading branch information
Keno committed Sep 21, 2019
1 parent 3a20af1 commit 5bb6d0a
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
2 changes: 2 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ end
function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
if f === _apply
return abstract_apply(argtypes[2], argtypes[3:end], vtypes, sv, max_methods)
elseif f === _apply_iterate
return abstract_apply(argtypes[3], argtypes[4:end], vtypes, sv, max_methods)
end

la = length(argtypes)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ getfield(getfield(Main, :Core), :eval)(getfield(Main, :Core), :(baremodule Compi
using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance
_apply, _apply_iterate, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance

const getproperty = getfield
const setproperty! = setfield!
Expand Down
2 changes: 1 addition & 1 deletion src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DECLARE_BUILTIN(throw); DECLARE_BUILTIN(is);
DECLARE_BUILTIN(typeof); DECLARE_BUILTIN(sizeof);
DECLARE_BUILTIN(issubtype); DECLARE_BUILTIN(isa);
DECLARE_BUILTIN(_apply); DECLARE_BUILTIN(_apply_pure);
DECLARE_BUILTIN(_apply_latest);
DECLARE_BUILTIN(_apply_latest); DECLARE_BUILTIN(_apply_iterate);
DECLARE_BUILTIN(isdefined); DECLARE_BUILTIN(nfields);
DECLARE_BUILTIN(tuple); DECLARE_BUILTIN(svec);
DECLARE_BUILTIN(getfield); DECLARE_BUILTIN(setfield);
Expand Down
28 changes: 21 additions & 7 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ void STATIC_INLINE _grow_to(jl_value_t **root, jl_value_t ***oldargs, jl_svec_t

static jl_function_t *jl_iterate_func JL_GLOBALLY_ROOTED;

JL_CALLABLE(jl_f__apply)
static jl_value_t *do_apply(jl_value_t *F, jl_value_t **args, uint32_t nargs, jl_value_t *iterate)
{
JL_NARGSV(apply, 1);
jl_function_t *f = args[0];
Expand Down Expand Up @@ -516,10 +516,13 @@ JL_CALLABLE(jl_f__apply)
extra += 1;
}
}
if (extra && jl_iterate_func == NULL) {
jl_iterate_func = jl_get_function(jl_top_module, "iterate");
if (jl_iterate_func == NULL)
jl_undefined_var_error(jl_symbol("iterate"));
if (extra && iterate == NULL) {
if (jl_iterate_func == NULL) {
jl_iterate_func = jl_get_function(jl_top_module, "iterate");
if (jl_iterate_func == NULL)
jl_undefined_var_error(jl_symbol("iterate"));
}
iterate = jl_iterate_func;
}
// allocate space for the argument array and gc roots for it
// based on our previous estimates
Expand Down Expand Up @@ -599,7 +602,7 @@ JL_CALLABLE(jl_f__apply)
assert(extra > 0);
jl_value_t *args[2];
args[0] = ai;
jl_value_t *next = jl_apply_generic(jl_iterate_func, args, 1);
jl_value_t *next = jl_apply_generic(iterate, args, 1);
while (next != jl_nothing) {
roots[stackalloc] = next;
jl_value_t *value = jl_fieldref(next, 0);
Expand All @@ -614,7 +617,7 @@ JL_CALLABLE(jl_f__apply)
roots[stackalloc + 1] = NULL;
JL_GC_ASSERT_LIVE(state);
args[1] = state;
next = jl_apply_generic(jl_iterate_func, args, 2);
next = jl_apply_generic(iterate, args, 2);
}
roots[stackalloc] = NULL;
extra -= 1;
Expand All @@ -629,6 +632,16 @@ JL_CALLABLE(jl_f__apply)
return result;
}

JL_CALLABLE(jl_f__apply_iterate)
{
return do_apply(F, args+1, nargs-1, args[0]);
}

JL_CALLABLE(jl_f__apply)
{
return do_apply(F, args, nargs, NULL);
}

// this is like `_apply`, but with quasi-exact checks to make sure it is pure
JL_CALLABLE(jl_f__apply_pure)
{
Expand Down Expand Up @@ -1301,6 +1314,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
// internal functions
jl_builtin_apply_type = add_builtin_func("apply_type", jl_f_apply_type);
jl_builtin__apply = add_builtin_func("_apply", jl_f__apply);
jl_builtin__apply_iterate = add_builtin_func("_apply_iterate", jl_f__apply_iterate);
jl_builtin__expr = add_builtin_func("_expr", jl_f__expr);
jl_builtin_svec = add_builtin_func("svec", jl_f_svec);
add_builtin_func("_apply_pure", jl_f__apply_pure);
Expand Down
1 change: 1 addition & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7278,6 +7278,7 @@ static void init_julia_llvm_env(Module *m)
builtin_func_map[jl_f_typeassert] = jlcall_func_to_llvm("jl_f_typeassert", &jl_f_typeassert, m);
builtin_func_map[jl_f_ifelse] = jlcall_func_to_llvm("jl_f_ifelse", &jl_f_ifelse, m);
builtin_func_map[jl_f__apply] = jlcall_func_to_llvm("jl_f__apply", &jl_f__apply, m);
builtin_func_map[jl_f__apply_iterate] = jlcall_func_to_llvm("jl_f__apply_iterate", &jl_f__apply_iterate, m);
builtin_func_map[jl_f__apply_pure] = jlcall_func_to_llvm("jl_f__apply_pure", &jl_f__apply_pure, m);
builtin_func_map[jl_f__apply_latest] = jlcall_func_to_llvm("jl_f__apply_latest", &jl_f__apply_latest, m);
builtin_func_map[jl_f_throw] = jlcall_func_to_llvm("jl_f_throw", &jl_f_throw, m);
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static void *const _tags[] = {
// some Core.Builtin Functions that we want to be able to reference:
&jl_builtin_throw, &jl_builtin_is, &jl_builtin_typeof, &jl_builtin_sizeof,
&jl_builtin_issubtype, &jl_builtin_isa, &jl_builtin_typeassert, &jl_builtin__apply,
&jl_builtin__apply_iterate,
&jl_builtin_isdefined, &jl_builtin_nfields, &jl_builtin_tuple, &jl_builtin_svec,
&jl_builtin_getfield, &jl_builtin_setfield, &jl_builtin_fieldtype, &jl_builtin_arrayref,
&jl_builtin_const_arrayref, &jl_builtin_arrayset, &jl_builtin_arraysize,
Expand Down Expand Up @@ -109,7 +110,7 @@ static htable_t fptr_to_id;
// This is a manually constructed dual of the fvars array, which would be produced by codegen for Julia code, for C.
static const jl_fptr_args_t id_to_fptrs[] = {
&jl_f_throw, &jl_f_is, &jl_f_typeof, &jl_f_issubtype, &jl_f_isa,
&jl_f_typeassert, &jl_f__apply, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined,
&jl_f_typeassert, &jl_f__apply, &jl_f__apply_iterate, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined,
&jl_f_tuple, &jl_f_svec, &jl_f_intrinsic_call, &jl_f_invoke_kwsorter,
&jl_f_getfield, &jl_f_setfield, &jl_f_fieldtype, &jl_f_nfields,
&jl_f_arrayref, &jl_f_const_arrayref, &jl_f_arrayset, &jl_f_arraysize, &jl_f_apply_type,
Expand Down

0 comments on commit 5bb6d0a

Please sign in to comment.