Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse-mode VJPs with mixed scalar+vector GPU code #632

Closed
ChrisRackauckas opened this issue Sep 28, 2021 · 8 comments
Closed

Reverse-mode VJPs with mixed scalar+vector GPU code #632

ChrisRackauckas opened this issue Sep 28, 2021 · 8 comments

Comments

@ChrisRackauckas
Copy link
Member

Very specific issue for a very specific case, but here's how it shows up. An example model:

using OrdinaryDiffEq
using DiffEqSensitivity
using LinearAlgebra
using Flux
using CUDA
using Random

rng = MersenneTwister(1234)
m = 32
n = 16
Z = randn(rng, Float32, (n,m)) |> gpu
𝒯 = 2.0
Δτ = 0.1
ca_init = [zeros(1) ; ones(m)] |> gpu

function f!(ċȧ, ca, Z, t)
  a = ca[2:end]

  a_unit = a / sum(a)
  w_unit = Z*a_unit
  Ka_unit = Z'*w_unit
  z_unit = dot(abs.(Ka_unit), a_unit)
  aKa_over_z = a .* Ka_unit / z_unit
  ċȧ[1] = sum(aKa_over_z) / m
  ċȧ[2:end] = -abs.(aKa_over_z)
end

function c(Z)
  prob = ODEProblem(f!, ca_init, (0.,𝒯), Z, saveat=Δτ)
  sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(), saveat=Δτ)
  #try this:
  #return last(sol.u)[1]
  #or this:
  return sol.u[20][1]
end

println("forward:", c(Z))
println("backward: ", Flux.Zygote.gradient(c, Z))

Reason this fails at first is because ReverseDiff.jl does not work on GPUs. So we should first improve the auto-VJP choice to take that into account. Cool, but then we end up in a larger problem. In-place differentiation requires one of the scalarizing reverse mode tape forms, and so those won't work on GPUs no matter what. Enzyme could be a solution, but @wsmoses how close is it to automatically handling CUDA.jl kernels?

So okay, we could do what we do with Neural ODEs on GPUs, which is namely to make it out-of-place and use Zygote vjps. Out of place form:

using OrdinaryDiffEq
using DiffEqSensitivity
using LinearAlgebra
using Flux
using CUDA
using Random

rng = MersenneTwister(1234)
m = 32
n = 16
Z = randn(rng, Float32, (n,m)) |> gpu
𝒯 = 2.0
Δτ = 0.1
ca_init = [zeros(1) ; ones(m)] |> gpu

function f(ca, Z, t)
  a = ca[2:end]

  a_unit = a / sum(a)
  w_unit = Z*a_unit
  Ka_unit = Z'*w_unit
  z_unit = dot(abs.(Ka_unit), a_unit)
  aKa_over_z = a .* Ka_unit / z_unit
  [sum(aKa_over_z) / m; -abs.(aKa_over_z)]
end

function c(Z)
  prob = ODEProblem(f, ca_init, (0.,𝒯), Z, saveat=Δτ)
  sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()), saveat=Δτ)
  #try this:
  return last(sol.u)[1]
  #or this:
  #return sol.u[20][1]
end

println("forward:", c(Z))
println("backward: ", Flux.Zygote.gradient(c, Z))

But this hits two issues. First of all, [sum(aKa_over_z) / m; -abs.(aKa_over_z)] is surprisingly not on the GPU which seems like a CUDA.jl issue.

JuliaGPU/CUDA.jl#1162

But secondly, if we try to say |> gpu inside of the rhs function (which would be slow, but hopefully work?), then we hit:

FluxML/Zygote.jl#1080

So fully GPU algorithms work in this form, but algorithms which have some scalar values don't have a nice way of recreating a GPU-based array in a way that is differentiable, hence this issue.

@DhairyaLGandhi just an interesting thing to note.

@DhairyaLGandhi
Copy link
Member

Could you elaborate what you mean by gpu being slow? Do you mean in the sense you incur network overhead to copy to vram or something else?

@DhairyaLGandhi
Copy link
Member

What if we refactor f! and handle only the mutation using Buffer or an adjoint? Buffer does not have to be slow if you're only assigning into it in chunks.

@ChrisRackauckas
Copy link
Member Author

What if we refactor f! and handle only the mutation using Buffer or an adjoint? Buffer does not have to be slow if you're only assigning into it in chunks.

That could be worth a try. I didn't try any Buffer solutions here.

Could you elaborate what you mean by gpu being slow? Do you mean in the sense you incur network overhead to copy to vram or something else?

I mean

function f(ca, Z, t)
  a = ca[2:end]

  a_unit = a / sum(a)
  w_unit = Z*a_unit
  Ka_unit = Z'*w_unit
  z_unit = dot(abs.(Ka_unit), a_unit)
  aKa_over_z = a .* Ka_unit / z_unit
  [sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu
end

If we are making an array on the CPU and sending it to the GPU with every f call, that would get expensive since the operations are all O(n). It might even be dominated by this transfer cost. So the issue is really that we have a GPU array and a scalar, and we somehow want to make that on the GPU without first creating an array on the CPU. Maybe @maleadt has an idea for this, or an argument for why this cannot be done.

@maleadt
Copy link
Contributor

maleadt commented Oct 5, 2021

So the issue is really that we have a GPU array and a scalar, and we somehow want to make that on the GPU without first creating an array on the CPU.

Yeah, that's not going to perform well. I guess CuArray could be made to support some available space before and after the array such that mutating operations like that could perform well, but that seems like opening another can of worms (how much space to reserve? should all GPU arrays have this, or do we need special constructors again? etc).

Is it not an option to do this manually, by over-allocating and e.g. using a view for the current data?

@ChrisRackauckas
Copy link
Member Author

With SciML/SciMLSensitivity.jl#498 and JuliaGPU/GPUArrays.jl#379 together, the simple out-of-place code works:

using OrdinaryDiffEq
using DiffEqSensitivity
using LinearAlgebra
using Flux
using CUDA
using Random

rng = MersenneTwister(1234)
m = 32
n = 16
Z = randn(rng, Float32, (n,m)) |> gpu
𝒯 = 2.0
Δτ = 0.1
ca_init = [zeros(1) ; ones(m)] |> gpu

function f(ca, Z, t)
  a = ca[2:end]

  a_unit = a / sum(a)
  w_unit = Z*a_unit
  Ka_unit = Z'*w_unit
  z_unit = dot(abs.(Ka_unit), a_unit)
  aKa_over_z = a .* Ka_unit / z_unit
  [sum(aKa_over_z) / m; -abs.(aKa_over_z)]
end

function c(Z)
  prob = ODEProblem(f, ca_init, (0.,𝒯), Z, saveat=Δτ)
  sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(), saveat=Δτ)
  #try this:
  return last(sol.u)[1]
  #or this:
  #return sol.u[20][1]
end

println("forward:", c(Z))
println("backward: ", Zygote.gradient(c, Z))

@ChrisRackauckas
Copy link
Member Author

In-place has been isolated to EnzymeAD/Enzyme.jl#144

@ChrisRackauckas
Copy link
Member Author

Yeah, that's not going to perform well. I guess CuArray could be made to support some available space before and after the array such that mutating operations like that could perform well, but that seems like opening another can of worms (how much space to reserve? should all GPU arrays have this, or do we need special constructors again? etc).

The problem there is really just mutation in general + Zygote. The better solution then is probably to get the Enzyme+CUDA.jl stack working together, which I know @wsmoses has looked into.

@ChrisRackauckas
Copy link
Member Author

All that's left is upstreamed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants