From 2c6912ed16f56ea7ae6c15af4cadc87e31ec62fd Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 14 Dec 2022 20:22:25 +0900 Subject: [PATCH] simplify caching logic --- src/abstractinterpret/abstractanalyzer.jl | 10 +- src/abstractinterpret/typeinfer.jl | 112 +++--------------- src/analyzers/jetanalyzer.jl | 84 ++++++------- src/analyzers/optanalyzer.jl | 14 +-- .../test_inferenceerrorreport.jl | 2 +- 5 files changed, 74 insertions(+), 148 deletions(-) diff --git a/src/abstractinterpret/abstractanalyzer.jl b/src/abstractinterpret/abstractanalyzer.jl index fe14aeee9..256cf01a1 100644 --- a/src/abstractinterpret/abstractanalyzer.jl +++ b/src/abstractinterpret/abstractanalyzer.jl @@ -536,23 +536,27 @@ get_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[res get_cached_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::JETCachedResult).reports get_any_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::AnyJETResult).reports +# HACK to avoid runtime dispatch +@inline push_report!(reports::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) = + @invoke push!(reports::Vector, report::InferenceErrorReport) + """ add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, report::InferenceErrorReport) Adds new [`report::InferenceErrorReport`](@ref InferenceErrorReport) associated with `result::InferenceResult`. """ function add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, @nospecialize(report::InferenceErrorReport)) - push!(get_reports(analyzer, result), report) + push_report!(get_reports(analyzer, result), report) return report end function add_cached_report!(analyzer::AbstractAnalyzer, caller::InferenceResult, @nospecialize(cached::InferenceErrorReport)) cached = copy_report′(cached) - push!(get_reports(analyzer, caller), cached) + push_report!(get_reports(analyzer, caller), cached) return cached end -add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push!(get_caller_cache(analyzer), report) +add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push_report!(get_caller_cache(analyzer), report) add_caller_cache!(analyzer::AbstractAnalyzer, reports::Vector{InferenceErrorReport}) = append!(get_caller_cache(analyzer), reports) # AbstractInterpreter diff --git a/src/abstractinterpret/typeinfer.jl b/src/abstractinterpret/typeinfer.jl index 5a5552fa9..5ec866193 100644 --- a/src/abstractinterpret/typeinfer.jl +++ b/src/abstractinterpret/typeinfer.jl @@ -240,8 +240,8 @@ end # cache # ===== -cache_report!(cache, @nospecialize(report::InferenceErrorReport)) = - push!(cache, copy_report′(report)::InferenceErrorReport) +cache_report!(cache::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) = + push_report!(cache, copy_report′(report)::InferenceErrorReport) struct AbstractAnalyzerView{Analyzer<:AbstractAnalyzer} analyzer::Analyzer @@ -340,6 +340,7 @@ end # @static if hasmethod(CC.transform_result_for_cache, (...)) function CC.transform_result_for_cache(analyzer::AbstractAnalyzer, linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + istoplevel(linfo) && return nothing cache = InferenceErrorReport[] for report in get_reports(analyzer, result) @static if JET_DEV_MODE @@ -543,104 +544,27 @@ function filter_lineages!(analyzer::AbstractAnalyzer, caller::InferenceResult, c filter!(!islineage(caller.linfo, current), get_reports(analyzer, caller)) end -# in this overload we can work on `frame.src::CodeInfo` (and also `frame::InferenceState`) -# where type inference (and also optimization if applied) already ran on -function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState) - CC.typeinf_nocycle(analyzer, frame) || return false # frame is now part of a higher cycle - # with no active ip's, frame is done - frames = frame.callers_in_cycle - isempty(frames) && push!(frames, frame) - valid_worlds = WorldRange() - for caller in frames - @assert !(caller.dont_work_on_me) - caller.dont_work_on_me = true - # might might not fully intersect these earlier, so do that now - valid_worlds = CC.intersect(caller.valid_worlds, valid_worlds) - end - for caller in frames - caller.valid_worlds = valid_worlds - CC.finish(caller, analyzer) - # finalize and record the linfo result - caller.inferred = true - end - # NOTE we don't discard `InferenceState`s here so that some analyzers can use them in `finish!` - # # collect results for the new expanded frame - # results = Tuple{InferenceResult, Vector{Any}, Bool}[ - # ( frames[i].result, - # frames[i].stmt_edges[1]::Vector{Any}, - # frames[i].cached ) - # for i in 1:length(frames) ] - # empty!(frames) - for frame in frames - caller = frame.result - opt = caller.src - if (@static VERSION ≥ v"1.9.0-DEV.1636" ? - (opt isa OptimizationState{typeof(analyzer)}) : - (opt isa OptimizationState)) - CC.optimize(analyzer, opt, OptimizationParams(analyzer), caller) - # # COMBAK we may want to enable inlining ? - # if opt.const_api - # # XXX: The work in ir_to_codeinf! is essentially wasted. The only reason - # # we're doing it is so that code_llvm can return the code - # # for the `return ...::Const` (which never runs anyway). We should do this - # # as a post processing step instead. - # CC.ir_to_codeinf!(opt) - # if result_type isa Const - # caller.src = result_type - # else - # @assert CC.isconstType(result_type) - # caller.src = Const(result_type.parameters[1]) - # end - # end - caller.valid_worlds = CC.getindex((opt.inlining.et::CC.EdgeTracker).valid_worlds) - end - end +function CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult) + reports = get_reports(analyzer, caller) - for frame in frames - caller = frame.result - edges = frame.stmt_edges[1]::Vector{Any} - cached = frame.cached - valid_worlds = caller.valid_worlds - if CC.last(valid_worlds) >= get_world_counter() - # if we aren't cached, we don't need this edge - # but our caller might, so let's just make it anyways - CC.store_backedges(caller, edges) - end - CC.finish!(analyzer, frame) - - reports = get_reports(analyzer, caller) + # XXX this is a dirty fix for performance problem, we need more "proper" fix + # https://github.com/aviatesk/JET.jl/issues/75 + unique!(aggregation_policy(analyzer), reports) - # XXX this is a dirty fix for performance problem, we need more "proper" fix - # https://github.com/aviatesk/JET.jl/issues/75 - unique!(aggregation_policy(analyzer), reports) + if get_entry(analyzer) !== caller.linfo + # inter-procedural handling: get back to the caller what we got from these results + add_caller_cache!(analyzer, reports) - # global cache management - if cached && !istoplevel(frame) - CC.cache_result!(analyzer, caller) - end - - if frame.parent !== nothing - # inter-procedural handling: get back to the caller what we got from these results - add_caller_cache!(analyzer, reports) - - # local cache management - # TODO there are duplicated work here and `transform_result_for_cache` - cache = InferenceErrorReport[] - for report in reports - cache_report!(cache, report) - end - set_cached_result!(analyzer, caller, cache) + # local cache management + # TODO there are duplicated work here and `transform_result_for_cache` + cache = InferenceErrorReport[] + for report in reports + cache_report!(cache, report) end + set_cached_result!(analyzer, caller, cache) end - return true -end - -# by default, this overload just is forwarded to the AbstractInterpreter's implementation -# but the only reason we have this overload is that some analyzers (like `JETAnalyzer`) -# can further overload this to generate `InferenceErrorReport` with an access to `frame` -function CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState) - return CC.finish!(analyzer, frame.result) + return @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult) end # top-level bridge diff --git a/src/analyzers/jetanalyzer.jl b/src/analyzers/jetanalyzer.jl index b89194409..06720a2e4 100644 --- a/src/analyzers/jetanalyzer.jl +++ b/src/analyzers/jetanalyzer.jl @@ -170,18 +170,17 @@ function CC.InferenceState(result::InferenceResult, cache::Symbol, analyzer::JET return frame end -function CC.finish!(analyzer::JETAnalyzer, frame::InferenceState) - src = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState) - - if isnothing(src) - # caught in cycle, similar error should have been reported where the source is available - return src - else - code = (src::CodeInfo).code +function CC.finish!(analyzer::JETAnalyzer, caller::InferenceResult) + src = @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult) + if src isa CodeInfo # report pass for uncaught `throw` calls - ReportPass(analyzer)(UncaughtExceptionReport, analyzer, frame, code) - return src + ReportPass(analyzer)(UncaughtExceptionReport, analyzer, caller, src) + else + # very much optimized (nothing to report), or very much unoptimized: + # in a case of the latter, similar error should have been reported + # where the source is available end + return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult) end let # overload `abstract_call_gf_by_type` @@ -487,56 +486,60 @@ end Represents general `throw` calls traced during inference. This is reported only when it's not caught by control flow. """ -@jetreport struct UncaughtExceptionReport <: InferenceErrorReport - throw_calls::Vector{Tuple{Int,Expr}} # (pc, call) -end -function UncaughtExceptionReport(sv::InferenceState, throw_calls::Vector{Tuple{Int,Expr}}) - vf = get_virtual_frame(sv.linfo) - sig = Any[] - ncalls = length(throw_calls) - for (i, (pc, call)) in enumerate(throw_calls) - call_sig = get_sig_nowrap((sv, pc), call) - append!(sig, call_sig) - i ≠ ncalls && push!(sig, ", ") - end - return UncaughtExceptionReport([vf], Signature(sig), throw_calls) -end -function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport) - msg = length(throw_calls) == 1 ? "may throw" : "may throw either of" - print(io, msg) -end +@jetreport struct UncaughtExceptionReport <: InferenceErrorReport end +print_report_message(io::IO, ::UncaughtExceptionReport) = print(io, "may throw") +print_signature(::UncaughtExceptionReport) = false + +# @jetreport struct UncaughtExceptionReport <: InferenceErrorReport +# throw_calls::Vector{Tuple{Int,Expr}} # (pc, call) +# end +# function UncaughtExceptionReport(caller::InferenceResult, throw_calls::Vector{Tuple{Int,Expr}}) +# vf = get_virtual_frame(caller.linfo) +# sig = Any[] +# ncalls = length(throw_calls) +# for (i, (pc, call)) in enumerate(throw_calls) +# call_sig = get_sig_nowrap((caller.src::CodeInfo, pc), call) +# append!(sig, call_sig) +# i ≠ ncalls && push!(sig, ", ") +# end +# return UncaughtExceptionReport([vf], Signature(sig), throw_calls) +# end +# function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport) +# msg = length(throw_calls) == 1 ? "may throw" : "may throw either of" +# print(io, msg) +# end # report `throw` calls "appropriately" # this error report pass is very special, since 1.) it's tightly bound to the report pass of # `SeriousExceptionReport` and 2.) it involves "report filtering" on its own -function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any}) - if frame.bestguess === Bottom - report_uncaught_exceptions!(analyzer, frame, stmts) +function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo) + if caller.result === Bottom + report_uncaught_exceptions!(analyzer, caller, src) return true else # the non-`Bottom` result may mean `throw` calls from the children frames # (if exists) are caught and not propagated here # we don't want to cache the caught `UncaughtExceptionReport`s for this frame and # its parents, and just filter them away now - filter!(get_reports(analyzer, frame.result)) do @nospecialize(report::InferenceErrorReport) + filter!(get_reports(analyzer, caller)) do @nospecialize(report::InferenceErrorReport) return !isa(report, UncaughtExceptionReport) end end return false end -(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any}) = - report_uncaught_exceptions!(analyzer, frame, stmts) # yes, you want tons of false positives ! -function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any}) +(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo) = + report_uncaught_exceptions!(analyzer, caller, src) # yes, you want tons of false positives ! +function report_uncaught_exceptions!(analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo) # if the return type here is `Bottom` annotated, this _may_ mean there're uncaught # `throw` calls # XXX it's possible that the `throw` calls within them are all caught but the other # critical errors still make the return type `Bottom` # NOTE to reduce the false positive cases described above, we count `throw` calls # after optimization, since it may have eliminated "unreachable" `throw` calls - codelocs = frame.src.codelocs - linetable = frame.src.linetable::LineTable + codelocs = src.codelocs + linetable = src.linetable::LineTable reported_locs = nothing - for report in get_reports(analyzer, frame.result) + for report in get_reports(analyzer, caller) if isa(report, SeriousExceptionReport) if isnothing(reported_locs) reported_locs = LineInfoNode[] @@ -545,7 +548,7 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat end end throw_calls = nothing - for (pc, stmt) in enumerate(stmts) + for (pc, stmt) in enumerate(src.code) isa(stmt, Expr) || continue is_throw_call(stmt) || continue # if this `throw` is already reported, don't duplciate @@ -558,7 +561,8 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat push!(throw_calls, (pc, stmt)) end if !isnothing(throw_calls) && !isempty(throw_calls) - add_new_report!(analyzer, frame.result, UncaughtExceptionReport(frame, throw_calls)) + # TODO add_new_report!(analyzer, caller, UncaughtExceptionReport(caller, throw_calls)) + add_new_report!(analyzer, caller, UncaughtExceptionReport(caller)) return true end return false diff --git a/src/analyzers/optanalyzer.jl b/src/analyzers/optanalyzer.jl index da9dd137a..d268f8b38 100644 --- a/src/analyzers/optanalyzer.jl +++ b/src/analyzers/optanalyzer.jl @@ -218,7 +218,7 @@ struct OptAnalysisPass <: ReportPass end optanalyzer_function_filter(@nospecialize ft) = true -# TODO better to work only `finish!` +# TODO better to work only `finish!`, i.e. only work on `CodeInfo` (with static parameters) function CC.finish(frame::InferenceState, analyzer::OptAnalyzer) ret = @invoke CC.finish(frame::InferenceState, analyzer::AbstractAnalyzer) @@ -272,20 +272,15 @@ function (::OptAnalysisPass)(::Type{CapturedVariableReport}, analyzer::OptAnalyz return reported end -function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState) - caller = frame.result - +function CC.finish!(analyzer::OptAnalyzer, caller::InferenceResult) # get the source before running `finish!` to keep the reference to `OptimizationState` src = caller.src - - ret = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState) - if popfirst!(analyzer.__analyze_frame) ReportPass(analyzer)(OptimizationFailureReport, analyzer, caller) - if (@static VERSION ≥ v"1.9.0-DEV.1636" ? (src isa OptimizationState{typeof(analyzer)}) : (src isa OptimizationState)) # the compiler optimized it, analyze it + src.ir === nothing || CC.ir_to_codeinf!(src) ReportPass(analyzer)(RuntimeDispatchReport, analyzer, caller, src) elseif (@static JET_DEV_MODE ? true : false) if isa(src, CC.ConstAPI) @@ -298,8 +293,7 @@ function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState) end end end - - return ret + return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult) end # report optimization failure due to recursive calls, etc. diff --git a/test/abstractinterpret/test_inferenceerrorreport.jl b/test/abstractinterpret/test_inferenceerrorreport.jl index a28475f7a..4a9b29fac 100644 --- a/test/abstractinterpret/test_inferenceerrorreport.jl +++ b/test/abstractinterpret/test_inferenceerrorreport.jl @@ -168,7 +168,7 @@ end result = report_call(m.foo, (String,)) r = only(get_reports_with_test(result)) @test isa(r, UncaughtExceptionReport) - @test Any['(', 's', String, ')', ArgumentError] ⫇ r.sig._sig + @test_broken Any['(', 's', String, ')', ArgumentError] ⫇ r.sig._sig end sparams1(::Type{T}) where T = zero(T)