You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.
Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:
bc = gd.BoundaryConditions((gd.PERIODIC,))
centered_variable = gd.GridVariable(centered, bc)
print(centered_variable + centered_variable)
I'm happy to have a go at implementing this myself, if someone isn't already working on it.
Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:
print(jnp.abs(centered_array))
But this works:
import numpy as np
print(np.abs(centered_array))
I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.
The text was updated successfully, but these errors were encountered:
This was an intentional design choice -- GridVariables have boundary conditions, which we don't know how to propagate automatically (unless using periodic boundaries, which aren't really boundary conditions at all). So we only support math on GridArray objects.
Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy?
This is correct, I agree it's strange. It's for the simple reason that NumPy supports overriding it's functions on new types but JAX doesn't.
Oh, I see! Yes, addition and multiplication would also work for matching homogeneous Dirichet/Neumann BCs, but that's a special case, and it wouldn't extend to other functions like sines and cosines and so on. I assume this is also the reason why the "shift" method on a GridVariable returns a GridArray. Thanks for the explanation!
Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.
Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:
But this throws an exception:
I'm happy to have a go at implementing this myself, if someone isn't already working on it.
Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:
But this works:
I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.
The text was updated successfully, but these errors were encountered: