Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote.OneElement does not properly reshape #1080

Closed
ChrisRackauckas opened this issue Sep 26, 2021 · 20 comments
Closed

Zygote.OneElement does not properly reshape #1080

ChrisRackauckas opened this issue Sep 26, 2021 · 20 comments

Comments

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Sep 26, 2021

MWE:

Z = randn(rng, Float32, (n,m))
Z .= Flux.Zygote.OneElement(0.1, (2,), axes(Z))
vec(Z) .= vec(Flux.Zygote.OneElement(0.1, (2,), axes(Z)))
@DhairyaLGandhi
Copy link
Member

I don't think one element is something end users need to interact with. Where are we seeing this get leaked to user code?

@ChrisRackauckas
Copy link
Member Author

In the DiffEq adjoints it's passed back: https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.58.0/src/concrete_solve.jl#L207

@ChrisRackauckas
Copy link
Member Author

function Base.reshape(x::Flux.Zygote.OneElement,shp::Int)
    prev_shp = Base.to_shape(x.axes)
    @assert prod(prev_shp) == prod(shp)
    linear_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
    Flux.Zygote.OneElement(x.val,(linear_ind,),(Base.OneTo(shp),))
end

function Base.reshape(x::Flux.Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
    @assert prod(Base.to_shape(x.axes)) == prod(shp)
    Flux.Zygote.OneElement(x.val,x.ind,((Base.OneTo(s) for s in shp)...,))
end

Seems to be what's necessary.

@DhairyaLGandhi
Copy link
Member

It's of course difficult to deal with in the presence of complex adjoint code. We likely need regular arrays to handle arbitrary operations happening on them in intermediate computations happening in custom adjoints. It's also not going to work neatly with GPUs in its current state.

@ChrisRackauckas
Copy link
Member Author

With those reshapes this seems fine:

using Zygote, CUDA
Z = randn(rng, Float32, (n,m)) |> gpu
Z .= Flux.Zygote.OneElement(0.1, (2,1), axes(Z))
vec(Z) .= vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z)))

@ChrisRackauckas
Copy link
Member Author

However,

vec(Z) .= -vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z)))

hits a GPU kernel generation error.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 26, 2021

The reshape code looks about right. Wouldn't we need to calculate the resultant indices in the second (Tuple{Int...}) case?

@ChrisRackauckas
Copy link
Member Author

oh yes you would.

@ChrisRackauckas
Copy link
Member Author

function Base.reshape(x::Flux.Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
    prev_shp = Base.to_shape(x.axes)
    @assert prod(Base.to_shape(x.axes)) == prod(shp)
    cartesian_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
    Flux.Zygote.OneElement(x.val,cartesian_ind,((Base.OneTo(s) for s in shp)...,))
end

would probably work.

@ChrisRackauckas
Copy link
Member Author

@maleadt here's an MWE:

using Zygote, CUDA
function Base.reshape(x::Zygote.OneElement,shp::Int)
    prev_shp = Base.to_shape(x.axes)
    @assert prod(prev_shp) == prod(shp)
    linear_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
    Flux.Zygote.OneElement(x.val,(linear_ind,),(Base.OneTo(shp),))
end

function Base.reshape(x::Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
    @assert prod(Base.to_shape(x.axes)) == prod(shp)
    Flux.Zygote.OneElement(x.val,x.ind,((Base.OneTo(s) for s in shp)...,))
end

Z = randn(rng, Float32, (n,m)) |> gpu
Z .= Flux.Zygote.OneElement(0.1, (2,1), axes(Z))
vec(Z) .= vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) # works!

vec(Z) .= -vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) # doesn't work!

I'm curious why adding that - makes it so it cannot compile on GPUs anymore.

@mcabbott
Copy link
Member

mcabbott commented Sep 28, 2021

The OneElement constructor isn't very defensive, because it was assumed this only had one caller. What's written here should probably be an error, since it constructs something whose ndims doesn't match its axes, and it is no surprise that this confuses many Base functions downstream:

julia> using Zygote

julia> Z = rand(2,3);

julia> ze = Zygote.OneElement(0.1, (2,), axes(Z))  # this is a broken object, should be illegal!
2×3 Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}:
 0.0
 0.1

julia> ndims(ze)
1

julia> size(ze)  # doesn't match ndims & type
(2, 3)

julia> @which vec(ze)
vec(a::AbstractVector) in Base at abstractarraymath.jl:42

julia> vec(ze) === ze  # does nothing at all, because it believes ndims
true

Does Zygote construct such things? If it does that's a bug which ought to be fixed, and a MWE of that would be appreciated. It's possible that the rule put too little thought into linear indexing, although a naiive attempt to fool it does not work:

julia> gradient(x -> x[1], rand(2,3))[1]  # linear indexing avoids this path
2×3 Matrix{Float64}:
 1.0  0.0  0.0
 0.0  0.0  0.0

julia> gradient(x -> x[1,1], rand(2,3))[1]
2×3 Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}:
 1.0  0.0  0.0
 0.0  0.0  0.0

Edit -- Here's the obvious way to constrain this thing, by explicitly demanding that the axes have the right length on construction. Which would turn the above construction by hand into a method error:

julia> struct OneElement{T,N,I,A} <: AbstractArray{T,N}
         val::T
         ind::I
         axes::A
         OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,<:AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
       end

julia> Base.size(A::OneElement) = map(length, A.axes)

julia> Base.axes(A::OneElement) = A.axes

julia> Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))

julia> OneElement(0.1, (2,), axes(Z))
ERROR: MethodError: no method matching OneElement(::Float64, ::Tuple{Int64}, ::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})

But the only caller is of the constructor is this the following function, which explicitly checks that the length of the indices matches N. For this to construct an illegal object, I think you would need an x whose ndims does not match its axes.

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
  if inds isa NTuple{N,Int} && T <: Number
    dx = OneElement(dy, inds, axes(x))
  elseif ...

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 28, 2021

Well, Yes, the rule does not cover the case of different indexing techniques or consider the array type (StarWarsArray.jl, OffsetArrays.jl esque), but that of course doesn't mean that reshape among others needs to be fixed

@ChrisRackauckas
Copy link
Member Author

Sometimes you just need to vec values in a pullback. I don't think it makes sense to blanket say that vec is never an operation that should ever be used in any pullback function ever. The problem is that the delta pullback value is sometimes a Zygote.OneElement, so algorithms that would want to vec will fail.

@mcabbott
Copy link
Member

mcabbott commented Sep 28, 2021

Yes I agree vec should always be safe, and understand why you want it. I don't see anyone saying you must avoid this.

What you must avoid is constructing objects whose type doesn't match their properties. The OneElement constructor presently lacks a guardrail to make this construction an error, which is why you can make one by hand. But if you can make Zygote make one for you, that would be much more concerning.

On a legal OneElement, reshape works fine:

julia> ze2 = Zygote.OneElement(0.1, (2,1), axes(Z))  # this is a legal object
2×3 Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}:
 0.0  0.0  0.0
 0.1  0.0  0.0

julia> vec(ze2)
6-element reshape(::Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, 6) with eltype Float64:
 0.0
 0.1
 0.0
 0.0
 0.0
 0.0

If you'd like to avoid a ReshapedArray here, it would be easy to add reshape methods like the ones above, of course. Maybe there are cases where this is more efficient? Note that the multi-dimensional one needs more care:

julia> reshape(ze2, (1,6))  # with the above Base.reshape definition
1×6 Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}:
 0.0  0.0  0.0  0.0  0.0  0.0

julia> reshape(collect(ze2), (1,6)) 
1×6 Matrix{Float64}:
 0.0  0.1  0.0  0.0  0.0  0.0

@ChrisRackauckas
Copy link
Member Author

Oh, I didn't see that I made an illegal Zygote.OneElement. It's the same error I get in a pullback though, which is disconcerting. MWE that got that is the first thing in SciML/DiffEqFlux.jl#632

@mcabbott
Copy link
Member

Is that on Flux master? You say " if we try to say |> gpu inside of the rhs function" which means you may need things fixed in FluxML/Flux.jl#1704 (almost a month ago). Without that you may still hit problems like #1005.

But I don't know that that's the problem. Independent of that, if you have managed to get Zygote to make an illegal OneElement, that's very concerning, and would be worth trying to figure out. Maybe adding the guardrail above (to make construction a MethodError) might at least produce stack traces closer to the constructor.

Is that stack trace for SciML/DiffEqFlux.jl#632 posted somewhere?

@maleadt
Copy link
Contributor

maleadt commented Oct 5, 2021

I'm curious why adding that - makes it so it cannot compile on GPUs anymore.

julia> typeof(vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))))
Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}

julia> typeof(-vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))))
Vector{Float64} (alias for Array{Float64, 1})

It works if you broadcast the -:

vec(Z) .= .-vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z)))

@ChrisRackauckas
Copy link
Member Author

That .- seems to be the key. Thanks.

@mcabbott
Copy link
Member

mcabbott commented Oct 6, 2021

Well spotted. Even for normal arrays, z .= -x will always add a copy which z .= .-x avoids, but is super-easy to miss. It might be friendly to provided methods for OneElement which avoid this, since it's just one line. Maybe z .= -2x too. I can't think of any downsides.

@ChrisRackauckas
Copy link
Member Author

Yeah, that's the downstream fix to that specific problem. I think that means everything here is answered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants