diff --git a/model.py b/model.py index 04710bb..b63a1fa 100644 --- a/model.py +++ b/model.py @@ -1,9 +1,5 @@ import torch -torch.jit.script = lambda x: x import torch.nn as nn -import e3nn -e3nn.set_optimization_defaults(jit_script_fx=False) - from e3nn import o3 from torch_geometric.data import Data