Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar indexing error from GPU matmul against Zygote.OneElement #1005

Closed
ChrisRackauckas opened this issue Jun 21, 2021 · 13 comments · Fixed by FluxML/Flux.jl#1704
Closed

Scalar indexing error from GPU matmul against Zygote.OneElement #1005

ChrisRackauckas opened this issue Jun 21, 2021 · 13 comments · Fixed by FluxML/Flux.jl#1704

Comments

@ChrisRackauckas
Copy link
Member

MWE:

using Zygote, CUDA
CUDA.allowscalar(false)
W = CuArray(rand(4,4))
x = Zygote.OneElement(1f0,(1,),axes(rand(4)))
W' * x # Scalar indexing

From:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots
u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)
ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]
function dudt_(u,p,t)
    x, y = u
    pend = cpu(p[end-1:end])
    @show typeof(p[1:length(p1)])
    @show typeof(gpu(u))
    @show cpu(ann(gpu(u),p[1:length(p1)]))[1]
    @show pend[1]*y + pend[2]*x
    [cpu(ann(gpu(u),p[1:length(p1)]))[1],pend[1]*y + pend[2]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)
function predict_adjoint(θ)
  gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=QuadratureAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)
cb = function (θ,l)
  println(l)
  #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
  return false
end
loss1 = loss_adjoint(θ)
Zygote.gradient(loss_adjoint,θ)

SciML/DiffEqFlux.jl#571

@ChrisRackauckas
Copy link
Member Author

@mcabbott

@mcabbott
Copy link
Member

CUDA.allowscalar(false) means that you can't get a OneElement from the gradient of W[1,1]. The key point of this example is that cpu(...)[1] does the indexing on a CPU array, but in the gradient the OneElement gets mixed up with GPU objects.

Possible fixes are

  1. To start overloading *(::OneElement, ::AbstractMatrix) etc. These are obviously very simple, and even on the CPU could be made more efficient than generic_matmul. The difficulty is that the dispatch for * is a minefield of type ambiguities.
  2. To use Adapt.jl to translate OneElement to a CuArray, so that the gradient of cpu literally moves it to the GPU. That would pretty much restore pre-RFC: more efficient ∇getindex  #962 behaviour.

Re 1. we should also wonder a bit what operations besides * might need overloading.

@ChrisRackauckas
Copy link
Member Author

* and + I think would go pretty far for this.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jun 21, 2021

CUDA.jl now uses cuda's allocator which is actually much higher overhead than before. Its especially bad for small arrays, possibly hurting more of Flux.

@mcabbott
Copy link
Member

I think this is the actual MWE:

using Zygote, Flux, CUDA
CUDA.allowscalar(false)

Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3]) == ([2,0,0],)  
# dot(x::Zygote.OneElement, y::CuArray)

Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9]) == ([2 6 8; 0 2 0; 0 3 0],) 
# generic_matmatmul!(C::CuArray, ..., A::Zygote.OneElement, ...)

And this is the simplest attempt at an Adapt solution, but it doesn't get called. Why?

using Adapt

a34 = Zygote.OneElement(3.4f0, (2,3), axes(rand(3,4)))
adapt(CuArray, a34) # isa OneElement

function Adapt.adapt(AT::Type{<: CUDA.AbstractGPUArray}, A::Zygote.OneElement{T2,N}) where {T2, N}
    B = fill!(similar(AT{T2}, axes(A)), zero(T2))
    CUDA.@allowscalar B[A.ind...] = A.val
    @show A.val
    B
end
adapt(CuArray, a34) # isa CuArray

@mcabbott
Copy link
Member

For the alternative plan, there are some attempts to overload * methods etc. here:

https://gist.github.com/mcabbott/4ea43bea49a25c198a20f55f590735c4

But, as promised, it gets tricky to avoid ambiguities.

@DhairyaLGandhi
Copy link
Member

Why does OneElement need arithmetic overloads anyway? It shouldn't leak to user facing code at all.

@mcabbott
Copy link
Member

Simpler failure case with no indexing -- did cpu/gpu ever work inside gradients?

julia> using CUDA, Zygote, Flux
julia> CUDA.allowscalar(false)
julia> a = rand(Float32, 4, 4); ca = cu(rand(4, 4));

julia> gradient(x -> sum(abs, cpu(ca * gpu(a * x))), a)
ERROR: ArgumentError: cannot take the CPU address of a CuArray{Float32, 2}

julia> gradient(x -> sum(abs, collect(ca * cu(a * x))), a)
ERROR: ArgumentError: cannot take the CPU address of a CuArray{Float32, 2}

I see there's exactly one test of this in https://github.com/FluxML/Zygote.jl/blob/master/test/cuda.jl, from #929 I think. But it doesn't do the obvious test of the reverse order, which fails:

julia> gradient(x -> sum(cu(x)), [1 2 3.0])[1]
1×3 Matrix{Float32}:
 1.0  1.0  1.0

julia> gradient(x -> sum(cpu(x)), cu([1 2 3.0]))[1]
1×3 Fill{Float32}, with entries equal to 1.0

julia> gradient(x -> sum(abs, cpu(x)), cu([1 2 3.0]))[1]
1×3 Matrix{Float32}:
 1.0  1.0  1.0

@DhairyaLGandhi
Copy link
Member

It did work just fine. We have movement tests in flux as well.

@mcabbott
Copy link
Member

It did work just fine.

Great, on which version exactly did you see gradient(x -> sum(abs, cpu(ca * gpu(a * x))), a) working? Then we can bisect.

We have movement tests in flux as well.

This would be https://github.com/FluxML/Flux.jl/blob/master/test/cuda/cuda.jl ? Which line exactly tests this? Why doesn't it catch gradient(x -> sum(abs, cpu(x)), cu([1 2 3.0]))[1] which seems the most obvious possible test?

@DhairyaLGandhi
Copy link
Member

Surely we can add tests, besides the test you suggest can be broken down into smaller chunks (specifically sum(cpu(x)) removing the abs).

@mcabbott
Copy link
Member

removing the abs

Indeed. You may even notice that I included that case above.

But again, on which version, exactly, did these work?

@DhairyaLGandhi
Copy link
Member

I would start with before the broadcasting and accumulation changes iirc. But I'm going to have check the versions to see where it's expected to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants