diff --git a/examples/InPSNAPExample.py b/examples/InPSNAPExample.py index 61c04588..9e470f92 100644 --- a/examples/InPSNAPExample.py +++ b/examples/InPSNAPExample.py @@ -104,13 +104,32 @@ from hippynn import plotting - plot_maker = plotting.PlotMaker( + plotters = [ plotting.Hist2D.compare(sys_energy, saved=True), plotting.Hist2D.compare(en_peratom, saved=True), plotting.Hist2D.compare(force,saved=True), plotting.SensitivityPlot(network.torch_module.sensitivity_layers[0], saved="sense_0.pdf"), - plot_every=plot_every, - ) + ] + + ## More advanced graph and plotter usage to display forces separately for each species + from hippynn.graphs.nodes.indexers import SpeciesIndexed + force_pred_by_species = SpeciesIndexed("ForceBySpeciesPredicted", parents=(force.pred,)) + force_true_by_species = SpeciesIndexed("ForceBySpeciesTrue", parents=(force.true,)) + + for true_idxed, pred_idxed in zip(force_true_by_species.children, force_pred_by_species.children): + species_num = true_idxed.name.split("_")[-1] # Parse the node name to find the species value + plotters.append( + plotting.Hist2D( + x_var=true_idxed, + y_var=pred_idxed, + xlabel=f"true Forces, species {species_num}", + ylabel=f"predicted Forces, species {species_num}", + saved=f"Hist2D-ForceBySpecies{species_num}.pdf", + ) + ) + ## end + + plot_maker = plotting.PlotMaker(*plotters, plot_every=plot_every) from hippynn.experiment.assembly import assemble_for_training @@ -147,7 +166,7 @@ if v.dtype not in [np.float64,np.float32,np.int64]: del arrays[k] - database.make_trainvalidtest_split(test_size=0.2, valid_size=0.4) + database.make_trainvalidtest_split(test_size=test_size, valid_size=valid_size) ### End pre-processing # Now that we have a database and a model, we can diff --git a/hippynn/graphs/nodes/base/__init__.py b/hippynn/graphs/nodes/base/__init__.py index fd0d7544..b106fd10 100644 --- a/hippynn/graphs/nodes/base/__init__.py +++ b/hippynn/graphs/nodes/base/__init__.py @@ -7,7 +7,7 @@ from .base import Node, SingleNode, InputNode, LossInputNode, LossPredNode, LossTrueNode, _BaseNode # Node that provides multiple outputs -from .multi import MultiNode +from .multi import MultiNode, IndexNode # Optional mixins for simplifying the process of defining BaseNode subclasses from .definition_helpers import AutoKw, AutoNoKw, ExpandParents diff --git a/hippynn/graphs/nodes/indexers.py b/hippynn/graphs/nodes/indexers.py index ffb20fed..e7253b66 100644 --- a/hippynn/graphs/nodes/indexers.py +++ b/hippynn/graphs/nodes/indexers.py @@ -2,12 +2,13 @@ Nodes for indexing information. """ from .tags import Encoder, AtomIndexer -from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode +from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode, IndexNode from .base.node_functions import NodeNotFound from .inputs import SpeciesNode # Index generating functions need access to appropriately raise this from ..indextypes import IdxType +from ..indextypes.reduce_funcs import index_type_coercion from ...layers import indexers as index_modules @@ -208,4 +209,36 @@ def __init__(self, name, parents, length, vmin, vmax, module="auto", **kwargs): self._output_index_state = parents[0]._index_state self.module_kwargs = {"length": length, "vmin": vmin, "vmax": vmax} - super().__init__(name, parents, module=module, **kwargs) \ No newline at end of file + super().__init__(name, parents, module=module, **kwargs) + +class SpeciesIndexed(AutoNoKw, SingleNode, ExpandParents): + _input_names = "values", "onehot_encoding" + _auto_module_class = index_modules.SpeciesIndex + _index_state = IdxType.Atoms + + @_parent_expander.match(_BaseNode) + def expansion0(self, node_to_index, species_set, **kwargs): + atom_node_to_index = index_type_coercion(node_to_index, IdxType.Atoms) + onehot = find_unique_relative(atom_node_to_index, OneHotEncoder) + self.species_set = species_set or onehot.species_set + return atom_node_to_index, onehot.encoding + + # add asserts for parent expansion + _parent_expander.assertlen(2) + _parent_expander.get_main_outputs() + _parent_expander.require_idx_states(IdxType.Atoms, IdxType.Atoms) + + def __init__(self, name, parents, *args, module="auto", species_set=None, **kwargs): + parents = self.expand_parents(parents, species_set=species_set) + super().__init__(name, parents, *args, module=module, **kwargs) + + nonzero_species = [species for species in self.species_set if species != 0] + self.species_to_idx = {species: idx for idx, species in enumerate(nonzero_species)} + + self.children = tuple( + IndexNode(name=f"{name}_{species}", parents=(self,), index=idx, index_state=IdxType.Atoms) + for species, idx in self.species_to_idx.items() + ) + + def with_species_equal(self, z_value): + return self.children[self.species_to_idx(z_value)] \ No newline at end of file diff --git a/hippynn/layers/indexers.py b/hippynn/layers/indexers.py index e2335324..2b8faaee 100644 --- a/hippynn/layers/indexers.py +++ b/hippynn/layers/indexers.py @@ -270,4 +270,14 @@ def forward(self, values): values = values[...,None] x = values - self.bins histo = torch.exp(-((x / self.sigma) ** 2) / 4) - return torch.flatten(histo, end_dim=1) \ No newline at end of file + return torch.flatten(histo, end_dim=1) + +class SpeciesIndex(torch.nn.Module): + def forward(self, values, onehot_encoding): + n_species = onehot_encoding.shape[1] + values_by_species = [] + for i in range(n_species): + species_mask = onehot_encoding[:,i] + species_values = values[species_mask] + values_by_species.append(species_values) + return values_by_species \ No newline at end of file