diff --git a/Manifest.toml b/Manifest.toml index 8e455713..ad49abcd 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -26,6 +26,11 @@ git-tree-sha1 = "1fce616fa0806c67c133eb1d2f68f0f1a7504665" uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" version = "5.0.1" +[[Cassette]] +git-tree-sha1 = "36bd4e0088652b0b2d25a03e531f0d04258feb78" +uuid = "7057c7e9-c182-5462-911a-8362d720325c" +version = "0.3.0" + [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4" diff --git a/Project.toml b/Project.toml index 498e1a5a..bdb011a1 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" diff --git a/src/CUDAnative.jl b/src/CUDAnative.jl index ff7e27f9..2ba9ebc8 100644 --- a/src/CUDAnative.jl +++ b/src/CUDAnative.jl @@ -34,6 +34,7 @@ const ptxas = Ref{String}() include("utils.jl") # needs to be loaded _before_ the compiler infrastructure, because of generated functions +isdevice() = false include("device/tools.jl") include("device/pointer.jl") include("device/array.jl") @@ -44,6 +45,7 @@ include("device/runtime.jl") include("init.jl") include("compiler.jl") +include("context.jl") include("execution.jl") include("exceptions.jl") include("reflection.jl") diff --git a/src/compiler/common.jl b/src/compiler/common.jl index 72350640..0e5a19f0 100644 --- a/src/compiler/common.jl +++ b/src/compiler/common.jl @@ -7,6 +7,8 @@ Base.@kwdef struct CompilerJob cap::VersionNumber kernel::Bool + contextualize::Bool = true + # optional properties minthreads::Union{Nothing,CuDim} = nothing maxthreads::Union{Nothing,CuDim} = nothing diff --git a/src/compiler/driver.jl b/src/compiler/driver.jl index d673991d..4fffe9de 100644 --- a/src/compiler/driver.jl +++ b/src/compiler/driver.jl @@ -62,11 +62,12 @@ function codegen(target::Symbol, job::CompilerJob; @timeit_debug to "validation" check_method(job) @timeit_debug to "Julia front-end" begin + f = job.contextualize ? contextualize(job.f) : job.f # get the method instance world = typemax(UInt) - meth = which(job.f, job.tt) - sig = Base.signature_type(job.f, job.tt)::Type + meth = which(f, job.tt) + sig = Base.signature_type(f, job.tt)::Type (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), sig, meth.sig)::Core.SimpleVector if VERSION >= v"1.2.0-DEV.320" diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 00000000..6d2c6bd9 --- /dev/null +++ b/src/context.jl @@ -0,0 +1,75 @@ +## +# Implements contextual dispatch through Cassette.jl +# Goals: +# - Rewrite common CPU functions to appropriate GPU intrinsics +# +# TODO: +# - error (erf, ...) +# - pow +# - min, max +# - mod, rem +# - gamma +# - bessel +# - distributions +# - unsorted + +using Cassette + +function transform(ctx, ref) + CI = ref.code_info + noinline = any(@nospecialize(x) -> + Core.Compiler.isexpr(x, :meta) && + x.args[1] == :noinline, + CI.code) + CI.inlineable = !noinline + + CI.ssavaluetypes = length(CI.code) + # Core.Compiler.validate_code(CI) + return CI +end + +const InlinePass = Cassette.@pass transform + +Cassette.@context CUDACtx +const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass)) + +### +# Cassette fixes +### + +# kwfunc fix +Cassette.overdub(::CUDACtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f) + +# the functions below are marked `@pure` and by rewritting them we hide that from +# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108). +@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x) +@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t) +@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t) +@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t) +@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbits), x) = return Base.isbits(x) + +@inline Cassette.overdub(::CUDACtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T) + +### +# Rewrite functions +### +Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) = true + +# libdevice.jl +for f in (:cos, :cospi, :sin, :sinpi, :tan, + :acos, :asin, :atan, + :cosh, :sinh, :tanh, + :acosh, :asinh, :atanh, + :log, :log10, :log1p, :log2, + :exp, :exp2, :exp10, :expm1, :ldexp, + :isfinite, :isinf, :isnan, + :signbit, :abs, + :sqrt, :cbrt, + :ceil, :floor,) + @eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64}) + @Base._inline_meta + return CUDAnative.$f(x) + end +end + +contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...) diff --git a/src/execution.jl b/src/execution.jl index 3dacef22..5d9339d3 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -9,7 +9,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize # the code it generates, or the execution function split_kwargs(kwargs) macro_kws = [:dynamic] - compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name] + compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name, :contextualize] call_kws = [:cooperative, :blocks, :threads, :config, :shmem, :stream] macro_kwargs = [] compiler_kwargs = [] @@ -351,6 +351,7 @@ The following keyword arguments are supported: - `maxregs`: the maximum number of registers to be allocated to a single thread (only supported on LLVM 4.0+) - `name`: override the name that the kernel will have in the generated code +- `contextualize`: whether to contextualize functions using Cassette (default: true) The output of this function is automatically cached, i.e. you can simply call `cufunction` in a hot path without degrading performance. New code will be generated automatically, when diff --git a/test/codegen.jl b/test/codegen.jl index 65ecb8e6..dac5d2c8 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -8,7 +8,8 @@ valid_kernel() = return invalid_kernel() = 1 - ir = sprint(io->CUDAnative.code_llvm(io, valid_kernel, Tuple{}; optimize=false, dump_module=true)) + ir = sprint(io->CUDAnative.code_llvm(io, valid_kernel, Tuple{}; dump_module=true, + contextualize=false, optimize=false)) # module should contain our function + a generic call wrapper @test occursin("define void @julia_valid_kernel", ir) @@ -52,7 +53,7 @@ end @noinline child(i) = sink(i) parent(i) = child(i) - ir = sprint(io->CUDAnative.code_llvm(io, parent, Tuple{Int})) + ir = sprint(io->CUDAnative.code_llvm(io, parent, Tuple{Int}; contextualize=false)) @test occursin(r"call .+ @julia_child_", ir) end @@ -76,10 +77,10 @@ end x::Int end - ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate})) + ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; contextualize=false)) @test occursin(r"@julia_kernel_\d+\(({ i64 }|\[1 x i64\]) addrspace\(\d+\)?\*", ir) - ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; kernel=true)) + ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; contextualize=false, kernel=true)) @test occursin(r"@ptxcall_kernel_\d+\(({ i64 }|\[1 x i64\])\)", ir) end @@ -135,7 +136,7 @@ end closure = ()->return function test_name(f, name; kwargs...) - code = sprint(io->CUDAnative.code_llvm(io, f, Tuple{}; kwargs...)) + code = sprint(io->CUDAnative.code_llvm(io, f, Tuple{}; contextualize=false, kwargs...)) @test occursin(name, code) end @@ -221,7 +222,7 @@ end return end - asm = sprint(io->CUDAnative.code_ptx(io, parent, Tuple{Int64})) + asm = sprint(io->CUDAnative.code_ptx(io, parent, Tuple{Int64}; contextualize=false)) @test occursin(r"call.uni\s+julia_child_"m, asm) end @@ -232,7 +233,7 @@ end return end - asm = sprint(io->CUDAnative.code_ptx(io, entry, Tuple{Int64}; kernel=true)) + asm = sprint(io->CUDAnative.code_ptx(io, entry, Tuple{Int64}; contextualize=false, kernel=true)) @test occursin(r"\.visible \.entry ptxcall_entry_", asm) @test !occursin(r"\.visible \.func julia_nonentry_", asm) @test occursin(r"\.func julia_nonentry_", asm) @@ -279,7 +280,7 @@ end return end - asm = sprint(io->CUDAnative.code_ptx(io, parent1, Tuple{Int})) + asm = sprint(io->CUDAnative.code_ptx(io, parent1, Tuple{Int}; contextualize=false)) @test occursin(r".func julia_child_", asm) function parent2(i) @@ -287,7 +288,7 @@ end return end - asm = sprint(io->CUDAnative.code_ptx(io, parent2, Tuple{Int})) + asm = sprint(io->CUDAnative.code_ptx(io, parent2, Tuple{Int}; contextualize=false)) @test occursin(r".func julia_child_", asm) end @@ -357,7 +358,7 @@ end closure = ()->nothing function test_name(f, name; kwargs...) - code = sprint(io->CUDAnative.code_ptx(io, f, Tuple{}; kwargs...)) + code = sprint(io->CUDAnative.code_ptx(io, f, Tuple{}; contextualize=false, kwargs...)) @test occursin(name, code) end @@ -429,7 +430,7 @@ end return end - ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Float32,Ptr{Float32}})) + ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Float32,Ptr{Float32}}; contextualize=false)) @test occursin("jl_box_float32", ir) CUDAnative.code_ptx(devnull, kernel, Tuple{Float32,Ptr{Float32}}) end @@ -444,11 +445,12 @@ end # some validation happens in the emit_function hook, which is called by code_llvm +# NOTE: contextualization changes order of frames @testset "recursion" begin @eval recurse_outer(i) = i > 0 ? i : recurse_inner(i) @eval @noinline recurse_inner(i) = i < 0 ? i : recurse_outer(i) - @test_throws_message(CUDAnative.KernelError, CUDAnative.code_llvm(devnull, recurse_outer, Tuple{Int})) do msg + @test_throws_message(CUDAnative.KernelError, CUDAnative.code_llvm(devnull, recurse_outer, Tuple{Int}; contextualize=false)) do msg occursin("recursion is currently not supported", msg) && occursin("[1] recurse_outer", msg) && occursin("[2] recurse_inner", msg) && @@ -456,6 +458,7 @@ end end end +# FIXME: contextualization removes all frames here -- changed inlining behavior? @testset "base intrinsics" begin foobar(i) = sin(i) diff --git a/test/device/execution.jl b/test/device/execution.jl index cb9d2527..33c96b81 100644 --- a/test/device/execution.jl +++ b/test/device/execution.jl @@ -70,9 +70,9 @@ end @test_throws ErrorException @device_code_lowered nothing # make sure kernel name aliases are preserved in the generated code - @test occursin("ptxcall_dummy", sprint(io->(@device_code_llvm io=io @cuda dummy()))) - @test occursin("ptxcall_dummy", sprint(io->(@device_code_ptx io=io @cuda dummy()))) - @test occursin("ptxcall_dummy", sprint(io->(@device_code_sass io=io @cuda dummy()))) + @test occursin("ptxcall_dummy", sprint(io->(@device_code_llvm io=io @cuda contextualize=false dummy()))) + @test occursin("ptxcall_dummy", sprint(io->(@device_code_ptx io=io @cuda contextualize=false dummy()))) + @test occursin("ptxcall_dummy", sprint(io->(@device_code_sass io=io @cuda contextualize=false dummy()))) # make sure invalid kernels can be partially reflected upon let @@ -96,7 +96,7 @@ end # set name of kernel @test occursin("ptxcall_mykernel", sprint(io->(@device_code_llvm io=io begin - k = cufunction(dummy, name="mykernel") + k = cufunction(dummy; name="mykernel", contextualize=false) k() end))) end @@ -463,7 +463,7 @@ end val_dev = CuArray(val) cuda_ptr = pointer(val_dev) ptr = CUDAnative.DevicePtr{Int}(cuda_ptr) - for i in (1, 10, 20, 35) + for i in (1, 10, 20, 32) variables = ('a':'z'..., 'A':'Z'...) params = [Symbol(variables[j]) for j in 1:i] # generate a kernel @@ -553,11 +553,11 @@ let (code, out, err) = julia_script(script, `-g2`) @test occursin("ERROR: KernelException: exception thrown during kernel execution on device", err) @test occursin("ERROR: a exception was thrown during kernel execution", out) if VERSION < v"1.3.0-DEV.270" - @test occursin("[1] Type at float.jl", out) + @test occursin(r"\[.\] Type at float.jl", out) else - @test occursin("[1] Int64 at float.jl", out) + @test occursin(r"\[.\] Int64 at float.jl", out) end - @test occursin("[2] kernel at none:2", out) + @test occursin(r"\[.\] kernel at none:2", out) end end