From e4ff1ab7df7cb50a0397beb40954a2432de33429 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Aug 2022 21:41:34 -0400 Subject: [PATCH 1/2] small upgrade to add!! --- src/accumulation.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/accumulation.jl b/src/accumulation.jl index dc4ccd3bf..9dc2e0631 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -20,7 +20,7 @@ function add!!(x, t::InplaceableThunk) debug_add!(x, t) end else - x + t + x + unthunk(t) end end @@ -28,7 +28,14 @@ add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) - x .+= y + if !debug_mode() + x .+= y + else + z = x + y + # Now write junk into x, to test that nothing is relying on mutation, only using returned value: + x .*= NaN + z + end else x + y end From 4d48df8b8913255ccdc84a6e9e969b881287dcf7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Aug 2022 23:33:03 -0400 Subject: [PATCH 2/2] AccumThunk --- src/ChainRulesCore.jl | 2 +- src/accumulation.jl | 11 ++++ src/tangent_arithmetic.jl | 46 ++++++++++++++- src/tangent_types/thunks.jl | 111 ++++++++++++++++++++++++++++++++++++ test/accumulation.jl | 30 +++++++++- 5 files changed, 196 insertions(+), 4 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b75d8eff5..36fef0aa6 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -15,7 +15,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk, AccumThunk include("compat.jl") include("debug_mode.jl") diff --git a/src/accumulation.jl b/src/accumulation.jl index 9dc2e0631..372a514a1 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,6 +26,17 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) +# add!!(x::AbstractArray, y::AccumThunk) = add!!(x, unthunk(y)) # not sure! This may be less efficient than fallback + +function add!!(x::AbstractArray, y::AccumThunk) + return if is_inplaceable_destination(x) + x .+= y + else + # We are free to mutate the other way... + add!!(y.value, x) + end +end + function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) if !debug_mode() diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 439f0ac8f..0a5178b99 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -116,7 +116,6 @@ Base.complex(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.complex(::ZeroTangent, i::Real) = complex(oftype(i, 0), i) Base.complex(r::Real, ::ZeroTangent) = complex(r) -Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Tangent, :Any) @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b @@ -154,3 +153,48 @@ for T in (:Number,) @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) end + +# Accumulation +# While many weird operations above may never be called, accumulation of gradients is one of +# the big sources of memory allocation in AD, and is the entire reason InplaceableThunks exist. +# Here we try to mark any array known to be safe-to-mutate by wrapping it with AccumThunk. + +Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(unthunk(a) + unthunk(b)) +# Try not to put this wrapper on non-arrays +maybe_accumthunk(a) = is_inplaceable_destination(a) ? AccumThunk(a) : a + +Base.:+(a::AbstractThunk, b::AbstractArray) = AccumThunk(unthunk(a) + b) +Base.:+(a::AbstractArray, b::AbstractThunk) = AccumThunk(a + unthunk(b)) + +Base.:+(a::AccumThunk, b::AbstractArray) = AccumThunk(add!!(a.value, b)) +Base.:+(a::AbstractArray, b::AccumThunk) = AccumThunk(add!!(b.value, a)) + +Base.:+(a::AccumThunk, b::AbstractThunk) = maybe_accumthunk(add!!(a.value, b)) +Base.:+(a::AbstractThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a)) + +function Base.:+(a::AccumThunk, b::AccumThunk) + return if is_inplaceable_destination(a.value) + AccumThunk(add!!(a.value, b.value)) + elseif is_inplaceable_destination(b.value) + AccumThunk(add!!(b.value, a.value)) + else # no point keeping this type: + a.value + b.value + end +end + + +#= + +# You could go further and assume any result of unthunk is safe to mutate, +# something like this: + +# Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(add!!(unthunk(a), b)) + +Base.:+(a::InplaceableThunk, b::AbstractThunk) = AccumThunk(add!!(unthunk(b), b)) +Base.:+(a::AbstractThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b)) +Base.:+(a::InplaceableThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b)) + +Base.:+(a::AccumThunk, b::InplaceableThunk) = maybe_accumthunk(add!!(a.value, b)) +Base.:+(a::InplaceableThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a)) + +=# diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 8baa006e8..4d18e9dd1 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -7,6 +7,16 @@ function Base.showerror(io::IO, e::MutateThunkException) return nothing end +##### +##### Operations which un-thunk automatically +##### + +# Note the if you use an object which might be thunked in two places, +# you should *always* call `unthunk` manually first, once, to avoid un-thunking twice. + +# Maybe the docs should have a list of exactly what operations do un-thunk automatically... +# do we really need so many? + Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x)) @inline function Base.iterate(x::AbstractThunk) @@ -138,6 +148,11 @@ macro thunk(body) func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) return :(Thunk($(esc(func)))) end +# macro thunk(s::Symbol) +# @warn "Applying `@thunk` to a single symbol does nothing, as there is no calculation to defer." +# # But should it perhaps do something, if we also regard thunks as marking safe-to-mutate? +# return esc(s) +# end """ unthunk(x) @@ -157,6 +172,7 @@ Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x))) """ Thunk(()->v) + A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a tangent. `@thunk(v)` is a macro that expands into `Thunk(()->v)`. @@ -212,6 +228,10 @@ end Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a) +##### +##### `InplaceableThunk` +##### + """ InplaceableThunk(add!::Function, val::Thunk) @@ -244,3 +264,94 @@ function Base.show(io::IO, x::InplaceableThunk) show(io, x.val) print(io, ")") end + + +##### +##### `AccumThunk` +##### + +""" + AccumThunk(value) <: AbstractThunk + +This isn't a delayed computation, but is instead a marker that its contents is known to be safe +to mutate during gradient accumulation. At present it is produced by adding two thunks, +allowing any further addition to keep mutating. Anything downstream which wants an array must +already know to `unthunk`, which is why this is `<: AbstractThunk`. + +Ideally it would be produced by adding two Arrays too, but that's impossible in CR's design. +It might be good for many rules which produce a known-safe Array to wrap it in this. + +If we may assume/demand that the result of `@thunk` is always a new array, too, +then more cases can mutate. And then it would make sense for `@thunk A` on one Symbol +to produce an `AccumThunk`, promoting `@thunk` to have two meanings. But not yet done. +""" +struct AccumThunk{T} <: AbstractThunk + value::T +end + +@inline unthunk(x::AccumThunk) = x.value + +function Base.show(io::IO, x::AccumThunk) + print(io, "AccumThunk(") + str = sprint(show, x.value, context = io) + if length(str) < 80 + print(io, str) + else + print(io, first(str, 70), "...") + end + print(io, ")") +end + + +#= + +julia> using ChainRules, ChainRulesCore, Diffractor + +julia> _getindex(x...) = getindex(x...); # use CR's rule: +julia> function ChainRules.rrule(::typeof(_getindex), x::AbstractArray, inds...) + function getindex_pullback(dy) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), ChainRules.thunked_∇getindex(x, dy, inds...), nots...) + end + return x[inds...], getindex_pullback + end + +julia> Diffractor.gradient(x -> _getindex(x,1), [1,2,3.0]) # calls unthunk on final answer +([1.0, 0.0, 0.0],) + +julia> @btime Diffractor.gradient(x -> _getindex(x,1), $(rand(128 * 100))); + min 1.012 μs, mean 11.103 μs (2 allocations, 100.05 KiB) + +julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2), $(rand(128 * 100))); + min 7.625 μs, mean 46.941 μs (6 allocations, 300.14 KiB) # unthunk, unthunk, add -- unchanged + +julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2)+_getindex(x,3), $(rand(128 * 100))); + min 16.791 μs, mean 67.720 μs (10 allocations, 500.23 KiB) # before + min 8.625 μs, mean 44.642 μs (6 allocations, 300.14 KiB) # after + + min 1.036 μs, mean 12.684 μs (2 allocations, 100.05 KiB) # with stronger assumption, overwrite any thunk + +# Same example as https://github.com/FluxML/Zygote.jl/pull/981#issuecomment-861079488 +# originally https://github.com/FluxML/Zygote.jl/issues/644 + +julia> function _evalpoly(x, p) + N = length(p) + ex = _getindex(p, length(p)) + for i in N-1:-1:1 + ex = muladd(x, ex, _getindex(p, i)) + end + ex + end +_evalpoly (generic function with 1 method) + +julia> x, p = rand(), randn(10000); + +julia> @btime _evalpoly(x, p); + min 20.375 μs, mean 20.553 μs (1 allocation, 16 bytes) + +julia> @btime Diffractor.gradient(_evalpoly, x, p); + min 566.669 ms, mean 585.185 ms (1174329 allocations, 2.44 GiB) # before + min 376.376 ms, mean 384.314 ms (1144338 allocations, 975.62 MiB) # after + +=# + diff --git a/test/accumulation.jl b/test/accumulation.jl index 597105d32..e5ef31fbe 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -90,9 +90,10 @@ end end - @testset "AbstractThunk $(typeof(thunk))" for thunk in ( + @testset "add!!(array, $(typeof(thunk)))" for thunk in ( @thunk(-1.0 * ones(2, 2)), InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), + AccumThunk(-ones(2, 2)) ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -101,14 +102,18 @@ @test ret === accumuland # must be same object end + @test unthunk(thunk) == -ones(2, 2) # AccumThunk has not been mutated + @testset "out of place" begin accumuland = @SMatrix [1.0 2.0; 3.0 4.0] ret = add!!(accumuland, thunk) @test ret == [0.0 1.0; 2.0 3.0] # must return right answer @test ret !== accumuland # must not be same object - @test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated + @test accumuland == [1.0 2.0; 3.0 4.0] # cannot ever be mutated end + + unthunk(thunk) # AccumThunk may have been mutated, test has no opinion? end @testset "not actually inplace but said it was" begin @@ -137,4 +142,25 @@ msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22])) @test occursin("equal", msg_equal) end + + @testset "thunk + thunk" begin + s1 = @thunk([1.0]) + @thunk([2.0]) + @thunk([3.0]) + @test unthunk(s1) == [6] + @test s1 isa AccumThunk + + list = [[1.0], @thunk([1.0]), InplaceableThunk(x -> x .+ 1, @thunk [1.0]), AccumThunk([1.0])] + for x in list, y in list + z = deepcopy(x) + deepcopy(y) + @test unthunk(z) == [2] + @test z isa AccumThunk || (x isa Array && y isa Array) + end + + triv = [1.0, @thunk(1.0), AccumThunk(1.0)] + for x in triv, y in triv + z = x + y + @test unthunk(z) === 2.0 + @test z isa Float64 || (x isa AccumThunk && y isa AccumThunk) + # How much to se care about not applying these wrappers when not useful? + end + end end