diff --git a/tetris_jax_baseline.py b/tetris_jax_baseline.py index 70f9e9a..a8f93bc 100644 --- a/tetris_jax_baseline.py +++ b/tetris_jax_baseline.py @@ -68,8 +68,6 @@ def update_edge_fn(edge_features, sender_features, receiver_features, globals): def update_node_fn(node_features, sender_features, receiver_features, globals): node_feats = receiver_features / self.denominator node_feats = e3nn.flax.Linear(target_irreps, name="linear_pre")(node_feats) - node_feats = e3nn.scalar_activation(node_feats) - node_feats = e3nn.flax.Linear(target_irreps, name="linear_post")(node_feats) shortcut = e3nn.flax.Linear( node_feats.irreps, name="shortcut", force_irreps_out=True )(node_features)