From 677195df4ab30bd9c588c3e2ebaec723b2b7da41 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 21 Sep 2019 19:09:36 -0400 Subject: [PATCH] Add a builtin that allows specifying which iterate method to use When using the Casette mechanism to intercept calls to _apply, a common strategy is to rewrite the function argument to properly consider the context and then falling back to regular _apply. However, as showin in https://github.com/jrevels/Cassette.jl/issues/146, this strategy is insufficient as the _apply itself may recurse into various `iterate` calls which are not properly tracked. This is an attempt to resolve this problem with a minimal performance penalty. Attempting to duplicate the _apply logic in julia, would lead to code that is very hard for inference (and nested Cassette passes to understand). In contrast, this simply adds a version of _apply that takes `iterate` as an explicit argument. Cassette and similar tools can override this argument and provide a function that properly allows the context to recurse through the iteration, while still allowing inference to take advantage of the special handling of _apply for simple cases. Also change the lowering of splatting to use this new intrinsic directly, thus fixing #26001. --- base/boot.jl | 1 + base/compiler/abstractinterpretation.jl | 27 ++++++++++++++++-------- base/compiler/compiler.jl | 2 +- base/compiler/ssair/inlining.jl | 21 ++++++++++--------- base/essentials.jl | 19 +++++++++-------- src/builtin_proto.h | 2 +- src/builtins.c | 28 ++++++++++++++++++------- src/codegen.cpp | 1 + src/julia-syntax.scm | 2 +- src/staticdata.c | 3 ++- 10 files changed, 67 insertions(+), 39 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index fd610eda65e6a2..ef0bc5e5074cb3 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -210,6 +210,7 @@ else const UInt = UInt32 end +function iterate end function Typeof end ccall(:jl_toplevel_eval_in, Any, (Any, Any), Core, quote diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 0645f41879b517..6146b895071afa 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -427,7 +427,7 @@ end # refine its type to an array of element types. # Union of Tuples of the same length is converted to Tuple of Unions. # returns an array of types -function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::InferenceState) +function precise_container_type(@nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState) if isa(typ, PartialStruct) && typ.typ.name === Tuple.name return typ.fields end @@ -489,17 +489,24 @@ function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::Infere elseif tti0 <: Array return Any[Vararg{eltype(tti0)}] else - return abstract_iteration(typ, vtypes, sv) + return abstract_iteration(itft, typ, vtypes, sv) end end # simulate iteration protocol on container type up to fixpoint -function abstract_iteration(@nospecialize(itertype), vtypes::VarTable, sv::InferenceState) +function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState) if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate) return Any[Vararg{Any}] end - iteratef = getfield(Main.Base, :iterate) - stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype], vtypes, sv) + if itft === nothing + iteratef = getfield(Main.Base, :iterate) + itft = Const(iteratef) + elseif isa(itft, Const) + iteratef = itft.val + else + return Any[Vararg{Any}] + end + stateordonet = abstract_call(iteratef, nothing, Any[itft, itertype], vtypes, sv) # Return Bottom if this is not an iterator. # WARNING: Changes to the iteration protocol must be reflected here, # this is not just an optimization. @@ -543,7 +550,7 @@ function abstract_iteration(@nospecialize(itertype), vtypes::VarTable, sv::Infer end # do apply(af, fargs...), where af is a function value -function abstract_apply(@nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, +function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS) aftw = widenconst(aft) if !isa(aft, Const) && (!isType(aftw) || has_free_typevars(aftw)) @@ -561,7 +568,7 @@ function abstract_apply(@nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarT for i = 1:nargs ctypes´ = [] for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]]) - cti = precise_container_type(ti, vtypes, sv) + cti = precise_container_type(itft, ti, vtypes, sv) if _any(t -> t === Bottom, cti) continue end @@ -634,7 +641,9 @@ end function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS) if f === _apply - return abstract_apply(argtypes[2], argtypes[3:end], vtypes, sv, max_methods) + return abstract_apply(nothing, argtypes[2], argtypes[3:end], vtypes, sv, max_methods) + elseif f === _apply_iterate + return abstract_apply(argtypes[2], argtypes[3], argtypes[4:end], vtypes, sv, max_methods) end la = length(argtypes) @@ -662,7 +671,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt end rt = builtin_tfunction(f, argtypes[2:end], sv) if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] ⊑ Tuple - cti = precise_container_type(argtypes[2], vtypes, sv) + cti = precise_container_type(nothing, argtypes[2], vtypes, sv) idx = argtypes[3].val if 1 <= idx <= length(cti) rt = unwrapva(cti[idx]) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index c2d97e40cded49..b44a60f23d626e 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -5,7 +5,7 @@ getfield(getfield(Main, :Core), :eval)(getfield(Main, :Core), :(baremodule Compi using Core.Intrinsics, Core.IR import Core: print, println, show, write, unsafe_write, stdout, stderr, - _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance + _apply, _apply_iterate, svec, apply_type, Builtin, IntrinsicFunction, MethodInstance, CodeInstance const getproperty = getfield const setproperty! = setfield! diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 3e07c7e4dbfeb2..7f1325ad84c098 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -592,11 +592,11 @@ function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize( end # This assumes the caller has verified that all arguments to the _apply call are Tuples. -function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}) - new_argexprs = Any[argexprs[2]] - new_atypes = Any[atypes[2]] +function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}, arg_start::Int) + new_argexprs = Any[argexprs[arg_start]] + new_atypes = Any[atypes[arg_start]] # loop over original arguments and flatten any known iterators - for i in 3:length(argexprs) + for i in (arg_start+1):length(argexprs) def = argexprs[i] def_type = atypes[i] if def_type isa PartialStruct @@ -882,11 +882,12 @@ end function inline_apply!(ir::IRCode, idx::Int, sig::Signature, params::Params) stmt = ir.stmts[idx] - while sig.f === Core._apply + while sig.f === Core._apply || sig.f === Core._apply_iterate + arg_start = sig.f === Core._apply ? 2 : 3 atypes = sig.atypes # Try to figure out the signature of the function being called # and if rewrite_apply_exprargs can deal with this form - for i = 3:length(atypes) + for i = (arg_start + 1):length(atypes) # TODO: We could basically run the iteration protocol here if !is_valid_type_for_apply_rewrite(atypes[i], params) return nothing @@ -894,13 +895,13 @@ function inline_apply!(ir::IRCode, idx::Int, sig::Signature, params::Params) end # Independent of whether we can inline, the above analysis allows us to rewrite # this apply call to a regular call - ft = atypes[2] - if length(atypes) == 3 && ft isa Const && ft.val === Core.tuple && atypes[3] ⊑ Tuple + ft = atypes[arg_start] + if length(atypes) == arg_start+1 && ft isa Const && ft.val === Core.tuple && atypes[arg_start+1] ⊑ Tuple # rewrite `((t::Tuple)...,)` to `t` - ir.stmts[idx] = stmt.args[3] + ir.stmts[idx] = stmt.args[arg_start+1] return nothing end - stmt.args, atypes = rewrite_apply_exprargs!(ir, idx, stmt.args, atypes) + stmt.args, atypes = rewrite_apply_exprargs!(ir, idx, stmt.args, atypes, arg_start) has_free_typevars(ft) && return nothing f = singleton_type(ft) sig = Signature(f, ft, atypes) diff --git a/base/essentials.jl b/base/essentials.jl index 4f36de7baecef5..d308207003d82c 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -113,6 +113,16 @@ macro _propagate_inbounds_meta() return Expr(:meta, :inline, :propagate_inbounds) end + +""" + iterate(iter [, state]) -> Union{Nothing, Tuple{Any, Any}} + +Advance the iterator to obtain the next element. If no elements +remain, `nothing` should be returned. Otherwise, a 2-tuple of the +next element and the new iteration state should be returned. +""" +function iterate end + """ convert(T, x) @@ -817,15 +827,6 @@ compute a definite answer. """ isdone(itr, state...) = missing -""" - iterate(iter [, state]) -> Union{Nothing, Tuple{Any, Any}} - -Advance the iterator to obtain the next element. If no elements -remain, `nothing` should be returned. Otherwise, a 2-tuple of the -next element and the new iteration state should be returned. -""" -function iterate end - """ isiterable(T) -> Bool diff --git a/src/builtin_proto.h b/src/builtin_proto.h index ca3096d3a57e07..85762ce2f9ffa9 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -23,7 +23,7 @@ DECLARE_BUILTIN(throw); DECLARE_BUILTIN(is); DECLARE_BUILTIN(typeof); DECLARE_BUILTIN(sizeof); DECLARE_BUILTIN(issubtype); DECLARE_BUILTIN(isa); DECLARE_BUILTIN(_apply); DECLARE_BUILTIN(_apply_pure); -DECLARE_BUILTIN(_apply_latest); +DECLARE_BUILTIN(_apply_latest); DECLARE_BUILTIN(_apply_iterate); DECLARE_BUILTIN(isdefined); DECLARE_BUILTIN(nfields); DECLARE_BUILTIN(tuple); DECLARE_BUILTIN(svec); DECLARE_BUILTIN(getfield); DECLARE_BUILTIN(setfield); diff --git a/src/builtins.c b/src/builtins.c index 69ce064badd0e3..f778b32a125be7 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -474,7 +474,7 @@ void STATIC_INLINE _grow_to(jl_value_t **root, jl_value_t ***oldargs, jl_svec_t static jl_function_t *jl_iterate_func JL_GLOBALLY_ROOTED; -JL_CALLABLE(jl_f__apply) +static jl_value_t *do_apply(jl_value_t *F, jl_value_t **args, uint32_t nargs, jl_value_t *iterate) { JL_NARGSV(apply, 1); jl_function_t *f = args[0]; @@ -516,10 +516,13 @@ JL_CALLABLE(jl_f__apply) extra += 1; } } - if (extra && jl_iterate_func == NULL) { - jl_iterate_func = jl_get_function(jl_top_module, "iterate"); - if (jl_iterate_func == NULL) - jl_undefined_var_error(jl_symbol("iterate")); + if (extra && iterate == NULL) { + if (jl_iterate_func == NULL) { + jl_iterate_func = jl_get_function(jl_top_module, "iterate"); + if (jl_iterate_func == NULL) + jl_undefined_var_error(jl_symbol("iterate")); + } + iterate = jl_iterate_func; } // allocate space for the argument array and gc roots for it // based on our previous estimates @@ -599,7 +602,7 @@ JL_CALLABLE(jl_f__apply) assert(extra > 0); jl_value_t *args[2]; args[0] = ai; - jl_value_t *next = jl_apply_generic(jl_iterate_func, args, 1); + jl_value_t *next = jl_apply_generic(iterate, args, 1); while (next != jl_nothing) { roots[stackalloc] = next; jl_value_t *value = jl_get_nth_field_checked(next, 0); @@ -614,7 +617,7 @@ JL_CALLABLE(jl_f__apply) roots[stackalloc + 1] = NULL; JL_GC_ASSERT_LIVE(state); args[1] = state; - next = jl_apply_generic(jl_iterate_func, args, 2); + next = jl_apply_generic(iterate, args, 2); } roots[stackalloc] = NULL; extra -= 1; @@ -629,6 +632,16 @@ JL_CALLABLE(jl_f__apply) return result; } +JL_CALLABLE(jl_f__apply_iterate) +{ + return do_apply(F, args+1, nargs-1, args[0]); +} + +JL_CALLABLE(jl_f__apply) +{ + return do_apply(F, args, nargs, NULL); +} + // this is like `_apply`, but with quasi-exact checks to make sure it is pure JL_CALLABLE(jl_f__apply_pure) { @@ -1301,6 +1314,7 @@ void jl_init_primitives(void) JL_GC_DISABLED // internal functions jl_builtin_apply_type = add_builtin_func("apply_type", jl_f_apply_type); jl_builtin__apply = add_builtin_func("_apply", jl_f__apply); + jl_builtin__apply_iterate = add_builtin_func("_apply_iterate", jl_f__apply_iterate); jl_builtin__expr = add_builtin_func("_expr", jl_f__expr); jl_builtin_svec = add_builtin_func("svec", jl_f_svec); add_builtin_func("_apply_pure", jl_f__apply_pure); diff --git a/src/codegen.cpp b/src/codegen.cpp index 2141ea22361c4d..dd0e96b8d49179 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -7278,6 +7278,7 @@ static void init_julia_llvm_env(Module *m) builtin_func_map[jl_f_typeassert] = jlcall_func_to_llvm("jl_f_typeassert", &jl_f_typeassert, m); builtin_func_map[jl_f_ifelse] = jlcall_func_to_llvm("jl_f_ifelse", &jl_f_ifelse, m); builtin_func_map[jl_f__apply] = jlcall_func_to_llvm("jl_f__apply", &jl_f__apply, m); + builtin_func_map[jl_f__apply_iterate] = jlcall_func_to_llvm("jl_f__apply_iterate", &jl_f__apply_iterate, m); builtin_func_map[jl_f__apply_pure] = jlcall_func_to_llvm("jl_f__apply_pure", &jl_f__apply_pure, m); builtin_func_map[jl_f__apply_latest] = jlcall_func_to_llvm("jl_f__apply_latest", &jl_f__apply_latest, m); builtin_func_map[jl_f_throw] = jlcall_func_to_llvm("jl_f_throw", &jl_f_throw, m); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index ac477a9966ac3a..9f24e2bfc22dc8 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2096,7 +2096,7 @@ (tuple-wrap (cdr a) '()))) (tuple-wrap (cdr a) (cons x run)))))) (expand-forms - `(call (core _apply) ,f ,@(tuple-wrap argl '()))))) + `(call (core _apply_iterate) (top iterate) ,f ,@(tuple-wrap argl '()))))) ((and (eq? (identifier-name f) '^) (length= e 4) (integer? (cadddr e))) (expand-forms diff --git a/src/staticdata.c b/src/staticdata.c index b500752ac2f774..752d2bef6c64f3 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -76,6 +76,7 @@ static void *const _tags[] = { // some Core.Builtin Functions that we want to be able to reference: &jl_builtin_throw, &jl_builtin_is, &jl_builtin_typeof, &jl_builtin_sizeof, &jl_builtin_issubtype, &jl_builtin_isa, &jl_builtin_typeassert, &jl_builtin__apply, + &jl_builtin__apply_iterate, &jl_builtin_isdefined, &jl_builtin_nfields, &jl_builtin_tuple, &jl_builtin_svec, &jl_builtin_getfield, &jl_builtin_setfield, &jl_builtin_fieldtype, &jl_builtin_arrayref, &jl_builtin_const_arrayref, &jl_builtin_arrayset, &jl_builtin_arraysize, @@ -109,7 +110,7 @@ static htable_t fptr_to_id; // This is a manually constructed dual of the fvars array, which would be produced by codegen for Julia code, for C. static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_throw, &jl_f_is, &jl_f_typeof, &jl_f_issubtype, &jl_f_isa, - &jl_f_typeassert, &jl_f__apply, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined, + &jl_f_typeassert, &jl_f__apply, &jl_f__apply_iterate, &jl_f__apply_pure, &jl_f__apply_latest, &jl_f_isdefined, &jl_f_tuple, &jl_f_svec, &jl_f_intrinsic_call, &jl_f_invoke_kwsorter, &jl_f_getfield, &jl_f_setfield, &jl_f_fieldtype, &jl_f_nfields, &jl_f_arrayref, &jl_f_const_arrayref, &jl_f_arrayset, &jl_f_arraysize, &jl_f_apply_type,