Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(pt): Support fitting_net statistics. #4504

Draft
wants to merge 4 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def wrapped_sampler():
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(wrapped_sampler, stat_file_path)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def get_dim_fparam(self) -> int:
Expand Down
46 changes: 46 additions & 0 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
abstractmethod,
)
from typing import (
Callable,
Optional,
Union,
)
Expand Down Expand Up @@ -39,6 +40,9 @@
get_index_between_two_maps,
map_atom_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -409,6 +413,48 @@
"""Set the FittingNet output dim."""
pass

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.

"""
if callable(merged):
sampled = merged()
else:
sampled = merged

Check warning on line 440 in deepmd/pt/model/task/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/fitting.py#L440

Added line #L440 was not covered by tests
# stat fparam
if self.numb_fparam > 0:
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
fparam_avg = torch.mean(cat_data, axis=0)
fparam_std = torch.std(cat_data, axis=0)
fparam_inv_std = 1.0 / fparam_std
self.fparam_avg.copy_(
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
)
self.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
)
)
# TODO: stat aparam
Comment on lines +442 to +456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle potential zero or near-zero standard deviation.
Currently, the code divides by fparam_std, potentially leading to inf or NaN values if std == 0. Consider adding a small epsilon or performing a check to avoid division by zero.

 fparam_std = torch.std(cat_data, axis=0)
+epsilon = 1e-12
+fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std)
 fparam_inv_std = 1.0 / fparam_std
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.numb_fparam > 0:
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
fparam_avg = torch.mean(cat_data, axis=0)
fparam_std = torch.std(cat_data, axis=0)
fparam_inv_std = 1.0 / fparam_std
self.fparam_avg.copy_(
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
)
self.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
)
)
# TODO: stat aparam
if self.numb_fparam > 0:
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
fparam_avg = torch.mean(cat_data, axis=0)
fparam_std = torch.std(cat_data, axis=0)
epsilon = 1e-12
fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std)
fparam_inv_std = 1.0 / fparam_std
self.fparam_avg.copy_(
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
)
self.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
)
)
# TODO: stat aparam


def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor:
return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1])

Expand Down
Loading