diff --git a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py deleted file mode 100644 index c83cc31433..0000000000 --- a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -import logging - -import torch - - -def init_edge_rot_mat(edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) - - # Make sure the atoms are far enough apart - # assert torch.min(edge_vec_0_distance) < 0.0001 - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") - - norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) - - edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 - edge_vec_2 = edge_vec_2 / (torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)) - # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.clone() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.clone() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - assert torch.max(vec_dot) < 0.99 - - norm_z = torch.cross(norm_x, edge_vec_2, dim=1) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) - norm_y = torch.cross(norm_x, norm_z, dim=1) - norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - - edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) - edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat.detach() diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 98e21a77f1..d96b1ca9ad 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -15,6 +15,7 @@ GraphModelMixin, HeadInterface, ) +from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): @@ -23,7 +24,6 @@ import typing -from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding from .layer_norm import ( @@ -443,9 +443,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -569,10 +567,6 @@ def _init_gp_partitions( edge_distance_vec, ) - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - return init_edge_rot_mat(edge_distance_vec) - @property def num_params(self): return sum(p.numel() for p in self.parameters()) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py index 1da2ed3adb..40a5f1d5d7 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py @@ -11,13 +11,13 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.models.base import GraphModelMixin +from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass -from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding from .layer_norm import ( @@ -484,9 +484,7 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -618,10 +616,6 @@ def forward(self, data): return outputs - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - return init_edge_rot_mat(edge_distance_vec) - @property def num_params(self): return sum(p.numel() for p in self.parameters()) diff --git a/src/fairchem/core/models/escn/edge_rot_mat.py b/src/fairchem/core/models/escn/edge_rot_mat.py new file mode 100644 index 0000000000..72c960a4a9 --- /dev/null +++ b/src/fairchem/core/models/escn/edge_rot_mat.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import logging +import math + +import torch + + +# Algorithm from Ken Whatmough (https://math.stackexchange.com/users/918128/ken-whatmough) +def vec3_to_perp_vec3(v): + """ + Small proof: + input = x y z + output = s(x)|z| s(y)|z| -s(z)(|x|+|y|) + + input dot output + = x*s(x)*|z| + y*s(y)*|z| - z*s(z)*|x| - z*s(z)*|y| + a*s(a)=|a| , + = |x|*|z| + |y|*|z| - |z|*|x| - |z|*|y| = 0 + + """ + return torch.hstack( + [ + v[:, [2]].copysign(v[:, [0, 1]]), + -v[:, [0, 1]].copysign(v[:, [2]]).sum(axis=1, keepdim=True), + ] + ) + + +# https://en.wikipedia.org/wiki/Rodrigues'_rotation_formula#Matrix_notation +def vec3_rotate_around_axis(v, axis, thetas): + # v_rot= v + (sTheta)*(axis X v) + (1-cTheta)*(axis X (axis X v)) + Kv = torch.cross(axis, v, dim=1) + KKv = torch.cross(axis, Kv, dim=1) + s_theta = torch.sin(thetas) + c_theta = torch.cos(thetas) + return v + s_theta * Kv + (1 - c_theta) * KKv + + +def init_edge_rot_mat(edge_distance_vec): + edge_vec_0 = edge_distance_vec.detach() + edge_vec_0_distance = torch.linalg.norm(edge_vec_0, axis=1, keepdim=True) + # Make sure the atoms are far enough apart + # assert torch.min(edge_vec_0_distance) < 0.0001 + if torch.min(edge_vec_0_distance) < 0.0001: + logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") + + norm_x = edge_vec_0 / edge_vec_0_distance + + perp_to_norm_x = vec3_to_perp_vec3(norm_x) + random_rotated_in_plane_perp_to_norm_x = vec3_rotate_around_axis( + perp_to_norm_x, + norm_x, + torch.rand((norm_x.shape[0], 1), device=norm_x.device) * 2 * math.pi, + ) + + norm_z = random_rotated_in_plane_perp_to_norm_x / torch.linalg.norm( + random_rotated_in_plane_perp_to_norm_x, axis=1, keepdim=True + ) + + norm_y = torch.cross(norm_x, norm_z, dim=1) + norm_y /= torch.linalg.norm(norm_y, dim=1, keepdim=True) + + # Construct the 3D rotation matrix + norm_x = norm_x.view(-1, 1, 3) + norm_y = -norm_y.view(-1, 1, 3) + norm_z = norm_z.view(-1, 1, 3) + return torch.cat([norm_z, norm_x, norm_y], dim=1).contiguous() diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index c17b8bda71..9aeb31ac60 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -15,6 +15,10 @@ import torch import torch.nn as nn +from fairchem.core.models.escn.edge_rot_mat import ( + init_edge_rot_mat, +) + if typing.TYPE_CHECKING: from torch_geometric.data.batch import Batch @@ -89,6 +93,7 @@ def __init__( show_timing_info: bool = False, resolution: int | None = None, activation_checkpoint: bool | None = False, + edge_rot_mat: str = "og", ) -> None: if mmax_list is None: mmax_list = [2] @@ -248,9 +253,7 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -365,63 +368,6 @@ def forward(self, data): return outputs - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) - - # Make sure the atoms are far enough apart - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error( - f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" - ) - (minval, minidx) = torch.min(edge_vec_0_distance, 0) - logging.error( - f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" - ) - - norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) - - edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 - edge_vec_2 = edge_vec_2 / ( - torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1) - ) - # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.clone() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.clone() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - assert torch.max(vec_dot) < 0.99 - - norm_z = torch.cross(norm_x, edge_vec_2, dim=1) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) - norm_y = torch.cross(norm_x, norm_z, dim=1) - norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - - edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) - edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat.detach() - @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) @@ -445,9 +391,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() diff --git a/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr index d374d616e1..352a5859d2 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr @@ -6,7 +6,7 @@ # --- # name: TestEquiformerV2.test_ddp.1 Approx( - array([0.12408739], dtype=float32), + array([-0.00897979], dtype=float32), rtol=0.001, atol=0.001 ) @@ -19,7 +19,7 @@ # --- # name: TestEquiformerV2.test_ddp.3 Approx( - array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), + array([-0.00893596, -0.00290774, -0.02622147], dtype=float32), rtol=0.001, atol=0.001 ) @@ -31,7 +31,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.1 Approx( - array([0.12408739], dtype=float32), + array([-0.00897979], dtype=float32), rtol=0.001, atol=0.001 ) @@ -44,7 +44,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.3 Approx( - array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), + array([-0.00893596, -0.00290774, -0.02622147], dtype=float32), rtol=0.001, atol=0.001 ) @@ -56,7 +56,7 @@ # --- # name: TestEquiformerV2.test_gp.1 Approx( - array([0.12408739], dtype=float32), + array([-0.02495257], dtype=float32), rtol=0.001, atol=0.001 ) @@ -69,7 +69,7 @@ # --- # name: TestEquiformerV2.test_gp.3 Approx( - array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32), + array([ 0.00203055, -0.00042872, -0.00279118], dtype=float32), rtol=0.001, atol=0.001 )