Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debugging vanishing gradients in implicit fix-point differentiation #466

Closed
alessandrosimon opened this issue Jul 6, 2023 · 3 comments
Closed
Labels
question Further information is requested

Comments

@alessandrosimon
Copy link

alessandrosimon commented Jul 6, 2023

I have a neural network in flax that is basically a function expansion in a non-linear basis set, imagine

f(x, p) = p_1 * sin(x) + p_2 * exp(x) + ...

and I want to find parameters such that for a given X the function f( . , p) has a fix-point at X. The idea was to use
Anderson Acceleration together with implicit differentiation to tune the parameters. The problem is that the gradient
of the FP z = f(z, p) wrt p is zero. I checked the output of the solver (verbose=True) and it successfully finds a solution in less than maxiter iterations.

The normal gradient operation through the network seems to work just fine, because if I take the returned fix-point solution and do one further iteration manually, I do get a non-vanishing gradient wrt the parameters. I could look at the generated jaxpr but the procedure is quite long/complicated so I don't think it would help much.

@mblondel
Copy link
Collaborator

mblondel commented Jul 7, 2023

Do you have a minimal example?

@mblondel mblondel added the question Further information is requested label Jul 7, 2023
@alessandrosimon
Copy link
Author

"Unfortunately" the method seems to work for a small test network. I think the problem lies in some later transformation that I apply to the output of the described layer. What I don't understand is that the gradient is non-zero for a simple 'single step' evaluation but vanishes as soon as a second iteration is done (setting maxiter=2 for example).

@alessandrosimon
Copy link
Author

OK, I think I know what the problem was. During the backward pass when the linear system of the inverse Jacobian is solved (by default with solve_cg in linear_solve.py) it doesn't find any solution and just returns the initial starting value, which is a zero vector. I guess it would be nice if there was some kind of error message in the case of non-convergence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants