Skip to content

Commit

Permalink
Merge branch 'main' of github.com:e3nn/e3nn-jax
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 6, 2023
2 parents 8599120 + 63ce978 commit c74f260
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions docs/tuto/nequip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,12 @@ As optimizer we will use Adam. This optimizer needs to keep track of the average
random_key = jax.random.PRNGKey(0) # change it to get different initializations

# Initialize the model
f = Model()
w = jax.jit(f.init)(random_key, dataset)
model = Model()
params = jax.jit(model.init)(random_key, dataset)

# Initialize the optimizer
opt = optax.adam(1e-3)
opt_state = opt.init(w)
opt_state = opt.init(params)


Let's define the training step. We will use ``jax.jit`` to compile the function and make it faster.
Expand All @@ -347,27 +347,27 @@ This function takes as input the model parameters, the optimizer state and the d
.. jupyter-execute::

@jax.jit
def train_step(opt_state, w, dataset):
def train_step(opt_state, params, dataset):
"""Perform a single training step."""
num_graphs = dataset.n_node.shape[0]

# Compute the loss as a function of the parameters
def fun(w):
preds = f.apply(w, dataset).array.squeeze(1)
preds = model.apply(w, dataset).array.squeeze(1)
targets = dataset.globals["energies"]

assert preds.shape == (num_graphs,)
assert targets.shape == (num_graphs,)
return loss_fn(preds, targets)

# And take its gradient
loss, grad = jax.value_and_grad(fun)(w)
loss, grad = jax.value_and_grad(fun)(params)

# Update the parameters and the optimizer state
updates, opt_state = opt.update(grad, opt_state)
w = optax.apply_updates(w, updates)
params = optax.apply_updates(params, updates)

return opt_state, w, loss
return opt_state, params, loss


Finally, let's train the model for 1000 iterations.
Expand All @@ -376,7 +376,7 @@ Finally, let's train the model for 1000 iterations.

losses = []
for _ in range(1000):
opt_state, w, loss = train_step(opt_state, w, dataset)
opt_state, params, loss = train_step(opt_state, params, dataset)
losses.append(loss)

Did it work?
Expand Down

0 comments on commit c74f260

Please sign in to comment.