Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

WIP: Use contextual dispatch for replacing functions #334

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ git-tree-sha1 = "1fce616fa0806c67c133eb1d2f68f0f1a7504665"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "5.0.1"

[[Cassette]]
git-tree-sha1 = "36bd4e0088652b0b2d25a03e531f0d04258feb78"
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
version = "0.3.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Expand Down
2 changes: 2 additions & 0 deletions src/CUDAnative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const ptxas = Ref{String}()
include("utils.jl")

# needs to be loaded _before_ the compiler infrastructure, because of generated functions
isdevice() = false
include("device/tools.jl")
include("device/pointer.jl")
include("device/array.jl")
Expand All @@ -44,6 +45,7 @@ include("device/runtime.jl")
include("init.jl")

include("compiler.jl")
include("context.jl")
include("execution.jl")
include("exceptions.jl")
include("reflection.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Base.@kwdef struct CompilerJob
cap::VersionNumber
kernel::Bool

contextualize::Bool = true

# optional properties
minthreads::Union{Nothing,CuDim} = nothing
maxthreads::Union{Nothing,CuDim} = nothing
Expand Down
5 changes: 3 additions & 2 deletions src/compiler/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@ function codegen(target::Symbol, job::CompilerJob;
@timeit_debug to "validation" check_method(job)

@timeit_debug to "Julia front-end" begin
f = job.contextualize ? contextualize(job.f) : job.f

# get the method instance
world = typemax(UInt)
meth = which(job.f, job.tt)
sig = Base.signature_type(job.f, job.tt)::Type
meth = which(f, job.tt)
sig = Base.signature_type(f, job.tt)::Type
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
if VERSION >= v"1.2.0-DEV.320"
Expand Down
75 changes: 75 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
##
# Implements contextual dispatch through Cassette.jl
# Goals:
# - Rewrite common CPU functions to appropriate GPU intrinsics
#
# TODO:
# - error (erf, ...)
# - pow
# - min, max
# - mod, rem
# - gamma
# - bessel
# - distributions
# - unsorted

using Cassette

function transform(ctx, ref)
CI = ref.code_info
noinline = any(@nospecialize(x) ->
Core.Compiler.isexpr(x, :meta) &&
x.args[1] == :noinline,
CI.code)
CI.inlineable = !noinline

CI.ssavaluetypes = length(CI.code)
# Core.Compiler.validate_code(CI)
return CI
end

const InlinePass = Cassette.@pass transform

Cassette.@context CUDACtx
const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass))

###
# Cassette fixes
###

# kwfunc fix
Cassette.overdub(::CUDACtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f)

# the functions below are marked `@pure` and by rewritting them we hide that from
# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108).
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x)
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t)
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t)
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t)
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbits), x) = return Base.isbits(x)

@inline Cassette.overdub(::CUDACtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)

###
# Rewrite functions
###
Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) = true

# libdevice.jl
for f in (:cos, :cospi, :sin, :sinpi, :tan,
:acos, :asin, :atan,
:cosh, :sinh, :tanh,
:acosh, :asinh, :atanh,
:log, :log10, :log1p, :log2,
:exp, :exp2, :exp10, :expm1, :ldexp,
:isfinite, :isinf, :isnan,
:signbit, :abs,
:sqrt, :cbrt,
:ceil, :floor,)
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
@Base._inline_meta
return CUDAnative.$f(x)
end
end

contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
contextualize(f::F) where F = (args...) -> (Cassette.overdub(cudactx, f, args...); return nothing)

We could go back to automatically returning nothing here

3 changes: 2 additions & 1 deletion src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize
# the code it generates, or the execution
function split_kwargs(kwargs)
macro_kws = [:dynamic]
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name]
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name, :contextualize]
call_kws = [:cooperative, :blocks, :threads, :config, :shmem, :stream]
macro_kwargs = []
compiler_kwargs = []
Expand Down Expand Up @@ -351,6 +351,7 @@ The following keyword arguments are supported:
- `maxregs`: the maximum number of registers to be allocated to a single thread (only
supported on LLVM 4.0+)
- `name`: override the name that the kernel will have in the generated code
- `contextualize`: whether to contextualize functions using Cassette (default: true)

The output of this function is automatically cached, i.e. you can simply call `cufunction`
in a hot path without degrading performance. New code will be generated automatically, when
Expand Down
27 changes: 15 additions & 12 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
valid_kernel() = return
invalid_kernel() = 1

ir = sprint(io->CUDAnative.code_llvm(io, valid_kernel, Tuple{}; optimize=false, dump_module=true))
ir = sprint(io->CUDAnative.code_llvm(io, valid_kernel, Tuple{}; dump_module=true,
contextualize=false, optimize=false))

# module should contain our function + a generic call wrapper
@test occursin("define void @julia_valid_kernel", ir)
Expand Down Expand Up @@ -52,7 +53,7 @@ end
@noinline child(i) = sink(i)
parent(i) = child(i)

ir = sprint(io->CUDAnative.code_llvm(io, parent, Tuple{Int}))
ir = sprint(io->CUDAnative.code_llvm(io, parent, Tuple{Int}; contextualize=false))
@test occursin(r"call .+ @julia_child_", ir)
end

Expand All @@ -76,10 +77,10 @@ end
x::Int
end

ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}))
ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; contextualize=false))
@test occursin(r"@julia_kernel_\d+\(({ i64 }|\[1 x i64\]) addrspace\(\d+\)?\*", ir)

ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; kernel=true))
ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Aggregate}; contextualize=false, kernel=true))
@test occursin(r"@ptxcall_kernel_\d+\(({ i64 }|\[1 x i64\])\)", ir)
end

Expand Down Expand Up @@ -135,7 +136,7 @@ end
closure = ()->return

function test_name(f, name; kwargs...)
code = sprint(io->CUDAnative.code_llvm(io, f, Tuple{}; kwargs...))
code = sprint(io->CUDAnative.code_llvm(io, f, Tuple{}; contextualize=false, kwargs...))
@test occursin(name, code)
end

Expand Down Expand Up @@ -221,7 +222,7 @@ end
return
end

asm = sprint(io->CUDAnative.code_ptx(io, parent, Tuple{Int64}))
asm = sprint(io->CUDAnative.code_ptx(io, parent, Tuple{Int64}; contextualize=false))
@test occursin(r"call.uni\s+julia_child_"m, asm)
end

Expand All @@ -232,7 +233,7 @@ end
return
end

asm = sprint(io->CUDAnative.code_ptx(io, entry, Tuple{Int64}; kernel=true))
asm = sprint(io->CUDAnative.code_ptx(io, entry, Tuple{Int64}; contextualize=false, kernel=true))
@test occursin(r"\.visible \.entry ptxcall_entry_", asm)
@test !occursin(r"\.visible \.func julia_nonentry_", asm)
@test occursin(r"\.func julia_nonentry_", asm)
Expand Down Expand Up @@ -279,15 +280,15 @@ end
return
end

asm = sprint(io->CUDAnative.code_ptx(io, parent1, Tuple{Int}))
asm = sprint(io->CUDAnative.code_ptx(io, parent1, Tuple{Int}; contextualize=false))
@test occursin(r".func julia_child_", asm)

function parent2(i)
child(i+1)
return
end

asm = sprint(io->CUDAnative.code_ptx(io, parent2, Tuple{Int}))
asm = sprint(io->CUDAnative.code_ptx(io, parent2, Tuple{Int}; contextualize=false))
@test occursin(r".func julia_child_", asm)
end

Expand Down Expand Up @@ -357,7 +358,7 @@ end
closure = ()->nothing

function test_name(f, name; kwargs...)
code = sprint(io->CUDAnative.code_ptx(io, f, Tuple{}; kwargs...))
code = sprint(io->CUDAnative.code_ptx(io, f, Tuple{}; contextualize=false, kwargs...))
@test occursin(name, code)
end

Expand Down Expand Up @@ -429,7 +430,7 @@ end
return
end

ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Float32,Ptr{Float32}}))
ir = sprint(io->CUDAnative.code_llvm(io, kernel, Tuple{Float32,Ptr{Float32}}; contextualize=false))
@test occursin("jl_box_float32", ir)
CUDAnative.code_ptx(devnull, kernel, Tuple{Float32,Ptr{Float32}})
end
Expand All @@ -444,18 +445,20 @@ end

# some validation happens in the emit_function hook, which is called by code_llvm

# NOTE: contextualization changes order of frames
@testset "recursion" begin
@eval recurse_outer(i) = i > 0 ? i : recurse_inner(i)
@eval @noinline recurse_inner(i) = i < 0 ? i : recurse_outer(i)

@test_throws_message(CUDAnative.KernelError, CUDAnative.code_llvm(devnull, recurse_outer, Tuple{Int})) do msg
@test_throws_message(CUDAnative.KernelError, CUDAnative.code_llvm(devnull, recurse_outer, Tuple{Int}; contextualize=false)) do msg
occursin("recursion is currently not supported", msg) &&
occursin("[1] recurse_outer", msg) &&
occursin("[2] recurse_inner", msg) &&
occursin("[3] recurse_outer", msg)
end
end

# FIXME: contextualization removes all frames here -- changed inlining behavior?
@testset "base intrinsics" begin
foobar(i) = sin(i)

Expand Down
16 changes: 8 additions & 8 deletions test/device/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ end
@test_throws ErrorException @device_code_lowered nothing

# make sure kernel name aliases are preserved in the generated code
@test occursin("ptxcall_dummy", sprint(io->(@device_code_llvm io=io @cuda dummy())))
@test occursin("ptxcall_dummy", sprint(io->(@device_code_ptx io=io @cuda dummy())))
@test occursin("ptxcall_dummy", sprint(io->(@device_code_sass io=io @cuda dummy())))
@test occursin("ptxcall_dummy", sprint(io->(@device_code_llvm io=io @cuda contextualize=false dummy())))
@test occursin("ptxcall_dummy", sprint(io->(@device_code_ptx io=io @cuda contextualize=false dummy())))
@test occursin("ptxcall_dummy", sprint(io->(@device_code_sass io=io @cuda contextualize=false dummy())))

# make sure invalid kernels can be partially reflected upon
let
Expand All @@ -96,7 +96,7 @@ end

# set name of kernel
@test occursin("ptxcall_mykernel", sprint(io->(@device_code_llvm io=io begin
k = cufunction(dummy, name="mykernel")
k = cufunction(dummy; name="mykernel", contextualize=false)
k()
end)))
end
Expand Down Expand Up @@ -463,7 +463,7 @@ end
val_dev = CuArray(val)
cuda_ptr = pointer(val_dev)
ptr = CUDAnative.DevicePtr{Int}(cuda_ptr)
for i in (1, 10, 20, 35)
for i in (1, 10, 20, 32)
variables = ('a':'z'..., 'A':'Z'...)
params = [Symbol(variables[j]) for j in 1:i]
# generate a kernel
Expand Down Expand Up @@ -553,11 +553,11 @@ let (code, out, err) = julia_script(script, `-g2`)
@test occursin("ERROR: KernelException: exception thrown during kernel execution on device", err)
@test occursin("ERROR: a exception was thrown during kernel execution", out)
if VERSION < v"1.3.0-DEV.270"
@test occursin("[1] Type at float.jl", out)
@test occursin(r"\[.\] Type at float.jl", out)
else
@test occursin("[1] Int64 at float.jl", out)
@test occursin(r"\[.\] Int64 at float.jl", out)
end
@test occursin("[2] kernel at none:2", out)
@test occursin(r"\[.\] kernel at none:2", out)
end

end
Expand Down