-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mark some arrays as safe for accumulation #578
Conversation
julia> @btime Diffractor.gradient(x -> _getindex(x,1), $(rand(128 * 100))); | ||
min 1.012 μs, mean 11.103 μs (2 allocations, 100.05 KiB) | ||
|
||
julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2), $(rand(128 * 100))); | ||
min 7.625 μs, mean 46.941 μs (6 allocations, 300.14 KiB) # unthunk, unthunk, add -- unchanged | ||
|
||
julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2)+_getindex(x,3), $(rand(128 * 100))); | ||
min 16.791 μs, mean 67.720 μs (10 allocations, 500.23 KiB) # before | ||
min 8.625 μs, mean 44.642 μs (6 allocations, 300.14 KiB) # after | ||
|
||
min 1.036 μs, mean 12.684 μs (2 allocations, 100.05 KiB) # with stronger assumption, overwrite any thunk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some benchmarks.
src/accumulation.jl
Outdated
Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, | ||
but this is not safe, e.g. it will lead to an error with ReadOnltArrays.jl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch is on top of #577, mostly to make tests pass.
Codecov ReportBase: 93.11% // Head: 93.11% // No change to project coverage 👍
Additional details and impacted files@@ Coverage Diff @@
## main #578 +/- ##
=======================================
Coverage 93.11% 93.11%
=======================================
Files 15 15
Lines 901 901
=======================================
Hits 839 839
Misses 62 62 Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
BTW, status is that I think it's probably too unsafe to load this onto It would be better to have a dedicated function for accumulation that it not used in the sloppy use of "maybe this will unthunk automatically". |
How would such a function be distinct from |
Whatever else it does, I think Few are actually unsafe, essentially just The safer approach seems to be opt-in. The result of |
Thanks for the explanation. Is there something that can be done on the rule level to annotate one as being safe like |
Unsafe like Done manually? You could wrap both in a ReadOnlyArray. But forgetting this means silent wrong answers. Done automatically? I messed around in FluxML/Zygote.jl#981 with making the Z-to-CR bridge look for repeated pointers... but I don't know how well this can work. Someone is going to have a rule in the wild which does |
ddd9b46
to
4d48df8
Compare
|
But yes, there is a bunch of extra thinking that remains to be done as to if in-place accumulation is ever safe in the presence of aliasing. |
This strikes me as being quite far from status quo. Flux allows for shared parameters (and dealing with this is 99% of the headache in Optimisers.jl) but their gradients will almost never be aliased, as they are freshly allocated by This PR wants to design around that. Assume any array you get from an Can this be safe? What are the edge cases?
|
More broadly, am I right to think the "functional" approach of how The alternative would be to allocate all buffers on the forward pass. Then there would never be a need for thunks, you would always update the existing array. I think this is the path taken by Tracker, and if I understand right also by Enzyme. (For scalars, thunks save work on inactive paths, but this would presumably track activity.) |
How would preallocation interact with something like |
You'd have to zero them all before the second backward pass. (After copying the contents to one slice of the output.) Tracker has will give an error if you accidentally call |
It appears Tracker has some capability for auto-accumulation? https://github.com/FluxML/Tracker.jl/blob/7ab871f4e4d6410e98bb1d5f527e512eb912aff8/src/back.jl#L48-L58. Surely it's not as easy as always using |
As far as I know
InplaceableThunk
never does anything at present, since gradients are accumulated with+
, and this un-thunks before adding.This PR marks some arrays as safe to mutate, and then mutates them when adding another array, or an
InplaceableThunk
. The only arrays it marks are ones produced by adding a thunk to something, as this is sure to be a new array. This is no help for an accumulation of 2 terms, but helps with the 3rd and after, and should make N terms O(1).#539 explores a different rule, assuming any
@thunk
expands to something safe to mutate. That helps with the 2ndThunk
. Both could be done.The marker it uses is a new AbstractThunk. This is nice because consumers of gradients must already know to call
unthunk
if they need an array.But in fact there's a small zoo of functions which automatically un-thunk. I'm not sure these are a great idea, as using them without thinking means you are likely to un-thunk twice or N times by accident. And using them on purpose requires memorising what's in the zoo and what's not. This PR adds one more problem: If you were relying on
+
to un-thunk for you, it will now assume you are accumulating, mutate, and wrap the answer in another thunk.The other snag is that a rule like that for
z = x+y
whose pullback passes the samedz
to bothdx
anddy
risks wrong answers if it passes this new marker through. (Whereas earlierThunks
were safe.) The rrule in CR does not appear to have this problem, asreshape
automatically unthunks (and will be called N times). Rules for.+
and maybe.-
may need care... I believe CR's will simply fail here sincendims(@thunk [1,2])
is an error. Perhaps norrule
should ever transmit a thunk unchanged, and this could be automated somewhere?Maybe this isn't quite safe enough.
What was the plan for
InplaceableThunk
? Was anything written when it was designed?