Skip to content

Commit

Permalink
split by species working
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Dec 3, 2024
1 parent 6dd3a91 commit 6eb79a0
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 8 deletions.
27 changes: 23 additions & 4 deletions examples/InPSNAPExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,32 @@

from hippynn import plotting

plot_maker = plotting.PlotMaker(
plotters = [
plotting.Hist2D.compare(sys_energy, saved=True),
plotting.Hist2D.compare(en_peratom, saved=True),
plotting.Hist2D.compare(force,saved=True),
plotting.SensitivityPlot(network.torch_module.sensitivity_layers[0], saved="sense_0.pdf"),
plot_every=plot_every,
)
]

## More advanced graph and plotter usage to display forces separately for each species
from hippynn.graphs.nodes.indexers import SpeciesIndexed
force_pred_by_species = SpeciesIndexed("ForceBySpeciesPredicted", parents=(force.pred,))
force_true_by_species = SpeciesIndexed("ForceBySpeciesTrue", parents=(force.true,))

for true_idxed, pred_idxed in zip(force_true_by_species.children, force_pred_by_species.children):
species_num = true_idxed.name.split("_")[-1] # Parse the node name to find the species value
plotters.append(
plotting.Hist2D(
x_var=true_idxed,
y_var=pred_idxed,
xlabel=f"true Forces, species {species_num}",
ylabel=f"predicted Forces, species {species_num}",
saved=f"Hist2D-ForceBySpecies{species_num}.pdf",
)
)
## end

plot_maker = plotting.PlotMaker(*plotters, plot_every=plot_every)

from hippynn.experiment.assembly import assemble_for_training

Expand Down Expand Up @@ -147,7 +166,7 @@
if v.dtype not in [np.float64,np.float32,np.int64]:
del arrays[k]

database.make_trainvalidtest_split(test_size=0.2, valid_size=0.4)
database.make_trainvalidtest_split(test_size=test_size, valid_size=valid_size)
### End pre-processing

# Now that we have a database and a model, we can
Expand Down
2 changes: 1 addition & 1 deletion hippynn/graphs/nodes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .base import Node, SingleNode, InputNode, LossInputNode, LossPredNode, LossTrueNode, _BaseNode

# Node that provides multiple outputs
from .multi import MultiNode
from .multi import MultiNode, IndexNode

# Optional mixins for simplifying the process of defining BaseNode subclasses
from .definition_helpers import AutoKw, AutoNoKw, ExpandParents
37 changes: 35 additions & 2 deletions hippynn/graphs/nodes/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Nodes for indexing information.
"""
from .tags import Encoder, AtomIndexer
from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode
from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode, IndexNode
from .base.node_functions import NodeNotFound
from .inputs import SpeciesNode

# Index generating functions need access to appropriately raise this
from ..indextypes import IdxType
from ..indextypes.reduce_funcs import index_type_coercion
from ...layers import indexers as index_modules


Expand Down Expand Up @@ -208,4 +209,36 @@ def __init__(self, name, parents, length, vmin, vmax, module="auto", **kwargs):
self._output_index_state = parents[0]._index_state
self.module_kwargs = {"length": length, "vmin": vmin, "vmax": vmax}

super().__init__(name, parents, module=module, **kwargs)
super().__init__(name, parents, module=module, **kwargs)

class SpeciesIndexed(AutoNoKw, SingleNode, ExpandParents):
_input_names = "values", "onehot_encoding"
_auto_module_class = index_modules.SpeciesIndex
_index_state = IdxType.Atoms

@_parent_expander.match(_BaseNode)
def expansion0(self, node_to_index, species_set, **kwargs):
atom_node_to_index = index_type_coercion(node_to_index, IdxType.Atoms)
onehot = find_unique_relative(atom_node_to_index, OneHotEncoder)
self.species_set = species_set or onehot.species_set
return atom_node_to_index, onehot.encoding

# add asserts for parent expansion
_parent_expander.assertlen(2)
_parent_expander.get_main_outputs()
_parent_expander.require_idx_states(IdxType.Atoms, IdxType.Atoms)

def __init__(self, name, parents, *args, module="auto", species_set=None, **kwargs):
parents = self.expand_parents(parents, species_set=species_set)
super().__init__(name, parents, *args, module=module, **kwargs)

nonzero_species = [species for species in self.species_set if species != 0]
self.species_to_idx = {species: idx for idx, species in enumerate(nonzero_species)}

self.children = tuple(
IndexNode(name=f"{name}_{species}", parents=(self,), index=idx, index_state=IdxType.Atoms)
for species, idx in self.species_to_idx.items()
)

def with_species_equal(self, z_value):
return self.children[self.species_to_idx(z_value)]
12 changes: 11 additions & 1 deletion hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,14 @@ def forward(self, values):
values = values[...,None]
x = values - self.bins
histo = torch.exp(-((x / self.sigma) ** 2) / 4)
return torch.flatten(histo, end_dim=1)
return torch.flatten(histo, end_dim=1)

class SpeciesIndex(torch.nn.Module):
def forward(self, values, onehot_encoding):
n_species = onehot_encoding.shape[1]
values_by_species = []
for i in range(n_species):
species_mask = onehot_encoding[:,i]
species_values = values[species_mask]
values_by_species.append(species_values)
return values_by_species

0 comments on commit 6eb79a0

Please sign in to comment.