diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 3f843a5..dd13ffa 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -1,9 +1,12 @@ +import warnings from tempfile import NamedTemporaryFile import freud import gsd.hoomd import hoomd +import networkx as nx import numpy as np +from boltons.setutils import IndexedSet from cmeutils.geometry import moit @@ -374,3 +377,248 @@ def xml_to_gsd(xmlfile, gsdfile): snap.bonds.group = bonds newt.append(snap) print(f"XML data written to {gsdfile}") + + +def identify_snapshot_connections(snapshot): + """Identify angle and dihedral connections in a snapshot from bonds. + + Parameters + ---------- + snapshot : gsd.hoomd.Frame + The snapshot to read in. + + Returns + ------- + gsd.hoomd.Frame + The snapshot with angle and dihedral information added. + """ + if snapshot.bonds.N == 0: + warnings.warn( + "No bonds found in snapshot, hence, no angles or " + "dihedrals will be identified." + ) + return snapshot + bond_groups = snapshot.bonds.group + connection_matches = _find_connections(bond_groups) + + if connection_matches["angles"]: + _fill_connection_info( + snapshot=snapshot, + connections=connection_matches["angles"], + type_="angles", + ) + if connection_matches["dihedrals"]: + _fill_connection_info( + snapshot=snapshot, + connections=connection_matches["dihedrals"], + type_="dihedrals", + ) + return snapshot + + +def _fill_connection_info(snapshot, connections, type_): + p_types = snapshot.particles.types + p_typeid = snapshot.particles.typeid + _connection_types = [] + _connection_typeid = [] + for conn in connections: + conn_sites = [p_types[p_typeid[i]] for i in conn] + sorted_conn_sites = _sort_connection_by_name(conn_sites, type_) + type = "-".join(sorted_conn_sites) + # check if type not in angle_types and types_inv not in angle_types: + if type not in _connection_types: + _connection_types.append(type) + _connection_typeid.append( + max(_connection_typeid) + 1 if _connection_typeid else 0 + ) + else: + _connection_typeid.append(_connection_types.index(type)) + + if type_ == "angles": + snapshot.angles.N = len(connections) + snapshot.angles.M = 3 + snapshot.angles.group = connections + snapshot.angles.types = _connection_types + snapshot.angles.typeid = _connection_typeid + elif type_ == "dihedrals": + snapshot.dihedrals.N = len(connections) + snapshot.dihedrals.M = 4 + snapshot.dihedrals.group = connections + snapshot.dihedrals.types = _connection_types + snapshot.dihedrals.typeid = _connection_typeid + + +# The following functions are obtained from gmso/utils/connectivity.py with +# minor modifications. +def _sort_connection_by_name(conn_sites, type_): + if type_ == "angles": + site1, site3 = sorted([conn_sites[0], conn_sites[2]]) + return [site1, conn_sites[1], site3] + elif type_ == "dihedrals": + site1, site2, site3, site4 = conn_sites + if site2 > site3 or (site2 == site3 and site1 > site4): + return [site4, site3, site2, site1] + else: + return [site1, site2, site3, site4] + + +def _find_connections(bonds): + """Identify all possible connections within a topology.""" + compound = nx.Graph() + + for b in bonds: + compound.add_edge(b[0], b[1]) + + compound_line_graph = nx.line_graph(compound) + + angle_matches = _detect_connections(compound_line_graph, type_="angle") + dihedral_matches = _detect_connections( + compound_line_graph, type_="dihedral" + ) + + return { + "angles": angle_matches, + "dihedrals": dihedral_matches, + } + + +def _detect_connections(compound_line_graph, type_="angle"): + EDGES = { + "angle": ((0, 1),), + "dihedral": ((0, 1), (1, 2)), + } + + connection = nx.Graph() + for edge in EDGES[type_]: + assert len(edge) == 2, "Edges should be of length 2" + connection.add_edge(edge[0], edge[1]) + + matcher = nx.algorithms.isomorphism.GraphMatcher( + compound_line_graph, connection + ) + + formatter_fns = { + "angle": _format_subgraph_angle, + "dihedral": _format_subgraph_dihedral, + } + + conn_matches = IndexedSet() + for m in matcher.subgraph_isomorphisms_iter(): + new_connection = formatter_fns[type_](m) + conn_matches.add(new_connection) + if conn_matches: + conn_matches = _trim_duplicates(conn_matches) + + # Do more sorting of individual connection + sorted_conn_matches = list() + for match in conn_matches: + if match[0] < match[-1]: + sorted_conn = match + else: + sorted_conn = match[::-1] + sorted_conn_matches.append(list(sorted_conn)) + + # Final sorting the whole list + if type_ == "angle": + return sorted( + sorted_conn_matches, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ) + elif type_ == "dihedral": + return sorted( + sorted_conn_matches, + key=lambda dihedral: ( + dihedral[1], + dihedral[2], + dihedral[0], + dihedral[3], + ), + ) + + +def _get_sorted_by_n_connections(m): + """Return sorted by n connections for the matching graph.""" + small = nx.Graph() + for k, v in m.items(): + small.add_edge(k[0], k[1]) + return sorted(small.adj, key=lambda x: len(small[x])), small + + +def _format_subgraph_angle(m): + """Format the angle subgraph. + + Since we are matching compound line graphs, + back out the actual nodes, not just the edges + + Parameters + ---------- + m : dict + keys are the compound line graph nodes + Values are the sub-graph matches (to the angle, dihedral, or improper) + + Returns + ------- + connection : list of nodes, in order of bonding + (start, middle, end) + """ + (sort_by_n_connections, _) = _get_sorted_by_n_connections(m) + ends = sorted([sort_by_n_connections[0], sort_by_n_connections[1]]) + middle = sort_by_n_connections[2] + return ( + ends[0], + middle, + ends[1], + ) + + +def _format_subgraph_dihedral(m): + """Format the dihedral subgraph. + + Since we are matching compound line graphs, + back out the actual nodes, not just the edges + + Parameters + ---------- + m : dict + keys are the compound line graph nodes + Values are the sub-graph matches (to the angle, dihedral, or improper) + top : gmso.Topology + The original Topology + + Returns + ------- + connection : list of nodes, in order of bonding + (start, mid1, mid2, end) + """ + (sort_by_n_connections, small) = _get_sorted_by_n_connections(m) + start = sort_by_n_connections[0] + if sort_by_n_connections[2] in small.neighbors(start): + mid1 = sort_by_n_connections[2] + mid2 = sort_by_n_connections[3] + else: + mid1 = sort_by_n_connections[3] + mid2 = sort_by_n_connections[2] + + end = sort_by_n_connections[1] + return (start, mid1, mid2, end) + + +def _trim_duplicates(all_matches): + """Remove redundant sub-graph matches. + + Is there a better way to do this? Like when we format the subgraphs, + can we impose an ordering so it's easier to eliminate redundant matches? + """ + trimmed_list = IndexedSet() + for match in all_matches: + if ( + match + and match not in trimmed_list + and match[::-1] not in trimmed_list + ): + trimmed_list.add(match) + return trimmed_list diff --git a/cmeutils/tests/assets/pekk-cg.gsd b/cmeutils/tests/assets/pekk-cg.gsd new file mode 100644 index 0000000..1a15bdf Binary files /dev/null and b/cmeutils/tests/assets/pekk-cg.gsd differ diff --git a/cmeutils/tests/base_test.py b/cmeutils/tests/base_test.py index f9a93ee..a75210b 100644 --- a/cmeutils/tests/base_test.py +++ b/cmeutils/tests/base_test.py @@ -53,6 +53,10 @@ def p3ht_gsd(self): def p3ht_cg_gsd(self): return path.join(asset_dir, "p3ht-cg.gsd") + @pytest.fixture + def pekk_cg_gsd(self): + return path.join(asset_dir, "pekk-cg.gsd") + @pytest.fixture def mapping(self): return np.loadtxt(path.join(asset_dir, "mapping.txt"), dtype=int) diff --git a/cmeutils/tests/test_gsd.py b/cmeutils/tests/test_gsd.py index e54cc62..46cf35a 100644 --- a/cmeutils/tests/test_gsd.py +++ b/cmeutils/tests/test_gsd.py @@ -5,6 +5,7 @@ import packaging.version import pytest from base_test import BaseTest +from gmso.external import from_mbuild, to_gsd_snapshot from mbuild.formats.hoomd_forcefield import to_hoomdsnapshot from cmeutils.gsd_utils import ( @@ -15,6 +16,7 @@ get_all_types, get_molecule_cluster, get_type_position, + identify_snapshot_connections, snap_delete_types, update_rigid_snapshot, xml_to_gsd, @@ -122,3 +124,124 @@ def test_xml_to_gsd(self, tmp_path, p3ht_gsd, p3ht_xml): old_snap.particles.position == new_snap.particles.position ) assert np.all(old_snap.particles.image == new_snap.particles.image) + + def test_identify_snapshot_connections_benzene(self): + benzene = mb.load("c1ccccc1", smiles=True) + topology = from_mbuild(benzene) + no_connection_snapshot, _ = to_gsd_snapshot(topology) + assert no_connection_snapshot.bonds.N == 12 + assert no_connection_snapshot.angles.N == 0 + assert no_connection_snapshot.dihedrals.N == 0 + updated_snapshot = identify_snapshot_connections(no_connection_snapshot) + + topology.identify_connections() + topology_snapshot, _ = to_gsd_snapshot(topology) + assert updated_snapshot.angles.N == topology_snapshot.angles.N + assert np.array_equal( + sorted( + updated_snapshot.angles.group, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ), + sorted( + topology_snapshot.angles.group, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ), + ) + assert sorted(updated_snapshot.angles.types) == sorted( + topology_snapshot.angles.types + ) + assert len(updated_snapshot.angles.typeid) == len( + topology_snapshot.angles.typeid + ) + assert updated_snapshot.dihedrals.N == topology_snapshot.dihedrals.N + assert np.array_equal( + sorted( + updated_snapshot.dihedrals.group, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ), + sorted( + topology_snapshot.dihedrals.group, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ), + ) + assert sorted(updated_snapshot.dihedrals.types) == sorted( + topology_snapshot.dihedrals.types + ) + assert len(updated_snapshot.dihedrals.typeid) == len( + topology_snapshot.dihedrals.typeid + ) + + def test_identify_connection_thiophene(self): + thiophene = mb.load("c1cscc1", smiles=True) + topology = from_mbuild(thiophene) + no_connection_snapshot, _ = to_gsd_snapshot(topology) + updated_snapshot = identify_snapshot_connections(no_connection_snapshot) + assert updated_snapshot.angles.N == 13 + assert sorted(updated_snapshot.angles.types) == sorted( + ["C-S-C", "H-C-S", "C-C-H", "C-C-S", "C-C-C"] + ) + + assert updated_snapshot.dihedrals.N == 16 + assert sorted(updated_snapshot.dihedrals.types) == sorted( + [ + "C-C-C-H", + "C-C-C-C", + "H-C-C-H", + "H-C-S-C", + "H-C-C-S", + "C-C-S-C", + "C-C-C-S", + ] + ) + + def test_identify_connection_no_dihedrals(self): + methane = mb.load("C", smiles=True) + topology = from_mbuild(methane) + no_connection_snapshot, _ = to_gsd_snapshot(topology) + assert no_connection_snapshot.bonds.N != 0 + assert no_connection_snapshot.angles.N == 0 + assert no_connection_snapshot.dihedrals.N == 0 + updated_snapshot = identify_snapshot_connections(no_connection_snapshot) + assert updated_snapshot.angles.N == 6 + assert updated_snapshot.angles.types == ["H-C-H"] + assert updated_snapshot.angles.typeid == [0, 0, 0, 0, 0, 0] + assert updated_snapshot.dihedrals.N == 0 + assert updated_snapshot.dihedrals.types is None + assert updated_snapshot.dihedrals.typeid is None + + def test_identify_connection_no_connections(self): + snapshot = gsd.hoomd.Frame() + snapshot.particles.N = 2 + snapshot.particles.types = ["A", "B"] + snapshot.particles.typeid = [0, 1] + with pytest.warns(UserWarning): + updated_snapshot = identify_snapshot_connections(snapshot) + assert updated_snapshot.bonds.N == 0 + assert updated_snapshot.angles.N == 0 + assert updated_snapshot.dihedrals.N == 0 + + def test_identify_connections_pekk_cg(self, pekk_cg_gsd): + with gsd.hoomd.open(pekk_cg_gsd) as traj: + snap = traj[0] + assert snap.angles.types == [] + snap_with_connections = identify_snapshot_connections(snap) + assert "K-E-K" in snap_with_connections.angles.types + assert "E-K-K" in snap_with_connections.angles.types + assert "K-E-K-K" in snap_with_connections.dihedrals.types + assert "E-K-K-E" in snap_with_connections.dihedrals.types diff --git a/environment-dev.yml b/environment-dev.yml index dad4e3d..bb2ebca 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - freud >=2.13.1 + - gmso >=0.11.2 - fresnel >=0.13.5 - gsd >=3.0 - hoomd >=4.0 diff --git a/environment.yml b/environment.yml index eece9ba..94f99f8 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - freud >=2.13.1 + - gmso >=0.11.2 - fresnel >=0.13.5 - gsd >=3.0 - hoomd >=4.0