Skip to content

Commit

Permalink
Amazingly, example seems to be running properly but its too slow beca…
Browse files Browse the repository at this point in the history
…use each PSCCurrent does the whole Linv * psi solve required to solve for all the currents, and only need to do one of these each iteration. So probably need to make a class PSCArray that has a list of current objects and just keeps them updated.
  • Loading branch information
akaptano committed Dec 20, 2024
1 parent 2fc5da8 commit 18540a0
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
23 changes: 9 additions & 14 deletions examples/3_Advanced/coil_force_optimization/passive_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
range_param = "half period"
nphi = 32
ntheta = 32
poff = 1.5
coff = 1.5
poff = 2.0
coff = 1.0
s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta)
s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4)
s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4)
Expand Down Expand Up @@ -81,8 +81,8 @@ def initialize_coils_QA(TEST_DIR, s):

# generate planar TF coils
ncoils = 3
R0 = s.get_rc(0, 0) * 1
R1 = s.get_rc(1, 0) * 3
R0 = s.get_rc(0, 0) * 1.4
R1 = s.get_rc(1, 0) * 7
order = 4

from simsopt.mhd.vmec import Vmec
Expand Down Expand Up @@ -132,7 +132,7 @@ def initialize_coils_QA(TEST_DIR, s):
aa = 0.05
bb = 0.05

Nx = 3
Nx = 4
Ny = Nx
Nz = Nx
# Create the initial coils:
Expand Down Expand Up @@ -186,6 +186,7 @@ def initialize_coils_QA(TEST_DIR, s):
a_list = np.ones(len(base_curves)) * a
b_list = np.ones(len(base_curves)) * a
base_currents = [PSCCurrent(base_curves, bs_TF, a_list, b_list, i) for i in range(ncoils)]
print(base_currents[0].get_value())
coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True)
base_coils = coils[:ncoils]
[c.current.fix_all() for c in base_coils] # Fix all the current dofs which are fake anyways
Expand Down Expand Up @@ -230,18 +231,12 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
CS_THRESHOLD = 1.5
CS_WEIGHT = 1e2
# Weight for the Coil Coil forces term
FORCE_WEIGHT = Weight(1e-34) # 1e-34 Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCE_WEIGHT = Weight(0.0) # 1e-34 Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
TORQUE_WEIGHT2 = Weight(0.0) # 1e-22 Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
# Directory for output
OUT_DIR = ("./QA_minimal_TForder{:d}_n{:d}_p{:.2e}_c{:.2e}_lw{:.2e}_lt{:.2e}_lkw{:.2e}" +
"_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}_tww{:2e}/").format(
base_curves_TF[0].order, ncoils, poff, coff, LENGTH_WEIGHT.value, LENGTH_TARGET, LINK_WEIGHT,
CC_THRESHOLD, CC_WEIGHT, CS_THRESHOLD, CS_WEIGHT, FORCE_WEIGHT.value,
FORCE_WEIGHT2.value,
TORQUE_WEIGHT.value,
TORQUE_WEIGHT2.value)
OUT_DIR = "./passive_coils/"
if os.path.exists(OUT_DIR):
shutil.rmtree(OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)
Expand Down Expand Up @@ -340,9 +335,9 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
if TORQUE_WEIGHT2.value > 0.0:
JF += TORQUE_WEIGHT2 * Jtorque2

print(JF.dof_names)

def fun(dofs):
print(JF.dof_names)
JF.x = dofs
J = JF.J()
grad = JF.dJ()
Expand Down
35 changes: 32 additions & 3 deletions src/simsopt/field/coil.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,16 @@ class PSCCurrent(sopp.Current, CurrentBase):

def __init__(self, psc_curves, biot_savart_TF, a_list, b_list, index, downsample=1, cross_section='circular', dofs=None, **kwargs):
self.psc_curves = psc_curves # Should include all the psc_curves
# save original TF evaluation points!
self.biot_savart_TF = biot_savart_TF
self.eval_points = self.biot_savart_TF.get_points_cart_ref()
self.a_list = a_list
self.b_list = b_list
self.downsample = downsample
self.cross_section = cross_section
self.index = index
gammas = jnp.array([c.gamma() for c in psc_curves])
self.biot_savart_TF.set_points(gammas.reshape(-1, 3))
self.biot_savart_TF.set_points(gammas[:, ::downsample, :].reshape(-1, 3))
gammadashs = jnp.array([c.gammadash() for c in psc_curves])
quadpoints = jnp.array([c.quadpoints for c in psc_curves])
A_ext = biot_savart_TF.A()
Expand All @@ -148,6 +151,7 @@ def __init__(self, psc_curves, biot_savart_TF, a_list, b_list, index, downsample
jacfwd(self.J_jax, argnums=2)(gammas, gammadashs, A_ext, downsample),
**args
)
self.biot_savart_TF.set_points(self.eval_points)
sopp.Current.__init__(self, current)
if dofs is None:
CurrentBase.__init__(self, external_dof_setter=sopp.Current.set_dofs,
Expand All @@ -160,7 +164,7 @@ def vjp(self, v_current):
gammas = jnp.array([c.gamma() for c in self.psc_curves])
gammadashs = jnp.array([c.gammadash() for c in self.psc_curves])
quadpoints = jnp.array([c.quadpoints for c in self.psc_curves])

Check failure on line 166 in src/simsopt/field/coil.py

View workflow job for this annotation

GitHub Actions / CI (3.9)

Ruff (F841)

src/simsopt/field/coil.py:166:9: F841 Local variable `quadpoints` is assigned to but never used
self.biot_savart_TF.set_points(gammas.reshape(-1, 3))
self.biot_savart_TF.set_points(gammas[:, ::self.downsample, :].reshape(-1, 3))
A_ext = self.biot_savart_TF.A()
args = [
gammas,
Expand All @@ -181,12 +185,37 @@ def vjp(self, v_current):
# should be associated with the TF curves?
# A_vjp returns Derivatives depending on the TF curve and TF coils
vjp3 = sum([self.biot_savart_TF.A_vjp(dJ_dA[i]) for i, c in enumerate(self.biot_savart_TF.coils)])
self.biot_savart_TF.set_points(self.eval_points)

#### ABSOLUTELY ESSENTIAL LINES BELOW
# Otherwise optimizable references multiply
# like crazy as number of coils increases

# self.biot_savart_TF._children = set()
# for c in self.psc_curves:
# c._children = set()
# for c in self.biot_savart_TF.coils:
# c._children = set()
# c.curve._children = set()
# c.current._children = set()

return vjp1 + vjp2 + vjp3

@property
def current(self):
return self.J_jax(*args)[i]
gammas = jnp.array([c.gamma() for c in self.psc_curves])
gammadashs = jnp.array([c.gammadash() for c in self.psc_curves])
quadpoints = jnp.array([c.quadpoints for c in self.psc_curves])

Check failure on line 208 in src/simsopt/field/coil.py

View workflow job for this annotation

GitHub Actions / CI (3.9)

Ruff (F841)

src/simsopt/field/coil.py:208:9: F841 Local variable `quadpoints` is assigned to but never used
self.biot_savart_TF.set_points(gammas.reshape(-1, 3))
A_ext = self.biot_savart_TF.A()
args = [
gammas,
gammadashs,
A_ext,
self.downsample
]
self.biot_savart_TF.set_points(self.eval_points)
return self.J_jax(*args)[self.index]
# return self.get_value()


Expand Down
1 change: 0 additions & 1 deletion src/simsopt/field/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,6 @@ def coil_coil_inductances_full_pure(gammas, gammadashs, quadpoints, a_list, b_li
, inplace=False)
return 1e-7 * Lij


def coil_coil_inductances_inv_pure(gammas, gammadashs, quadpoints, a_list, b_list, downsample, cross_section):
# Lij is symmetric positive definite so has a cholesky decomposition
C = jnp.linalg.cholesky(coil_coil_inductances_full_pure(gammas, gammadashs, quadpoints, a_list, b_list, downsample, cross_section))
Expand Down
1 change: 0 additions & 1 deletion src/simsopt/field/magneticfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _set_points_cb(self):
bf.set_points_cart(self.get_points_cart_ref())

def _B_impl(self, B):
# print(B.shape, [bf.B() for bf in self.Bfields], np.shape(np.sum([bf.B() for bf in self.Bfields], axis=0)))
B[:] = np.sum([bf.B() for bf in self.Bfields], axis=0)

def _dB_by_dX_impl(self, dB):
Expand Down
4 changes: 2 additions & 2 deletions src/simsopt/geo/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,8 @@ def create_planar_curves_between_two_toroidal_surfaces(

# Initialize a bunch of circular coils with same normal vector
for ic in range(ncoils):
alpha2 = np.pi / 2.0
delta2 = 0.0
alpha2 = np.random.rand(1) * np.pi - np.pi / 2.0
delta2 = np.random.rand(1) * np.pi
calpha2 = np.cos(alpha2)
salpha2 = np.sin(alpha2)
cdelta2 = np.cos(delta2)
Expand Down

0 comments on commit 18540a0

Please sign in to comment.