Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark some arrays as safe for accumulation #578

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk, AccumThunk

include("compat.jl")
include("debug_mode.jl")
Expand Down
22 changes: 20 additions & 2 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,33 @@ function add!!(x, t::InplaceableThunk)
debug_add!(x, t)
end
else
x + t
x + unthunk(t)
end
end

add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y))

function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
# add!!(x::AbstractArray, y::AccumThunk) = add!!(x, unthunk(y)) # not sure! This may be less efficient than fallback

function add!!(x::AbstractArray, y::AccumThunk)
return if is_inplaceable_destination(x)
x .+= y
else
# We are free to mutate the other way...
add!!(y.value, x)
end
end

function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
return if is_inplaceable_destination(x)
if !debug_mode()
x .+= y
else
z = x + y
# Now write junk into x, to test that nothing is relying on mutation, only using returned value:
x .*= NaN
z
end
else
x + y
end
Expand Down
46 changes: 45 additions & 1 deletion src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ Base.complex(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
Base.complex(::ZeroTangent, i::Real) = complex(oftype(i, 0), i)
Base.complex(r::Real, ::ZeroTangent) = complex(r)

Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Tangent, :Any)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
Expand Down Expand Up @@ -154,3 +153,48 @@ for T in (:Number,)
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
end

# Accumulation
# While many weird operations above may never be called, accumulation of gradients is one of
# the big sources of memory allocation in AD, and is the entire reason InplaceableThunks exist.
# Here we try to mark any array known to be safe-to-mutate by wrapping it with AccumThunk.

Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(unthunk(a) + unthunk(b))
# Try not to put this wrapper on non-arrays
maybe_accumthunk(a) = is_inplaceable_destination(a) ? AccumThunk(a) : a

Base.:+(a::AbstractThunk, b::AbstractArray) = AccumThunk(unthunk(a) + b)
Base.:+(a::AbstractArray, b::AbstractThunk) = AccumThunk(a + unthunk(b))

Base.:+(a::AccumThunk, b::AbstractArray) = AccumThunk(add!!(a.value, b))
Base.:+(a::AbstractArray, b::AccumThunk) = AccumThunk(add!!(b.value, a))

Base.:+(a::AccumThunk, b::AbstractThunk) = maybe_accumthunk(add!!(a.value, b))
Base.:+(a::AbstractThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a))

function Base.:+(a::AccumThunk, b::AccumThunk)
return if is_inplaceable_destination(a.value)
AccumThunk(add!!(a.value, b.value))
elseif is_inplaceable_destination(b.value)
AccumThunk(add!!(b.value, a.value))
else # no point keeping this type:
a.value + b.value
end
end


#=

# You could go further and assume any result of unthunk is safe to mutate,
# something like this:

# Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(add!!(unthunk(a), b))

Base.:+(a::InplaceableThunk, b::AbstractThunk) = AccumThunk(add!!(unthunk(b), b))
Base.:+(a::AbstractThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b))
Base.:+(a::InplaceableThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b))

Base.:+(a::AccumThunk, b::InplaceableThunk) = maybe_accumthunk(add!!(a.value, b))
Base.:+(a::InplaceableThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a))

=#
111 changes: 111 additions & 0 deletions src/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ function Base.showerror(io::IO, e::MutateThunkException)
return nothing
end

#####
##### Operations which un-thunk automatically
#####

# Note the if you use an object which might be thunked in two places,
# you should *always* call `unthunk` manually first, once, to avoid un-thunking twice.

# Maybe the docs should have a list of exactly what operations do un-thunk automatically...
# do we really need so many?

Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x))

@inline function Base.iterate(x::AbstractThunk)
Expand Down Expand Up @@ -138,6 +148,11 @@ macro thunk(body)
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
return :(Thunk($(esc(func))))
end
# macro thunk(s::Symbol)
# @warn "Applying `@thunk` to a single symbol does nothing, as there is no calculation to defer."
# # But should it perhaps do something, if we also regard thunks as marking safe-to-mutate?
# return esc(s)
# end

"""
unthunk(x)
Expand All @@ -157,6 +172,7 @@ Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x)))

"""
Thunk(()->v)

A thunk is a deferred computation.
It wraps a zero argument closure that when invoked returns a tangent.
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
Expand Down Expand Up @@ -212,6 +228,10 @@ end

Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a)

#####
##### `InplaceableThunk`
#####

"""
InplaceableThunk(add!::Function, val::Thunk)

Expand Down Expand Up @@ -244,3 +264,94 @@ function Base.show(io::IO, x::InplaceableThunk)
show(io, x.val)
print(io, ")")
end


#####
##### `AccumThunk`
#####

"""
AccumThunk(value) <: AbstractThunk

This isn't a delayed computation, but is instead a marker that its contents is known to be safe
to mutate during gradient accumulation. At present it is produced by adding two thunks,
allowing any further addition to keep mutating. Anything downstream which wants an array must
already know to `unthunk`, which is why this is `<: AbstractThunk`.

Ideally it would be produced by adding two Arrays too, but that's impossible in CR's design.
It might be good for many rules which produce a known-safe Array to wrap it in this.

If we may assume/demand that the result of `@thunk` is always a new array, too,
then more cases can mutate. And then it would make sense for `@thunk A` on one Symbol
to produce an `AccumThunk`, promoting `@thunk` to have two meanings. But not yet done.
"""
struct AccumThunk{T} <: AbstractThunk
value::T
end

@inline unthunk(x::AccumThunk) = x.value

function Base.show(io::IO, x::AccumThunk)
print(io, "AccumThunk(")
str = sprint(show, x.value, context = io)
if length(str) < 80
print(io, str)
else
print(io, first(str, 70), "...")
end
print(io, ")")
end


#=

julia> using ChainRules, ChainRulesCore, Diffractor

julia> _getindex(x...) = getindex(x...); # use CR's rule:
julia> function ChainRules.rrule(::typeof(_getindex), x::AbstractArray, inds...)
function getindex_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), ChainRules.thunked_∇getindex(x, dy, inds...), nots...)
end
return x[inds...], getindex_pullback
end

julia> Diffractor.gradient(x -> _getindex(x,1), [1,2,3.0]) # calls unthunk on final answer
([1.0, 0.0, 0.0],)

julia> @btime Diffractor.gradient(x -> _getindex(x,1), $(rand(128 * 100)));
min 1.012 μs, mean 11.103 μs (2 allocations, 100.05 KiB)

julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2), $(rand(128 * 100)));
min 7.625 μs, mean 46.941 μs (6 allocations, 300.14 KiB) # unthunk, unthunk, add -- unchanged

julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2)+_getindex(x,3), $(rand(128 * 100)));
min 16.791 μs, mean 67.720 μs (10 allocations, 500.23 KiB) # before
min 8.625 μs, mean 44.642 μs (6 allocations, 300.14 KiB) # after

min 1.036 μs, mean 12.684 μs (2 allocations, 100.05 KiB) # with stronger assumption, overwrite any thunk
Comment on lines +322 to +332
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some benchmarks.


# Same example as https://github.com/FluxML/Zygote.jl/pull/981#issuecomment-861079488
# originally https://github.com/FluxML/Zygote.jl/issues/644

julia> function _evalpoly(x, p)
N = length(p)
ex = _getindex(p, length(p))
for i in N-1:-1:1
ex = muladd(x, ex, _getindex(p, i))
end
ex
end
_evalpoly (generic function with 1 method)

julia> x, p = rand(), randn(10000);

julia> @btime _evalpoly(x, p);
min 20.375 μs, mean 20.553 μs (1 allocation, 16 bytes)

julia> @btime Diffractor.gradient(_evalpoly, x, p);
min 566.669 ms, mean 585.185 ms (1174329 allocations, 2.44 GiB) # before
min 376.376 ms, mean 384.314 ms (1144338 allocations, 975.62 MiB) # after

=#

30 changes: 28 additions & 2 deletions test/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@
end
end

@testset "AbstractThunk $(typeof(thunk))" for thunk in (
@testset "add!!(array, $(typeof(thunk)))" for thunk in (
@thunk(-1.0 * ones(2, 2)),
InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))),
AccumThunk(-ones(2, 2))
)
@testset "in place" begin
accumuland = [1.0 2.0; 3.0 4.0]
Expand All @@ -101,14 +102,18 @@
@test ret === accumuland # must be same object
end

@test unthunk(thunk) == -ones(2, 2) # AccumThunk has not been mutated

@testset "out of place" begin
accumuland = @SMatrix [1.0 2.0; 3.0 4.0]

ret = add!!(accumuland, thunk)
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
@test ret !== accumuland # must not be same object
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
@test accumuland == [1.0 2.0; 3.0 4.0] # cannot ever be mutated
end

unthunk(thunk) # AccumThunk may have been mutated, test has no opinion?
end

@testset "not actually inplace but said it was" begin
Expand Down Expand Up @@ -137,4 +142,25 @@
msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22]))
@test occursin("equal", msg_equal)
end

@testset "thunk + thunk" begin
s1 = @thunk([1.0]) + @thunk([2.0]) + @thunk([3.0])
@test unthunk(s1) == [6]
@test s1 isa AccumThunk

list = [[1.0], @thunk([1.0]), InplaceableThunk(x -> x .+ 1, @thunk [1.0]), AccumThunk([1.0])]
for x in list, y in list
z = deepcopy(x) + deepcopy(y)
@test unthunk(z) == [2]
@test z isa AccumThunk || (x isa Array && y isa Array)
end

triv = [1.0, @thunk(1.0), AccumThunk(1.0)]
for x in triv, y in triv
z = x + y
@test unthunk(z) === 2.0
@test z isa Float64 || (x isa AccumThunk && y isa AccumThunk)
# How much to se care about not applying these wrappers when not useful?
end
end
end