Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: in-place accum #981

Closed
wants to merge 8 commits into from
Closed

RFC: in-place accum #981

wants to merge 8 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented May 25, 2021

This is a minimal attempt to add safe in-place accumulation of gradients.

It assumes that any Δ::DenseArray may be mutated, and to keep this safe, any rule which duplicates Δ should apply a NoWrite wrapper. This should prevent the problems seen in #962, without such protection.

Rules which do Δ -> (Δ, Δ) (as for +) must apply this wrapper to both branches. (Ideally, once one has been used up, the other could be marked safe to mutate. That could be done but is more complicated, and can be added later.) For rules within Zygote, this protection is done by hand.

For rules defined by ChainRules, the interface function should check the pointer of Δ (and remove NoWrite), then compare the pointer of what it returns, and re-wrap if necessary. I'm not sure this is done perfectly yet. I'm also not sure if this has performance overhead. Evil test cases would be welcome.

It's easy to make the NoWrite wrapper disappear on the RHS of broadcasting or reductions. But if it survives to meet * then it's likely to cause slow generic matmul. I've added quite a few explicit _unprotect calls to try to avoid this. (This should only matter for rules defined with @adjoint.)

Xref JuliaDiff/ChainRulesCore.jl#350

@mcabbott
Copy link
Member Author

mcabbott commented Jun 15, 2021

Benchmarks:

julia> f4(x) = sum([x[i]^2 for i in eachindex(x)]);

julia> @btime Zygote.gradient(f4, $(collect(1:1000)))
  749.750 μs (7068 allocations: 15.77 MiB)  # v0.6.11
  455.959 μs (3071 allocations: 8.04 MiB)   # v0.6.12, with 962
  93.833 μs (2074 allocations: 472.20 KiB)  # this PR
([2, 4, 6, 8, 10, 12, 14, 16, 18, 20  …  1982, 1984, 1986, 1988, 1990, 1992, 1994, 1996, 1998, 2000],)
julia> function _evalpoly(x, p)
           N = length(p)
           ex = p[end]
           for i in N-1:-1:1
               ex = muladd(x, ex, p[i])
           end
           ex
       end
_evalpoly (generic function with 1 method)

julia> x, p = rand(), randn(10000);

julia> @btime _evalpoly(x, p);
  21.791 μs (1 allocation: 16 bytes)

julia> @btime Zygote.gradient(_evalpoly, x, p);
  197.007 ms (680107 allocations: 1.52 GiB)    # v0.6.11
  146.587 ms (660107 allocations: 792.75 MiB)  # v0.6.12, with 962
  62.367 ms (640111 allocations: 35.76 MiB)    # this PR
julia> @btime Zygote.gradient(x -> sum(abs2, net(x)), $(rand(50,50,50,50)));
  1.115 s (8266 allocations: 3.07 GiB)    # v0.6.11
  1.049 s (8187 allocations: 954.08 MiB)  # v0.6.12
  1.013 s (9138 allocations: 954.11 MiB)  # this PR

@ToucheSir
Copy link
Member

How do we feel about this? Would it help to do an @adjoint -> rrule conversion first so that _unprotect is no longer required?

@mcabbott
Copy link
Member Author

I didn't think about this since. Except to realise that https://github.com/bkamins/ReadOnlyArrays.jl might be better than the version I wrote here.

The checks I wrote for function (s::ZBack)(dy) try to handle these cases:

  1. Rules returning (Δ, Δ), like for +: If any two gradients agree, wrap them both.
  2. Rules receiving a wrapped gradient: Always unwrap before calling the rrule. Then re-wrap if (2a) the same answer emerges, or else (2b) if any two gradients agree.

Does (1)/(2b) ever occur, besides +? Maybe sum on an array of arrays is another possible case, not sure this is caught, seems tricky.

Does (2a) ever occur?

Since accum is recursive, this accumulation will also mutate arrays inside the structural gradient of non-array objects. Are any of these ever shared? The existing checks will not notice.

@mcabbott
Copy link
Member Author

A narrower idea is to make in-place accumulation work only for the result of scalar indexing:

function accum(x::OneElement{T,N}, ys::OneElement{T,N}...) where {T,N}
    z = Buffer(x)
    fill!(z.data, zero(T))
    z[x.ind...] = x.val
    accum(z, ys...)
end
function accum(x::Buffer, ys::OneElement...)  # only produced by the above method
    for y in ys
        x[y.ind...] += y.val
    end
    x
end
_project(x::AbstractArray, dx::Buffer) = copy(dx)  # don't return this type

This gets similar speedup on the above examples. I think it ought to be safe. Buffer is just a flag here really, should think about 2nd derivatives too.

@ToucheSir
Copy link
Member

Sounds good to me. Provenance tracking of possibly shared arrays has proven to be a consistent thorn in our side, so the less that has to be done the better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants