Skip to content

Commit

Permalink
Support promolecules with single atom (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
msricher authored Dec 16, 2024
1 parent abf1008 commit 511a702
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
17 changes: 14 additions & 3 deletions atomdb/promolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

from copy import deepcopy
from itertools import chain, combinations
from numbers import Integral
from numbers import Integral, Number
from operator import itemgetter
from warnings import warn

import numpy as np
from scipy.optimize import linprog

from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_REMOTE, DEFAULT_DATASET, MULTIPLICITIES
from atomdb.periodic import element_number, element_symbol
from atomdb.species import load
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE, MULTIPLICITIES

__all__ = [
"Promolecule",
Expand Down Expand Up @@ -566,6 +566,10 @@ def make_promolecule(
Promolecule instance.
"""
# Convert single coord [x, y, z] to list of coords [[x, y, z]]
coords = np.asarray(coords, dtype=float)
if coords.ndim == 1:
coords = coords.reshape(1, -1)
# Check coordinate units
if units is None or units.lower() == "bohr":
coords = [coord / 1 for coord in coords]
Expand All @@ -574,13 +578,18 @@ def make_promolecule(
else:
raise ValueError(f"Invalid `units` parameter '{units}'; " "must be 'bohr' or 'angstrom'")

# Convert single atnum to list of atnums [atnum]
if isinstance(atnums, (Integral, str)):
atnums = [atnums]
# Get atomic symbols/numbers from inputs
atoms = [element_symbol(atom) for atom in atnums]
atnums = [element_number(atom) for atom in atnums]
atoms = [element_symbol(atom) for atom in atnums]

# Handle default charge parameters
if charges is None:
charges = [0 for _ in atnums]
elif isinstance(charges, Number):
charges = [charges]

# Handle default multiplicity parameters
if mults is None:
Expand All @@ -590,6 +599,8 @@ def make_promolecule(
else:
# set each multiplicity to None
mults = [None for _ in atnums]
elif isinstance(mults, Number):
mults = [mults]

# Construct linear combination of species
promol = Promolecule()
Expand Down
22 changes: 19 additions & 3 deletions atomdb/test/test_promolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
#
# --

import pytest

import os
from numbers import Number

import numpy as np

import numpy.testing as npt
import pytest

from atomdb import make_promolecule

Expand Down Expand Up @@ -95,6 +94,16 @@
},
id="Be floating point charge/floating point mult",
),
pytest.param(
{
"atnums": 4,
"charges": 1.2,
"mults": 1.2,
"coords": np.asarray([0.0, 0.0, 0.0], dtype=float),
"dataset": "uhf_augccpvdz",
},
id="Be floating point charge/floating point mult",
),
pytest.param(
{
"atnums": [4],
Expand Down Expand Up @@ -167,6 +176,13 @@ def test_make_promolecule(case):
remotepath=None,
)

if isinstance(atnums, (Number, str)):
atnums = [atnums]
if isinstance(charges, Number):
charges = [charges]
if isinstance(mults, Number):
mults = [mults]

# Check that coefficients add up to (# centers)
npt.assert_allclose(sum(promol.coeffs), len(atnums))

Expand Down
4 changes: 2 additions & 2 deletions website/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
jupyter-book
sphinx
matplotlib
numpy>= 1.10
scipy>=1.0, <=1.10.1
numpy>=1.16
scipy>=1.4

0 comments on commit 511a702

Please sign in to comment.