-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: in-place accum
#981
RFC: in-place accum
#981
Conversation
Benchmarks:
|
How do we feel about this? Would it help to do an |
I didn't think about this since. Except to realise that https://github.com/bkamins/ReadOnlyArrays.jl might be better than the version I wrote here. The checks I wrote for
Does (1)/(2b) ever occur, besides Does (2a) ever occur? Since |
A narrower idea is to make in-place accumulation work only for the result of scalar indexing: function accum(x::OneElement{T,N}, ys::OneElement{T,N}...) where {T,N}
z = Buffer(x)
fill!(z.data, zero(T))
z[x.ind...] = x.val
accum(z, ys...)
end
function accum(x::Buffer, ys::OneElement...) # only produced by the above method
for y in ys
x[y.ind...] += y.val
end
x
end
_project(x::AbstractArray, dx::Buffer) = copy(dx) # don't return this type This gets similar speedup on the above examples. I think it ought to be safe. |
Sounds good to me. Provenance tracking of possibly shared arrays has proven to be a consistent thorn in our side, so the less that has to be done the better. |
This is a minimal attempt to add safe in-place accumulation of gradients.
It assumes that any
Δ::DenseArray
may be mutated, and to keep this safe, any rule which duplicatesΔ
should apply aNoWrite
wrapper. This should prevent the problems seen in #962, without such protection.Rules which do
Δ -> (Δ, Δ)
(as for+
) must apply this wrapper to both branches. (Ideally, once one has been used up, the other could be marked safe to mutate. That could be done but is more complicated, and can be added later.) For rules within Zygote, this protection is done by hand.For rules defined by ChainRules, the interface function should check the pointer of
Δ
(and removeNoWrite
), then compare the pointer of what it returns, and re-wrap if necessary. I'm not sure this is done perfectly yet. I'm also not sure if this has performance overhead. Evil test cases would be welcome.It's easy to make the
NoWrite
wrapper disappear on the RHS of broadcasting or reductions. But if it survives to meet*
then it's likely to cause slow generic matmul. I've added quite a few explicit_unprotect
calls to try to avoid this. (This should only matter for rules defined with@adjoint
.)Xref JuliaDiff/ChainRulesCore.jl#350