Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 19, 2023
1 parent a42da2d commit 0fc08ec
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 22 deletions.
22 changes: 19 additions & 3 deletions examples/plot_spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ def get_cmap(x):
if x == "bwr":
return [[0, "rgb(0,50,255)"], [0.5, "rgb(200,200,200)"], [1, "rgb(255,50,0)"]]
if x == "plasma":
return [[0, "#9F1A9B"], [0.25, "#0D1286"], [0.5, "#000000"], [0.75, "#F58C45"], [1, "#F0F524"]]
return [
[0, "#9F1A9B"],
[0.25, "#0D1286"],
[0.5, "#000000"],
[0.75, "#F58C45"],
[1, "#F0F524"],
]


alpha = jnp.linspace(0, 2 * jnp.pi, 200)
Expand All @@ -16,7 +22,9 @@ def get_cmap(x):
alpha, beta = jnp.meshgrid(alpha, beta, indexing="ij")
vectors = e3nn.angles_to_xyz(alpha, beta)

signal = e3nn.spherical_harmonics("8e", vectors, normalize=True, normalization="component").array
signal = e3nn.spherical_harmonics(
"8e", vectors, normalize=True, normalization="component"
).array
signal = signal[:, :, 8]

data = [
Expand All @@ -32,7 +40,15 @@ def get_cmap(x):
)
]

axis = dict(showbackground=False, showticklabels=False, showgrid=False, zeroline=False, title="", nticks=3, range=[-3, 3])
axis = dict(
showbackground=False,
showticklabels=False,
showgrid=False,
zeroline=False,
title="",
nticks=3,
range=[-3, 3],
)

layout = dict(
width=512,
Expand Down
18 changes: 14 additions & 4 deletions examples/tensor_product_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ def tp(x1, x2):
if args.module:
assert not args.extrachannels
assert args.weights
return e3nn.haiku.FullyConnectedTensorProduct(args.irreps_out)(x1, x2, **kwargs)
return e3nn.haiku.FullyConnectedTensorProduct(args.irreps_out)(
x1, x2, **kwargs
)
else:
if args.extrachannels:
assert not args.module
x1 = x1.mul_to_axis() # (batch, channels, irreps)
x2 = x2.mul_to_axis() # (batch, channels, irreps)
x = e3nn.tensor_product(x1[..., :, None, :], x2[..., None, :, :], **kwargs)
x = e3nn.tensor_product(
x1[..., :, None, :], x2[..., None, :, :], **kwargs
)
x = x.reshape(x.shape[:-3] + (-1,) + x.shape[-1:])
x = x.axis_to_mul()
else:
Expand All @@ -91,7 +95,10 @@ def tp(x1, x2):
else:
return x.filter(keep=args.irreps_out)

inputs = (e3nn.normal(args.irreps_in1, k(), (args.batch,)), e3nn.normal(args.irreps_in2, k(), (args.batch,)))
inputs = (
e3nn.normal(args.irreps_in1, k(), (args.batch,)),
e3nn.normal(args.irreps_in2, k(), (args.batch,)),
)
w = tp.init(k(), *inputs)

# Ensure everything is on the GPU (shouldn't be necessary, but just in case)
Expand All @@ -112,7 +119,10 @@ def tp(x1, x2):
# tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
f_2 = f
f = jax.value_and_grad(
lambda w, x1, x2: sum(jnp.sum(jnp.tanh(out)) for out in jax.tree_util.tree_leaves(f_2(w, x1, x2)))
lambda w, x1, x2: sum(
jnp.sum(jnp.tanh(out))
for out in jax.tree_util.tree_leaves(f_2(w, x1, x2))
)
)

# compile
Expand Down
22 changes: 15 additions & 7 deletions examples/tetris_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ def __call__(self, pos, edge_src, edge_dst):
)

for _ in range(4):
node_feat = MessagePassingConvolutionFlax("32x0e + 32x0o + 16x0e + 8x1e + 8x1o", **kw)(
pos, node_feat, edge_src, edge_dst
)
node_feat = MessagePassingConvolutionFlax(
"32x0e + 32x0o + 16x0e + 8x1e + 8x1o", **kw
)(pos, node_feat, edge_src, edge_dst)
node_feat = e3nn.gate(node_feat)
node_feat = MessagePassingConvolutionFlax("0o + 7x0e", **kw)(pos, node_feat, edge_src, edge_dst)
node_feat = MessagePassingConvolutionFlax("0o + 7x0e", **kw)(
pos, node_feat, edge_src, edge_dst
)

return node_feat.array

Expand All @@ -86,7 +88,9 @@ def update(params, opt_state, pos, edge_src, edge_dst, labels, batch):
grad_fn = jax.value_and_grad(loss_pred, has_aux=True)
(_, pred), grads = grad_fn(params, pos, edge_src, edge_dst, labels, batch)
accuracy_odd = jnp.sign(jnp.round(pred[:, 0])) == labels[:, 0]
accuracy_even = jnp.argmax(pred[:, 1:], axis=1) == jnp.argmax(labels[:, 1:], axis=1)
accuracy_even = jnp.argmax(pred[:, 1:], axis=1) == jnp.argmax(
labels[:, 1:], axis=1
)
accuracy = (jnp.mean(accuracy_odd) + jnp.mean(accuracy_even)) / 2
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
Expand All @@ -108,13 +112,17 @@ def update(params, opt_state, pos, edge_src, edge_dst, labels, batch):

wall = time.perf_counter()
for it in range(1, steps + 1):
params, opt_state, accuracy, pred = update(params, opt_state, pos, edge_src, edge_dst, labels, batch)
params, opt_state, accuracy, pred = update(
params, opt_state, pos, edge_src, edge_dst, labels, batch
)

print(f"[{it}] accuracy = {100 * accuracy:.0f}%")

if accuracy == 1:
total = time.perf_counter() - wall
print(f"100% accuracy has been reach in {total:.1f}s after {it} iterations ({1000 * total/it:.1f}ms/it).")
print(
f"100% accuracy has been reach in {total:.1f}s after {it} iterations ({1000 * total/it:.1f}ms/it)."
)
break

jnp.set_printoptions(precision=2, suppress=True)
Expand Down
24 changes: 18 additions & 6 deletions examples/tetris_point_jraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,24 @@ class Layer(flax.linen.Module):
@flax.linen.compact
def __call__(self, graphs, positions):
target_irreps = e3nn.Irreps(self.target_irreps)
vectors = positions[graphs.receivers] - positions[graphs.senders] # [n_edges, 1e or 1o]
vectors = (
positions[graphs.receivers] - positions[graphs.senders]
) # [n_edges, 1e or 1o]
sh = e3nn.spherical_harmonics(list(range(1, self.sh_lmax + 1)), vectors, True)

def update_edge_fn(edge_features, sender_features, receiver_features, globals):
return e3nn.concatenate([sender_features, e3nn.tensor_product(sender_features, sh)]).regroup()
return e3nn.concatenate(
[sender_features, e3nn.tensor_product(sender_features, sh)]
).regroup()

def update_node_fn(node_features, sender_features, receiver_features, globals):
shortcut = e3nn.flax.Linear(target_irreps, name="shortcut")(node_features)

node_feats = receiver_features / jnp.sqrt(self.avg_num_neighbors)
node_feats = e3nn.flax.Linear(target_irreps, name="linear_pre")(node_feats)
node_feats = e3nn.scalar_activation(node_feats, even_act=jax.nn.gelu, odd_act=jax.nn.tanh)
node_feats = e3nn.scalar_activation(
node_feats, even_act=jax.nn.gelu, odd_act=jax.nn.tanh
)
node_feats = e3nn.flax.Linear(target_irreps, name="linear_post")(node_feats)
return shortcut + node_feats

Expand All @@ -82,8 +88,12 @@ def __call__(self, graphs):
graphs = graphs._replace(nodes=jnp.ones((len(positions), 1)))

ann = 2.0
graphs = Layer("32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o", ann)(graphs, positions)
graphs = Layer("32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o", ann)(graphs, positions)
graphs = Layer("32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o", ann)(
graphs, positions
)
graphs = Layer("32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o", ann)(
graphs, positions
)
graphs = Layer("0o + 7x0e", ann)(graphs, positions)
return graphs.nodes

Expand Down Expand Up @@ -143,7 +153,9 @@ def update(params, opt_state, graphs):
done = False
wall = time.perf_counter()
for it in range(1, steps + 1):
params, opt_state, loss, accuracy, logits = update(params, opt_state, graphs)
params, opt_state, loss, accuracy, logits = update(
params, opt_state, graphs
)
losses.append(loss)

if not done and accuracy == 1.0:
Expand Down
14 changes: 12 additions & 2 deletions examples/tetris_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ def model(x):
# Shallower and wider convolutions also works

# kw = dict(irreps_sh=Irreps('0e + 1o'), diameter=5.5, num_radial_basis=3, steps=(1.0, 1.0, 1.0))
kw = dict(irreps_sh=e3nn.Irreps("0e + 1o"), diameter=2 * 1.4, num_radial_basis=1, steps=(1.0, 1.0, 1.0))
kw = dict(
irreps_sh=e3nn.Irreps("0e + 1o"),
diameter=2 * 1.4,
num_radial_basis=1,
steps=(1.0, 1.0, 1.0),
)

x = e3nn.IrrepsArray("0e", x[..., None])

# for _ in range(2):
for _ in range(5):
x = g(ConvolutionHaiku(f"{mul0}x0e + {mul0}x0o + {2 * mul1}x0e + {mul1}x1e + {mul1}x1o", **kw)(x))
x = g(
ConvolutionHaiku(
f"{mul0}x0e + {mul0}x0o + {2 * mul1}x0e + {mul1}x1e + {mul1}x1o",
**kw,
)(x)
)

x = ConvolutionHaiku("0o + 7x0e", **kw)(x)

Expand Down

0 comments on commit 0fc08ec

Please sign in to comment.