Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for spin, DeltaSpin and LAMMPS SPIN #728

Open
wants to merge 14 commits into
base: devel
Choose a base branch
from
11 changes: 9 additions & 2 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def _load_set(folder, nopbc: bool):
cells = np.zeros((coords.shape[0], 3, 3))
else:
cells = np.load(os.path.join(folder, "box.npy"))
return cells, coords
spins = _cond_load_data(os.path.join(folder, "spin.npy"))
return cells, coords, spins


def to_system_data(folder, type_map=None, labels=True):
Expand All @@ -38,13 +39,18 @@ def to_system_data(folder, type_map=None, labels=True):
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
all_cells = []
all_coords = []
all_spins = []
for ii in sets:
cells, coords = _load_set(ii, data.get("nopbc", False))
cells, coords, spins = _load_set(ii, data.get("nopbc", False))
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
if spins is not None:
all_spins.append(np.reshape(spins, [nframes, -1, 3]))
data["cells"] = np.concatenate(all_cells, axis=0)
data["coords"] = np.concatenate(all_coords, axis=0)
if len(all_spins) > 0:
data["spins"] = np.concatenate(all_spins, axis=0)
# allow custom dtypes
if labels:
dtypes = dpdata.system.LabeledSystem.DTYPES
Expand All @@ -59,6 +65,7 @@ def to_system_data(folder, type_map=None, labels=True):
"orig",
"cells",
"coords",
"spins",
"real_atom_names",
"nopbc",
):
Expand Down
1 change: 1 addition & 0 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def to_system_data(folder, type_map=None, labels=True):
"orig",
"cells",
"coords",
"spins",
"real_atom_types",
"real_atom_names",
"nopbc",
Expand Down
49 changes: 49 additions & 0 deletions dpdata/lammps/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,42 @@
) # Convert scaled coordinates back to Cartesien coordinates with wraping at periodic boundary conditions


def get_spintype(keys):
key_sp = ["sp", "spx", "spy", "spz"]
key_csp = ["c_spin[1]", "c_spin[2]", "c_spin[3]", "c_spin[4]"]
lmp_sp_type = [key_sp, key_csp]
for k in range(2):
if all(i in keys for i in lmp_sp_type[k]):
return lmp_sp_type[k]

Check warning on line 141 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L136-L141

Added lines #L136 - L141 were not covered by tests

Comment on lines +135 to +142
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit tests for the get_spintype function to ensure correctness

The newly added get_spintype function is not covered by unit tests. Implementing tests for this function will help verify that it accurately identifies different spin types based on the provided keys, enhancing the reliability of spin data processing.

Would you like assistance in creating unit tests for this function?

Tools
GitHub Check: codecov/patch

[warning] 136-141: dpdata/lammps/dump.py#L136-L141
Added lines #L136 - L141 were not covered by tests


def safe_get_spin_force(lines):
blk, head = _get_block(lines, "ATOMS")
keys = head.split()
sp_type = get_spintype(keys)
assert sp_type is not None, "Dump file does not contain spin!"
id_idx = keys.index("id") - 2
sp = keys.index(sp_type[0]) - 2
spx = keys.index(sp_type[1]) - 2
spy = keys.index(sp_type[2]) - 2
spz = keys.index(sp_type[3]) - 2
sp_force = []
for ii in blk:
words = ii.split()
sp_force.append(

Check warning on line 157 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L145-L157

Added lines #L145 - L157 were not covered by tests
[
float(words[id_idx]),
float(words[sp]),
float(words[spx]),
float(words[spy]),
float(words[spz]),
]
)
sp_force.sort()
sp_force = np.array(sp_force)[:, 1:]
return sp_force

Check warning on line 168 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L166-L168

Added lines #L166 - L168 were not covered by tests

Comment on lines +144 to +169
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit tests for safe_get_spin_force to validate spin force extraction

The safe_get_spin_force function is crucial for extracting spin force data from the dump files, but it currently lacks test coverage. Adding unit tests will ensure it handles various spin data formats correctly and robustly detects cases where spin data is missing or malformed.

I can help create unit tests for this function if you'd like.

Tools
GitHub Check: codecov/patch

[warning] 145-157: dpdata/lammps/dump.py#L145-L157
Added lines #L145 - L157 were not covered by tests


[warning] 166-168: dpdata/lammps/dump.py#L166-L168
Added lines #L166 - L168 were not covered by tests


def get_dumpbox(lines):
blk, h = _get_block(lines, "BOX BOUNDS")
bounds = np.zeros([3, 2])
Expand Down Expand Up @@ -216,6 +252,12 @@
system["cells"] = [np.array(cell)]
system["atom_types"] = get_atype(lines, type_idx_zero=type_idx_zero)
system["coords"] = [safe_get_posi(lines, cell, np.array(orig), unwrap)]
contain_spin = False
blk, head = _get_block(lines, "ATOMS")
if "sp" in head:
contain_spin = True
spin_force = safe_get_spin_force(lines)
system["spins"] = [spin_force[:, :1] * spin_force[:, 1:4]]

Check warning on line 260 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L258-L260

Added lines #L258 - L260 were not covered by tests
Comment on lines +255 to +260
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use get_spintype for consistent spin data detection

Currently, the code checks if "sp" in head to determine the presence of spin data. Since spin data may be indicated by different keys (e.g., "sp", "spx", "c_spin[1]"), using the get_spintype function ensures consistent detection across various formats.

Apply this diff to improve spin data detection:

         contain_spin = False
         blk, head = _get_block(lines, "ATOMS")
+        keys = head.split()
+        sp_type = get_spintype(keys)
+        if sp_type is not None:
             contain_spin = True
             spin_force = safe_get_spin_force(lines)
             system["spins"] = [spin_force[:, :1] * spin_force[:, 1:4]]
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
contain_spin = False
blk, head = _get_block(lines, "ATOMS")
if "sp" in head:
contain_spin = True
spin_force = safe_get_spin_force(lines)
system["spins"] = [spin_force[:, :1] * spin_force[:, 1:4]]
contain_spin = False
blk, head = _get_block(lines, "ATOMS")
keys = head.split()
sp_type = get_spintype(keys)
if sp_type is not None:
contain_spin = True
spin_force = safe_get_spin_force(lines)
system["spins"] = [spin_force[:, :1] * spin_force[:, 1:4]]
Tools
GitHub Check: codecov/patch

[warning] 258-260: dpdata/lammps/dump.py#L258-L260
Added lines #L258 - L260 were not covered by tests

for ii in range(1, len(array_lines)):
bounds, tilt = get_dumpbox(array_lines[ii])
orig, cell = dumpbox2box(bounds, tilt)
Expand All @@ -228,6 +270,13 @@
system["coords"].append(
safe_get_posi(array_lines[ii], cell, np.array(orig), unwrap)[idx]
)
if contain_spin:
system["spins"].append(

Check warning on line 274 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L274

Added line #L274 was not covered by tests
safe_get_spin_force(array_lines[ii])[:, :1]
* safe_get_spin_force(array_lines[ii])[:, 1:4]
)
Comment on lines +273 to +277
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize by storing the result of safe_get_spin_force

In the loop, safe_get_spin_force(array_lines[ii]) is called twice, leading to unnecessary computation. Storing the result in a variable enhances performance by avoiding redundant function calls.

Apply this diff to optimize the code:

             if contain_spin:
+                spin_force = safe_get_spin_force(array_lines[ii])
                 system["spins"].append(
-                    safe_get_spin_force(array_lines[ii])[:, :1]
-                    * safe_get_spin_force(array_lines[ii])[:, 1:4]
+                    spin_force[:, :1] * spin_force[:, 1:4]
                 )
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if contain_spin:
system["spins"].append(
safe_get_spin_force(array_lines[ii])[:, :1]
* safe_get_spin_force(array_lines[ii])[:, 1:4]
)
if contain_spin:
spin_force = safe_get_spin_force(array_lines[ii])
system["spins"].append(
spin_force[:, :1] * spin_force[:, 1:4]
)
Tools
GitHub Check: codecov/patch

[warning] 274-274: dpdata/lammps/dump.py#L274
Added line #L274 was not covered by tests

if contain_spin:
system["spins"] = np.array(system["spins"])

Check warning on line 279 in dpdata/lammps/dump.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/dump.py#L279

Added line #L279 was not covered by tests
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests to verify spin data array conversion

Converting system["spins"] to a NumPy array is essential for ensuring data consistency, but this line is not covered by tests. Consider adding tests to validate that spin data is correctly aggregated and converted after processing multiple blocks.

Would you like assistance in creating tests for this functionality?

Tools
GitHub Check: codecov/patch

[warning] 279-279: dpdata/lammps/dump.py#L279
Added line #L279 was not covered by tests

system["cells"] = np.array(system["cells"])
system["coords"] = np.array(system["coords"])
return system
Expand Down
70 changes: 63 additions & 7 deletions dpdata/lammps/lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@
return np.array(posis)


def get_spins(lines):
atom_lines = get_atoms(lines)
if len(atom_lines[0].split()) < 8:
return None
spins_ori = []
spins_norm = []
for ii in atom_lines:
spins_ori.append([float(jj) for jj in ii.split()[5:8]])
spins_norm.append([float(jj) for jj in ii.split()[-1:]])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify extraction of spins_norm by removing unnecessary list comprehension

In line 138, you can simplify the extraction of spins_norm by removing the list comprehension since you're extracting a single float value. This enhances code readability.

Apply this diff to simplify the code:

-        spins_norm.append([float(jj) for jj in ii.split()[-1:]])
+        spins_norm.append(float(ii.split()[-1]))
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
spins_norm.append([float(jj) for jj in ii.split()[-1:]])
spins_norm.append(float(ii.split()[-1]))

spins = np.array(spins_ori) * np.array(spins_norm)
return spins

Check warning on line 140 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L134-L140

Added lines #L134 - L140 were not covered by tests


def get_lmpbox(lines):
box_info = []
tilt = np.zeros(3)
Expand All @@ -151,6 +164,9 @@
def system_data(lines, type_map=None, type_idx_zero=True):
system = {}
system["atom_numbs"] = get_natoms_vec(lines)
spins = get_spins(lines)
if spins is not None:
system["spins"] = np.array([spins])

Check warning on line 169 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L169

Added line #L169 was not covered by tests
system["atom_names"] = []
if type_map is None:
for ii in range(len(system["atom_numbs"])):
Expand Down Expand Up @@ -216,14 +232,54 @@
+ ptr_float_fmt
+ "\n"
)
for ii in range(natoms):
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
if "spins" in system.keys():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify dictionary key check by removing unnecessary .keys() call

In line 235, you can check for the existence of a key in a dictionary directly using if "spins" in system: without calling .keys(). This is more efficient and Pythonic.

Apply this diff to simplify the condition:

-    if "spins" in system.keys():
+    if "spins" in system:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "spins" in system.keys():
if "spins" in system:
Tools
Ruff

235-235: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

coord_fmt = (

Check warning on line 236 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L236

Added line #L236 was not covered by tests
coord_fmt.strip("\n")
+ " "
+ ptr_float_fmt
+ " "
+ ptr_float_fmt
+ " "
+ ptr_float_fmt
+ " "
+ ptr_float_fmt
+ "\n"
)
spins_norm = np.linalg.norm(system["spins"][f_idx], axis=1)

Check warning on line 248 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L248

Added line #L248 was not covered by tests
for ii in range(natoms):
if "spins" in system.keys():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify dictionary key check by removing unnecessary .keys() call

Similarly, in line 250, you can simplify the key check in the same manner.

Apply this diff:

-        if "spins" in system.keys():
+        if "spins" in system:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "spins" in system.keys():
if "spins" in system:
Tools
Ruff

250-250: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

if spins_norm[ii] != 0:
ret += coord_fmt % (

Check warning on line 252 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L251-L252

Added lines #L251 - L252 were not covered by tests
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0] / spins_norm[ii],
system["spins"][f_idx][ii][1] / spins_norm[ii],
system["spins"][f_idx][ii][2] / spins_norm[ii],
spins_norm[ii],
)
else:
ret += coord_fmt % (

Check warning on line 264 in dpdata/lammps/lmp.py

View check run for this annotation

Codecov / codecov/patch

dpdata/lammps/lmp.py#L264

Added line #L264 was not covered by tests
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0],
system["spins"][f_idx][ii][1],
system["spins"][f_idx][ii][2] + 1,
spins_norm[ii],
)
Comment on lines +251 to +274
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential logic error when handling zero spin norms

In the from_system_data function, when spins_norm[ii] is zero, the code adds 1 to the z-component of the spin vector on line 272. This may be unintended and could lead to incorrect spin values being written. Please verify if adding 1 to the z-component is necessary.

If unintended, apply this diff to correct the code:

-                        system["spins"][f_idx][ii][2] + 1,
+                        system["spins"][f_idx][ii][2],
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if spins_norm[ii] != 0:
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0] / spins_norm[ii],
system["spins"][f_idx][ii][1] / spins_norm[ii],
system["spins"][f_idx][ii][2] / spins_norm[ii],
spins_norm[ii],
)
else:
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0],
system["spins"][f_idx][ii][1],
system["spins"][f_idx][ii][2] + 1,
spins_norm[ii],
)
if spins_norm[ii] != 0:
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0] / spins_norm[ii],
system["spins"][f_idx][ii][1] / spins_norm[ii],
system["spins"][f_idx][ii][2] / spins_norm[ii],
spins_norm[ii],
)
else:
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
system["spins"][f_idx][ii][0],
system["spins"][f_idx][ii][1],
system["spins"][f_idx][ii][2],
spins_norm[ii],
)
Tools
GitHub Check: codecov/patch

[warning] 251-252: dpdata/lammps/lmp.py#L251-L252
Added lines #L251 - L252 were not covered by tests


[warning] 264-264: dpdata/lammps/lmp.py#L264
Added line #L264 was not covered by tests

else:
ret += coord_fmt % (
ii + 1,
system["atom_types"][ii] + 1,
system["coords"][f_idx][ii][0] - system["orig"][0],
system["coords"][f_idx][ii][1] - system["orig"][1],
system["coords"][f_idx][ii][2] - system["orig"][2],
)
return ret


Expand Down
107 changes: 107 additions & 0 deletions dpdata/plugins/vasp_deltaspin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import re

import numpy as np

import dpdata.vasp_deltaspin.outcar
import dpdata.vasp_deltaspin.poscar
from dpdata.format import Format
from dpdata.utils import uniq_atom_names


@Format.register("vasp_deltaspin/poscar")
@Format.register("vasp_deltaspin/contcar")
class VASPPoscarFormat(Format):
@Format.post("rot_lower_triangular")
def from_system(self, file_name, **kwargs):
with open(file_name) as fp:
lines = [line.rstrip("\n") for line in fp]
with open(file_name[:-6] + "INCAR") as fp:
lines_incar = [line.rstrip("\n") for line in fp]
data = dpdata.vasp_deltaspin.poscar.to_system_data(lines, lines_incar)
data = uniq_atom_names(data)
return data

Check warning on line 24 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L18-L24

Added lines #L18 - L24 were not covered by tests
Comment on lines +18 to +24
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure robust error handling when opening files

Currently, the code opens file_name and file_name[:-6] + "INCAR" without handling potential exceptions if the files do not exist or cannot be accessed. This may lead to unhandled exceptions and crash the program.

Consider adding error handling to manage file-related exceptions and provide meaningful error messages.

+        import os
+        if not os.path.exists(file_name):
+            raise FileNotFoundError(f"File not found: {file_name}")
+        if not os.path.exists(file_name[:-6] + "INCAR"):
+            raise FileNotFoundError(f"File not found: {file_name[:-6] + 'INCAR'}")
+
         with open(file_name) as fp:
             lines = [line.rstrip("\n") for line in fp]
         with open(file_name[:-6] + "INCAR") as fp:
             lines_incar = [line.rstrip("\n") for line in fp]

Committable suggestion was skipped due to low confidence.

Tools
GitHub Check: codecov/patch

[warning] 18-24: dpdata/plugins/vasp_deltaspin.py#L18-L24
Added lines #L18 - L24 were not covered by tests


def to_system(self, data, file_name, frame_idx=0, **kwargs):
"""Dump the system in vasp POSCAR format.

Parameters
----------
data : dict
The system data
file_name : str
The output file name
frame_idx : int
The index of the frame to dump
**kwargs : dict
other parameters
"""
w_str, m_str = VASPStringFormat().to_system(data, frame_idx=frame_idx)
with open(file_name, "w") as fp:
fp.write(w_str)

Check warning on line 42 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L40-L42

Added lines #L40 - L42 were not covered by tests

with open(file_name[:-6] + "INCAR") as fp:
tmp_incar = fp.read()
res_incar = re.sub(

Check warning on line 46 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L44-L46

Added lines #L44 - L46 were not covered by tests
Comment on lines +44 to +46
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling when reading 'INCAR' file

When opening file_name[:-6] + "INCAR", there is no check to handle situations where the file might not exist or be inaccessible, which could result in an unhandled exception.

Include error handling to gracefully handle potential file access issues.

+        import os
+        incar_path = file_name[:-6] + "INCAR"
+        if not os.path.exists(incar_path):
+            raise FileNotFoundError(f"File not found: {incar_path}")
+
         with open(incar_path) as fp:
             tmp_incar = fp.read()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with open(file_name[:-6] + "INCAR") as fp:
tmp_incar = fp.read()
res_incar = re.sub(
import os
incar_path = file_name[:-6] + "INCAR"
if not os.path.exists(incar_path):
raise FileNotFoundError(f"File not found: {incar_path}")
with open(incar_path) as fp:
tmp_incar = fp.read()
res_incar = re.sub(
Tools
GitHub Check: codecov/patch

[warning] 44-46: dpdata/plugins/vasp_deltaspin.py#L44-L46
Added lines #L44 - L46 were not covered by tests

r"MAGMOM[\s\S]*?\n\nM_CONST[\s\S]*?\n\n", m_str, tmp_incar, re.S
)
Comment on lines +46 to +48
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 'flags' as a keyword argument in 're.sub'

To improve code readability and avoid confusion with positional arguments, pass the flags parameter as a keyword argument in the re.sub function.

Apply this diff to address the concern:

         res_incar = re.sub(
             r"MAGMOM[\s\S]*?\n\nM_CONST[\s\S]*?\n\n", m_str, tmp_incar, flags=re.S
         )
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
res_incar = re.sub(
r"MAGMOM[\s\S]*?\n\nM_CONST[\s\S]*?\n\n", m_str, tmp_incar, re.S
)
res_incar = re.sub(
r"MAGMOM[\s\S]*?\n\nM_CONST[\s\S]*?\n\n", m_str, tmp_incar, flags=re.S
)
Tools
Ruff

46-48: re.sub should pass count and flags as keyword arguments to avoid confusion due to unintuitive argument positions

(B034)

with open(file_name[:-6] + "INCAR", "w") as fp:
fp.write(res_incar)

Check warning on line 50 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L49-L50

Added lines #L49 - L50 were not covered by tests


@Format.register("vasp/string")
class VASPStringFormat(Format):
def to_system(self, data, frame_idx=0, **kwargs):
"""Dump the system in vasp POSCAR format string.

Parameters
----------
data : dict
The system data
frame_idx : int
The index of the frame to dump
**kwargs : dict
other parameters
"""
assert frame_idx < len(data["coords"])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace 'assert' with exception handling for input validation

Using assert statements for input validation is not recommended in production code because assertions can be disabled with optimization flags (-O), potentially skipping critical checks.

Replace the assert statement with explicit exception handling to ensure the validation is always performed.

-        assert frame_idx < len(data["coords"])
+        if frame_idx >= len(data["coords"]):
+            raise IndexError(f"frame_idx {frame_idx} is out of range.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert frame_idx < len(data["coords"])
if frame_idx >= len(data["coords"]):
raise IndexError(f"frame_idx {frame_idx} is out of range.")
Tools
GitHub Check: codecov/patch

[warning] 67-68: dpdata/plugins/vasp_deltaspin.py#L67-L68
Added lines #L67 - L68 were not covered by tests

return dpdata.vasp_deltaspin.poscar.from_system_data(data, frame_idx)

Check warning on line 68 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L67-L68

Added lines #L67 - L68 were not covered by tests


# rotate the system to lammps convention
@Format.register("vasp_deltaspin/outcar")
class VASPOutcarFormat(Format):
@Format.post("rot_lower_triangular")
def from_labeled_system(
self, file_name, begin=0, step=1, convergence_check=True, **kwargs
):
data = {}
ml = kwargs.get("ml", False)
(

Check warning on line 80 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L78-L80

Added lines #L78 - L80 were not covered by tests
data["atom_names"],
data["atom_numbs"],
data["atom_types"],
data["cells"],
data["coords"],
data["spins"],
data["energies"],
data["forces"],
data["mag_forces"],
tmp_virial,
) = dpdata.vasp_deltaspin.outcar.get_frames(
file_name,
begin=begin,
step=step,
ml=ml,
convergence_check=convergence_check,
)
if tmp_virial is not None:
data["virials"] = tmp_virial

Check warning on line 99 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L98-L99

Added lines #L98 - L99 were not covered by tests
# scale virial to the unit of eV
if "virials" in data:
v_pref = 1 * 1e3 / 1.602176621e6
for ii in range(data["cells"].shape[0]):
vol = np.linalg.det(np.reshape(data["cells"][ii], [3, 3]))
data["virials"][ii] *= v_pref * vol
data = uniq_atom_names(data)
return data

Check warning on line 107 in dpdata/plugins/vasp_deltaspin.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/vasp_deltaspin.py#L101-L107

Added lines #L101 - L107 were not covered by tests
10 changes: 10 additions & 0 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class System:
DataType(
"coords", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="coord"
),
DataType("spins", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), required=False),
DataType(
"real_atom_types", np.ndarray, (Axis.NFRAMES, Axis.NATOMS), required=False
),
Expand Down Expand Up @@ -712,6 +713,10 @@ def affine_map(self, trans, f_idx: int | numbers.Integral = 0):
assert np.linalg.det(trans) != 0
self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans)
self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans)
try:
self.data["spins"][f_idx] = np.matmul(self.data["spins"][f_idx], trans)
except:
pass
Comment on lines +716 to +719
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use contextlib.suppress(Exception) and avoid bare except.

Consider the following improvements:

  1. Replace the try-except-pass pattern with contextlib.suppress(Exception) for a more explicit and concise way to handle the specific exception.
  2. Avoid using a bare except as it can catch unintended exceptions and make debugging difficult. Specify the expected exception type instead.

Apply this diff to address the suggestions:

-        try:
-            self.data["spins"][f_idx] = np.matmul(self.data["spins"][f_idx], trans)
-        except:
-            pass
+        with contextlib.suppress(KeyError):
+            self.data["spins"][f_idx] = np.matmul(self.data["spins"][f_idx], trans)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
self.data["spins"][f_idx] = np.matmul(self.data["spins"][f_idx], trans)
except:
pass
with contextlib.suppress(KeyError):
self.data["spins"][f_idx] = np.matmul(self.data["spins"][f_idx], trans)
Tools
Ruff

716-719: Use contextlib.suppress(Exception) instead of try-except-pass

Replace with contextlib.suppress(Exception)

(SIM105)


718-718: Do not use bare except

(E722)


@post_funcs.register("shift_orig_zero")
def _shift_orig_zero(self):
Expand Down Expand Up @@ -1210,6 +1215,9 @@ class LabeledSystem(System):
DataType(
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
),
DataType(
"mag_forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), required=False
),
DataType(
"virials",
np.ndarray,
Expand Down Expand Up @@ -1793,3 +1801,5 @@ def to_format(self, *args, **kwargs):


add_format_methods()

# %%
Empty file.
Loading