You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a neural network in flax that is basically a function expansion in a non-linear basis set, imagine
f(x, p) = p_1 * sin(x) + p_2 * exp(x) + ...
and I want to find parameters such that for a given X the function f( . , p) has a fix-point at X. The idea was to use
Anderson Acceleration together with implicit differentiation to tune the parameters. The problem is that the gradient
of the FP z = f(z, p) wrt p is zero. I checked the output of the solver (verbose=True) and it successfully finds a solution in less than maxiter iterations.
The normal gradient operation through the network seems to work just fine, because if I take the returned fix-point solution and do one further iteration manually, I do get a non-vanishing gradient wrt the parameters. I could look at the generated jaxpr but the procedure is quite long/complicated so I don't think it would help much.
The text was updated successfully, but these errors were encountered:
"Unfortunately" the method seems to work for a small test network. I think the problem lies in some later transformation that I apply to the output of the described layer. What I don't understand is that the gradient is non-zero for a simple 'single step' evaluation but vanishes as soon as a second iteration is done (setting maxiter=2 for example).
OK, I think I know what the problem was. During the backward pass when the linear system of the inverse Jacobian is solved (by default with solve_cg in linear_solve.py) it doesn't find any solution and just returns the initial starting value, which is a zero vector. I guess it would be nice if there was some kind of error message in the case of non-convergence.
I have a neural network in flax that is basically a function expansion in a non-linear basis set, imagine
and I want to find parameters such that for a given X the function f( . , p) has a fix-point at X. The idea was to use
Anderson Acceleration together with implicit differentiation to tune the parameters. The problem is that the gradient
of the FP z = f(z, p) wrt p is zero. I checked the output of the solver (verbose=True) and it successfully finds a solution in less than
maxiter
iterations.The normal gradient operation through the network seems to work just fine, because if I take the returned fix-point solution and do one further iteration manually, I do get a non-vanishing gradient wrt the parameters. I could look at the generated jaxpr but the procedure is quite long/complicated so I don't think it would help much.
The text was updated successfully, but these errors were encountered: