-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scalar indexing error from GPU matmul against Zygote.OneElement #1005
Comments
Possible fixes are
Re 1. we should also wonder a bit what operations besides |
|
CUDA.jl now uses cuda's allocator which is actually much higher overhead than before. Its especially bad for small arrays, possibly hurting more of Flux. |
I think this is the actual MWE: using Zygote, Flux, CUDA
CUDA.allowscalar(false)
Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3]) == ([2,0,0],)
# dot(x::Zygote.OneElement, y::CuArray)
Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9]) == ([2 6 8; 0 2 0; 0 3 0],)
# generic_matmatmul!(C::CuArray, ..., A::Zygote.OneElement, ...) And this is the simplest attempt at an Adapt solution, but it doesn't get called. Why? using Adapt
a34 = Zygote.OneElement(3.4f0, (2,3), axes(rand(3,4)))
adapt(CuArray, a34) # isa OneElement
function Adapt.adapt(AT::Type{<: CUDA.AbstractGPUArray}, A::Zygote.OneElement{T2,N}) where {T2, N}
B = fill!(similar(AT{T2}, axes(A)), zero(T2))
CUDA.@allowscalar B[A.ind...] = A.val
@show A.val
B
end
adapt(CuArray, a34) # isa CuArray |
For the alternative plan, there are some attempts to overload https://gist.github.com/mcabbott/4ea43bea49a25c198a20f55f590735c4 But, as promised, it gets tricky to avoid ambiguities. |
Why does |
Simpler failure case with no indexing -- did
I see there's exactly one test of this in https://github.com/FluxML/Zygote.jl/blob/master/test/cuda.jl, from #929 I think. But it doesn't do the obvious test of the reverse order, which fails:
|
It did work just fine. We have movement tests in flux as well. |
Great, on which version exactly did you see
This would be https://github.com/FluxML/Flux.jl/blob/master/test/cuda/cuda.jl ? Which line exactly tests this? Why doesn't it catch |
Surely we can add tests, besides the test you suggest can be broken down into smaller chunks (specifically |
Indeed. You may even notice that I included that case above. But again, on which version, exactly, did these work? |
I would start with before the broadcasting and accumulation changes iirc. But I'm going to have check the versions to see where it's expected to work. |
MWE:
From:
SciML/DiffEqFlux.jl#571
The text was updated successfully, but these errors were encountered: