-
-
Notifications
You must be signed in to change notification settings - Fork 55
WIP: Use contextual dispatch for replacing functions #334
base: master
Are you sure you want to change the base?
Conversation
bors try |
tryBuild failed |
Yes, https://github.com/JuliaGPU/CUDAnative.jl/compare/tb/cassette |
bors try |
As bors tells us apparently not ;) @jrevels https://gitlab.com/JuliaGPU/CUDAnative.jl/-/jobs/153739960 is full of interesting cases. |
This comment has been minimized.
This comment has been minimized.
bors try |
This comment has been minimized.
This comment has been minimized.
Yeah, as I feared... Let's mark this WIP then 🙁 |
bors try |
tryBuild failed |
Same error count; inlining doesn't help. |
I was planning on grabbing Jarrett this week and see if we can figure it out. (I am in the progress to add GPU support to Cthulhu so that should make it easier) |
bors try |
tryBuild failed |
Ok! The debugging session with Jarrett proved fruitful, we are down to 10ish failures :) |
Cool! What were the changes? |
We applied my usual Cassette issue workaround of "isolate the problematic thing and make it a contextual primitive (i.e. don't overdub into it)". The problematic thing here was the It turns out that while Cassette propagates purity to the compiler correctly, the compiler is (probably rightfully) pessimistic and just bails out on purity optimization for generated functions (i.e. |
end | ||
end | ||
|
||
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...) | |
contextualize(f::F) where F = (args...) -> (Cassette.overdub(cudactx, f, args...); return nothing) |
We could go back to automatically returning nothing
here
bors try |
Rebased and removed the 265 hacks now that we have JuliaLang/julia#32237. EDIT: oh wow everything broke. @vchuravy and ideas? Looks like it the transform "works", just plenty of dynamic invocations: julia> kernel(a) = (a[1] = 1; nothing)
kernel (generic function with 2 methods)
julia> kernel([0])
julia> CUDAnative.contextualize(kernel)([0])
julia> @cuda kernel(cu([0]))
ERROR: InvalidIRError: compiling #148(CuDeviceArray{Float32,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to overdub)
Stacktrace:
[1] Val at essentials.jl:694
[2] setindex! at /home/tbesard/Julia/pkg/CUDAnative/src/device/array.jl:84
[3] kernel at REPL[5]:1
[4] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75 Seems something has changed with |
MWE: @generated bar(::Val{align}) where {align} = :(42)
foo(i) = i+bar(Val(1))
using Cassette
function transform(ctx, ref)
CI = ref.code_info
noinline = any(@nospecialize(x) ->
Core.Compiler.isexpr(x, :meta) &&
x.args[1] == :noinline,
CI.code)
CI.inlineable = !noinline
CI.ssavaluetypes = length(CI.code)
Core.Compiler.validate_code(CI)
return CI
end
const InlinePass = Cassette.@pass transform
Cassette.@context Ctx
const ctx = Cassette.disablehooks(Ctx(pass = InlinePass))
contextualize(f::F) where F = (args...) -> Cassette.overdub(ctx, f, args...)
using InteractiveUtils
code_llvm(foo, Tuple{Int})
code_llvm(contextualize(foo), Tuple{Int}) Doesn't need the addition, but otherwise we get a const jlapi function. Also doesn't need the inlining pass, but otherwise the LLVM contains a call to overdub, while it now clearly shows a dynamic invocation: ; @ /home/tbesard/Julia/wip2.jl:2 within `foo'
define i64 @julia_foo_15985(i64) {
top:
; ┌ @ int.jl:53 within `+'
%1 = add i64 %0, 42
; └
ret i64 %1
}
; @ /home/tbesard/Julia/wip2.jl:22 within `#7'
define i64 @"julia_#7_16073"(i64) {
top:
%1 = alloca %jl_value_t addrspace(10)*, i32 2
; ┌ @ /home/tbesard/Julia/wip2.jl:2 within `foo'
; │┌ @ essentials.jl:694 within `Val'
%2 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 0
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042584 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2
%3 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 1
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490826585840 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %3
%4 = call nonnull %jl_value_t addrspace(10)* @jl_apply_generic(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042232 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %1, i32 2)
; │└
; │┌ @ int.jl:53 within `+'
; ││┌ @ /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:465 within `_overdub_fallback'
; │││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:445 within `fallback'
; ││││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:447 within `call'
%5 = add i64 %0, 42
; └└└└└
ret i64 %5
} Bisected to JuliaLang/julia#31012 |
265 looks broken:
Also the
This one has a better on-device stack trace:
Some more, that fail for all their types:
|
Will look into these after my talk (especially the 265 one is disappointing). Having a reproducer with MiniCassette would be great. |
Yeah no hurry, just reporting in that your fix works. Once merged we'll have CI at least. |
Never mind, Pkg and I were confused about which version of Cassette we were using. |
b0ef268
to
e8de056
Compare
Turns out we weren't properly contextualizing all methods. Penalty: 20 additional failures.
|
Squashed and rebased. Added a Remaining failures are almost all dynamic calls to jl_f_tuple and jl_f_getfield. |
At least one source of those issues is the dynamic dispatch that gets introduced when passing a type. MWE: using Cassette
Cassette.@context Noop
contextualize(f::F) where F = (args...) -> Cassette.overdub(Noop(), f, args...)
function main()
a = [0]
function kernel(ptr)
unsafe_store!(ptr, 1)
return
end
contextualize(kernel)(pointer(a))
code_llvm(contextualize(kernel), Tuple{Ptr{Int}})
function kernel(T, ptr)
unsafe_store!(ptr, T(1))
return
end
contextualize(kernel)(Int, pointer(a))
code_llvm(contextualize(kernel), Tuple{Type{Int}, Ptr{Int}})
end define void @"julia_#34_19961"(i64) {
top:
%1 = inttoptr i64 %0 to i64*
store i64 1, i64* %1, align 1
ret void
}
define void @"julia_#34_19962"(%jl_value_t addrspace(10)* nonnull, i64) {
top:
%2 = alloca %jl_value_t addrspace(10)*, i32 2
%gcframe = alloca %jl_value_t addrspace(10)*, i32 3
%3 = bitcast %jl_value_t addrspace(10)** %gcframe to i8*
call void @llvm.memset.p0i8.i32(i8* %3, i8 0, i32 24, i32 0, i1 false)
%thread_ptr = call i8* asm "movq %fs:0, $0", "=r"()
%ptls_i8 = getelementptr i8, i8* %thread_ptr, i64 -15712
%ptls = bitcast i8* %ptls_i8 to %jl_value_t***
%4 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 0
%5 = bitcast %jl_value_t addrspace(10)** %4 to i64*
store i64 2, i64* %5
%6 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
%7 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
%8 = bitcast %jl_value_t addrspace(10)** %7 to %jl_value_t***
%9 = load %jl_value_t**, %jl_value_t*** %6
store %jl_value_t** %9, %jl_value_t*** %8
%10 = bitcast %jl_value_t*** %6 to %jl_value_t addrspace(10)***
store %jl_value_t addrspace(10)** %gcframe, %jl_value_t addrspace(10)*** %10
%11 = bitcast %jl_value_t*** %ptls to i8*
%12 = call noalias nonnull %jl_value_t addrspace(10)* @jl_gc_pool_alloc(i8* %11, i32 1400, i32 16) #1
%13 = bitcast %jl_value_t addrspace(10)* %12 to %jl_value_t addrspace(10)* addrspace(10)*
%14 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(10)* %13, i64 -1
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060992907120 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)* addrspace(10)* %14
%15 = bitcast %jl_value_t addrspace(10)* %12 to i64 addrspace(10)*
store i64 %1, i64 addrspace(10)* %15, align 8
%16 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
store %jl_value_t addrspace(10)* %12, %jl_value_t addrspace(10)** %16
%17 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 0
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060989456256 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %17
%18 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 1
store %jl_value_t addrspace(10)* %12, %jl_value_t addrspace(10)** %18
%19 = call nonnull %jl_value_t addrspace(10)* @jl_f_tuple(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* null to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2, i32 2)
%20 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
store %jl_value_t addrspace(10)* %19, %jl_value_t addrspace(10)** %20
%21 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 0
store %jl_value_t addrspace(10)* %19, %jl_value_t addrspace(10)** %21
%22 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 1
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060897026208 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %22
%23 = call nonnull %jl_value_t addrspace(10)* @jl_f_getfield(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* null to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2, i32 2)
%24 = bitcast %jl_value_t addrspace(10)* %23 to i64* addrspace(10)*
%25 = load i64*, i64* addrspace(10)* %24, align 8
store i64 1, i64* %25, align 1
%26 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
%27 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %26
%28 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
%29 = bitcast %jl_value_t*** %28 to %jl_value_t addrspace(10)**
store %jl_value_t addrspace(10)* %27, %jl_value_t addrspace(10)** %29
ret void
} code_warntype looks identical:
And there's no inference failure when looking with Ctulhu:
|
17dfd92
to
3c9b279
Compare
On 1.1 Cassette should be performant enough for these kinds of transforms.
Fixes https://github.com/JuliaGPU/CUDAnative.jl/issues/27
@maleadt did you have a branch similar to this around?