-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote.OneElement does not properly reshape #1080
Comments
I don't think one element is something end users need to interact with. Where are we seeing this get leaked to user code? |
In the DiffEq adjoints it's passed back: https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.58.0/src/concrete_solve.jl#L207 |
function Base.reshape(x::Flux.Zygote.OneElement,shp::Int)
prev_shp = Base.to_shape(x.axes)
@assert prod(prev_shp) == prod(shp)
linear_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
Flux.Zygote.OneElement(x.val,(linear_ind,),(Base.OneTo(shp),))
end
function Base.reshape(x::Flux.Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
@assert prod(Base.to_shape(x.axes)) == prod(shp)
Flux.Zygote.OneElement(x.val,x.ind,((Base.OneTo(s) for s in shp)...,))
end Seems to be what's necessary. |
It's of course difficult to deal with in the presence of complex adjoint code. We likely need regular arrays to handle arbitrary operations happening on them in intermediate computations happening in custom adjoints. It's also not going to work neatly with GPUs in its current state. |
With those reshapes this seems fine: using Zygote, CUDA
Z = randn(rng, Float32, (n,m)) |> gpu
Z .= Flux.Zygote.OneElement(0.1, (2,1), axes(Z))
vec(Z) .= vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) |
However, vec(Z) .= -vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) hits a GPU kernel generation error. |
The reshape code looks about right. Wouldn't we need to calculate the resultant indices in the second ( |
oh yes you would. |
function Base.reshape(x::Flux.Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
prev_shp = Base.to_shape(x.axes)
@assert prod(Base.to_shape(x.axes)) == prod(shp)
cartesian_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
Flux.Zygote.OneElement(x.val,cartesian_ind,((Base.OneTo(s) for s in shp)...,))
end would probably work. |
@maleadt here's an MWE: using Zygote, CUDA
function Base.reshape(x::Zygote.OneElement,shp::Int)
prev_shp = Base.to_shape(x.axes)
@assert prod(prev_shp) == prod(shp)
linear_ind = LinearIndices(prev_shp)[CartesianIndex(x.ind)]
Flux.Zygote.OneElement(x.val,(linear_ind,),(Base.OneTo(shp),))
end
function Base.reshape(x::Zygote.OneElement,shp::Tuple{Int64, Vararg{Int64, N}}) where N
@assert prod(Base.to_shape(x.axes)) == prod(shp)
Flux.Zygote.OneElement(x.val,x.ind,((Base.OneTo(s) for s in shp)...,))
end
Z = randn(rng, Float32, (n,m)) |> gpu
Z .= Flux.Zygote.OneElement(0.1, (2,1), axes(Z))
vec(Z) .= vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) # works!
vec(Z) .= -vec(Flux.Zygote.OneElement(0.1, (2,1), axes(Z))) # doesn't work! I'm curious why adding that |
The OneElement constructor isn't very defensive, because it was assumed this only had one caller. What's written here should probably be an error, since it constructs something whose
Does Zygote construct such things? If it does that's a bug which ought to be fixed, and a MWE of that would be appreciated. It's possible that the rule put too little thought into linear indexing, although a naiive attempt to fool it does not work:
Edit -- Here's the obvious way to constrain this thing, by explicitly demanding that the axes have the right length on construction. Which would turn the above construction by hand into a method error:
But the only caller is of the constructor is this the following function, which explicitly checks that the length of the indices matches
|
Well, Yes, the rule does not cover the case of different indexing techniques or consider the array type (StarWarsArray.jl, OffsetArrays.jl esque), but that of course doesn't mean that reshape among others needs to be fixed |
Sometimes you just need to |
Yes I agree What you must avoid is constructing objects whose type doesn't match their properties. The OneElement constructor presently lacks a guardrail to make this construction an error, which is why you can make one by hand. But if you can make Zygote make one for you, that would be much more concerning. On a legal OneElement, reshape works fine:
If you'd like to avoid a ReshapedArray here, it would be easy to add reshape methods like the ones above, of course. Maybe there are cases where this is more efficient? Note that the multi-dimensional one needs more care:
|
Oh, I didn't see that I made an illegal Zygote.OneElement. It's the same error I get in a pullback though, which is disconcerting. MWE that got that is the first thing in SciML/DiffEqFlux.jl#632 |
Is that on Flux master? You say " if we try to say |> gpu inside of the rhs function" which means you may need things fixed in FluxML/Flux.jl#1704 (almost a month ago). Without that you may still hit problems like #1005. But I don't know that that's the problem. Independent of that, if you have managed to get Zygote to make an illegal OneElement, that's very concerning, and would be worth trying to figure out. Maybe adding the guardrail above (to make construction a MethodError) might at least produce stack traces closer to the constructor. Is that stack trace for SciML/DiffEqFlux.jl#632 posted somewhere? |
It works if you broadcast the
|
That |
Well spotted. Even for normal arrays, |
Yeah, that's the downstream fix to that specific problem. I think that means everything here is answered. |
MWE:
The text was updated successfully, but these errors were encountered: