Skip to content

Commit

Permalink
switch back to original parFile parser
Browse files Browse the repository at this point in the history
  • Loading branch information
KedoKudo committed Jan 1, 2025
1 parent fdfdf2f commit cfb252a
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 35 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ exclude_lines = [
]

[tool.ruff]
line-length = 120
line-length = 140

[tool.ruff.lint]
select = ["A", "ARG","ASYNC", "E", "F", "I", "N", "UP032", "W"]
select = ["A", "ARG","ASYNC", "E", "F", "I", "UP032", "W"]
49 changes: 16 additions & 33 deletions src/pleiades/sammy/io/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#!/usr/bin/env python
"""Classes for handling SAMMY parameter files."""
# (originally sammyParFile.py)
# sammyParFile.py
# Version: 1.0
# Authors:
# - Alexander M. Long
Expand All @@ -17,14 +15,11 @@
# How to use:
# - Import the this class with 'from pleiades import sammyParFile'

# General imports
import configparser
import json
import logging
import pathlib
import re
from copy import deepcopy

logger = logging.getLogger(__name__)


class ParFile:
Expand Down Expand Up @@ -505,6 +500,8 @@ def __add__(self, isotope: "ParFile") -> "ParFile":
Returns:
parFile: combined parFile instance
"""
from copy import deepcopy

compound = deepcopy(self)

# only add an isotope if resonances exists in the specified energy range
Expand Down Expand Up @@ -550,11 +547,9 @@ def _parse_spin_group_cards(self) -> None:
"""parse a list of spin_group cards, sort the key-word pairs of groups and channels
Args:
- spin_group_cards (list): list of strings containing the lines associated with
spin-groups and/or spin-channels
- spin_group_cards (list): list of strings containing the lines associated with spin-groups and/or spin-channels
Returns: (list of dicts): list containing groups, each group is a dictionary containing
key-value dicts for spin_groups and channels
Returns: (list of dicts): list containing groups, each group is a dictionary containing key-value dicts for spin_groups and channels
"""
sg_dict = []
lines = (line for line in self._spin_group_cards) # convert to a generator object
Expand Down Expand Up @@ -694,9 +689,7 @@ def _write_channel_radii(self, channel_radii_dict: dict) -> str:

def _read_resonance_params(self, resonance_params_line: str) -> dict:
# parse key-word pairs from a resonance_params line
resonance_params_dict = {
key: resonance_params_line[value] for key, value in self._RESONANCE_PARAMS_FORMAT.items()
}
resonance_params_dict = {key: resonance_params_line[value] for key, value in self._RESONANCE_PARAMS_FORMAT.items()}
return resonance_params_dict

def _write_resonance_params(self, resonance_params_dict: dict) -> str:
Expand Down Expand Up @@ -760,7 +753,7 @@ def _write_misc_tzero(self, misc_tzero_dict: dict) -> str:
new_text[slice_value] = list(str(misc_tzero_dict[key]).ljust(word_length))
return "".join(new_text)

def _write_misc_deltE(self, misc_deltE_dict: dict) -> str: # noqa: N802, N803
def _write_misc_deltE(self, misc_deltE_dict: dict) -> str:
# write a formatted misc_deltE line from dict with the key-word misc_deltE values
new_text = [" "] * 80 # 80 characters long list of spaces to be filled
new_text[:5] = list("DELTE")
Expand Down Expand Up @@ -813,7 +806,7 @@ def bump_group_number(self, increment: int = 0) -> None:
spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]]

sg_formatted = "".join(spin_groups[:8]).ljust(43)
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # # noqa: N806
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed
for l in range(0, L + 1): # noqa: E741
sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78)
isotope["spin_groups"] = sg_formatted
Expand Down Expand Up @@ -845,7 +838,7 @@ def isotopic_masses_abundance(self) -> None:
spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]]

sg_formatted = "".join(spin_groups[:8]).ljust(43)
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # noqa: N806
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed
for l in range(0, L + 1): # noqa: E741
sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78)

Expand All @@ -864,14 +857,12 @@ def define_as_element(self, name: str, weight: float = 1.0) -> None:
spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]]

sg_formatted = "".join(spin_groups[:8]).ljust(43)
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # # noqa: N806
L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed
for l in range(0, L + 1): # noqa: E741
sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78)

aw = [float(i["mass_b"]) for i in self.parent.data["particle_pairs"]]
weights = [
float(self.parent.data["isotopic_masses"][i]["abundance"]) for i in self.parent.data["isotopic_masses"]
]
weights = [float(self.parent.data["isotopic_masses"][i]["abundance"]) for i in self.parent.data["isotopic_masses"]]
aw = average(aw, weights=weights)

iso_dict = {
Expand Down Expand Up @@ -1038,9 +1029,7 @@ def normalization(self, **kwargs) -> None:
"vary_exp_decay_bg": 0,
}

self.parent.data["normalization"].update(
{key: value for key, value in kwargs.items() if key in self.parent.data["normalization"]}
)
self.parent.data["normalization"].update({key: value for key, value in kwargs.items() if key in self.parent.data["normalization"]})

def broadening(self, **kwargs) -> None:
"""change or vary broadening parameters and vary flags
Expand Down Expand Up @@ -1074,9 +1063,7 @@ def broadening(self, **kwargs) -> None:
"vary_deltae_us": 0,
}

self.parent.data["broadening"].update(
{key: value for key, value in kwargs.items() if key in self.parent.data["broadening"]}
)
self.parent.data["broadening"].update({key: value for key, value in kwargs.items() if key in self.parent.data["broadening"]})

def misc(self, **kwargs) -> None:
"""change or vary misc parameters and vary flags
Expand Down Expand Up @@ -1118,9 +1105,7 @@ def misc(self, **kwargs) -> None:
"DlnE": "",
}

self.parent.data["misc"].update(
{key: value for key, value in kwargs.items() if key in self.parent.data["misc"]}
)
self.parent.data["misc"].update({key: value for key, value in kwargs.items() if key in self.parent.data["misc"]})

def resolution(self, **kwargs) -> None:
"""change or vary resolution parameters and vary flags
Expand Down Expand Up @@ -1253,9 +1238,7 @@ def resolution(self, **kwargs) -> None:
"chann": "",
}

self.parent.data["resolution"].update(
{key: value for key, value in kwargs.items() if key in self.parent.data["resolution"]}
)
self.parent.data["resolution"].update({key: value for key, value in kwargs.items() if key in self.parent.data["resolution"]})

def vary_all(self, vary=True, data_key="misc_delta"):
"""toggle all vary parameters in a data keyword to either vary/fixed
Expand Down
156 changes: 156 additions & 0 deletions tests/unit/pleiades/sammy/io/test_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python
"""Unit tests for SAMMY parameter file handling."""

import pytest

from pleiades.sammy.io.parameter import ParFile


@pytest.fixture
def sample_par_content():
"""Sample parameter file content for testing."""
return (
"SPIN GROUP INFORMATION\n"
" 1 2 2 0.5 1.000000\n" # Fixed format numbers
" 1 PPair1 1 0.5 0.0000 0.0000\n"
" 2 PPair1 1 0.5 0.0000 0.0000\n"
"\n"
"RESONANCE PARAMETERS\n"
" 6.6720+0 2.3000-2 1.4760-3 0.0000+0 0.0000+0 1 1 1 0 0 1\n"
" 2.0870+1 2.3000-2 1.0880-3 0.0000+0 0.0000+0 0 0 1 0 0 1\n"
"\n"
"PARTICLE PAIR DEFINITIONS\n"
"Name=PPair1 Particle a=neutron Particle b=U-238 \n"
" Za=0 Zb=92 Pent=1 Shift=0\n"
" Sa=0.5 Sb=0.0 Ma=1.008664915 Mb=238.0508\n"
"\n"
"0.1\n"
)


@pytest.fixture
def tmp_par_file(tmp_path, sample_par_content):
"""Create a temporary parameter file."""
par_file = tmp_path / "test.par"
par_file.write_text(sample_par_content)
return par_file


@pytest.fixture
def par_file(tmp_par_file):
"""Create ParFile instance for testing."""
return ParFile(filename=str(tmp_par_file)).read()


class TestParFile:
"""Test ParFile class functionality."""

def test_init(self, tmp_par_file):
"""Test ParFile initialization."""
par = ParFile(filename=str(tmp_par_file))
assert par.filename == str(tmp_par_file)
assert par.weight == 1.0
assert par.name == "auto"
assert par.emin == 0.001
assert par.emax == 100

def test_read_file(self, par_file):
"""Test reading parameter file."""
# Check particle pairs
assert len(par_file.data["particle_pairs"]) == 1

# Each spin group can have multiple channels, stored as nested list
assert len(par_file.data["spin_group"]) > 0

# Check first spin group format
first_group = par_file.data["spin_group"][0]
assert "group_number" in first_group[0]
assert "spin" in first_group[0]
assert "isotopic_abundance" in first_group[0]

def test_write_file(self, par_file, tmp_path):
"""Test writing parameter file."""
print(par_file.data)
# First validate input structure
first_group = par_file.data["spin_group"][0][0]
assert first_group["n_entrance_channel"].strip() != ""
assert first_group["n_exit_channel"].strip() != ""

# Test write/read cycle
output_file = tmp_path / "output.par"
par_file.write(output_file)

# Read back and verify
new_par = ParFile(filename=str(output_file)).read()
assert len(new_par.data["resonance_params"]) == len(par_file.data["resonance_params"])
assert len(new_par.data["spin_group"]) == len(par_file.data["spin_group"])


# def test_combine_parameters(self, par_file, tmp_par_file):
# """Test combining parameter files."""
# other_par = ParFile(filename=str(tmp_par_file)).read()
# combined = par_file + other_par

# # Check combined data
# assert len(combined.data["resonance_params"]) == 4 # 2 from each
# assert len(combined.data["spin_group"]) == 2


# class TestUpdate:
# """Test Update class functionality."""

# def test_vary_resonances(self, par_file):
# """Test varying resonance parameters."""
# # Set all resonances to vary
# par_file.update.vary_all_resonances(vary=True)

# for res in par_file.data["resonance_params"]:
# assert res["vary_energy"] == "1"
# assert res["vary_capture_width"] == "1"
# assert res["vary_neutron_width"] == "1"

# def test_limit_energies(self, par_file):
# """Test limiting energy range."""
# # Set energy limits to exclude some resonances
# par_file.update.limit_energies_of_parfile()

# for res in par_file.data["resonance_params"]:
# energy = float(res["reosnance_energy"])
# assert par_file.emin <= energy <= par_file.emax

# def test_isotopic_weight(self, par_file):
# """Test updating isotopic weight."""
# new_weight = 0.5
# par_file.weight = new_weight
# par_file.update.isotopic_weight()

# # Check updated weights
# for group in par_file.data["spin_group"]:
# assert float(group[0]["isotopic_abundance"]) == new_weight

# @pytest.mark.parametrize("vary", [True, False])
# def test_vary_all(self, par_file, vary):
# """Test varying all parameters."""
# par_file.update.vary_all(vary=vary, data_key="normalization")

# # Check normalization vary flags
# norm = par_file.data["normalization"]
# for key in [k for k in norm if k.startswith("vary_")]:
# assert int(norm[key]) == int(vary)


# def test_error_handling(tmp_path):
# """Test error handling for invalid files."""
# # Test invalid file
# with pytest.raises(FileNotFoundError):
# ParFile(filename=str(tmp_path / "nonexistent.par")).read()

# # Test invalid format
# invalid_file = tmp_path / "invalid.par"
# invalid_file.write_text("Invalid content")
# with pytest.raises(Exception):
# ParFile(filename=str(invalid_file)).read()


if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit cfb252a

Please sign in to comment.