diff --git a/pyproject.toml b/pyproject.toml index d8d5cdf..6af4ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ exclude_lines = [ ] [tool.ruff] -line-length = 120 +line-length = 140 [tool.ruff.lint] -select = ["A", "ARG","ASYNC", "E", "F", "I", "N", "UP032", "W"] +select = ["A", "ARG","ASYNC", "E", "F", "I", "UP032", "W"] diff --git a/src/pleiades/sammy/io/parameter.py b/src/pleiades/sammy/io/parameter.py index 99a4e63..ad3b466 100644 --- a/src/pleiades/sammy/io/parameter.py +++ b/src/pleiades/sammy/io/parameter.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python -"""Classes for handling SAMMY parameter files.""" -# (originally sammyParFile.py) +# sammyParFile.py # Version: 1.0 # Authors: # - Alexander M. Long @@ -17,14 +15,11 @@ # How to use: # - Import the this class with 'from pleiades import sammyParFile' +# General imports import configparser import json -import logging import pathlib import re -from copy import deepcopy - -logger = logging.getLogger(__name__) class ParFile: @@ -505,6 +500,8 @@ def __add__(self, isotope: "ParFile") -> "ParFile": Returns: parFile: combined parFile instance """ + from copy import deepcopy + compound = deepcopy(self) # only add an isotope if resonances exists in the specified energy range @@ -550,11 +547,9 @@ def _parse_spin_group_cards(self) -> None: """parse a list of spin_group cards, sort the key-word pairs of groups and channels Args: - - spin_group_cards (list): list of strings containing the lines associated with - spin-groups and/or spin-channels + - spin_group_cards (list): list of strings containing the lines associated with spin-groups and/or spin-channels - Returns: (list of dicts): list containing groups, each group is a dictionary containing - key-value dicts for spin_groups and channels + Returns: (list of dicts): list containing groups, each group is a dictionary containing key-value dicts for spin_groups and channels """ sg_dict = [] lines = (line for line in self._spin_group_cards) # convert to a generator object @@ -694,9 +689,7 @@ def _write_channel_radii(self, channel_radii_dict: dict) -> str: def _read_resonance_params(self, resonance_params_line: str) -> dict: # parse key-word pairs from a resonance_params line - resonance_params_dict = { - key: resonance_params_line[value] for key, value in self._RESONANCE_PARAMS_FORMAT.items() - } + resonance_params_dict = {key: resonance_params_line[value] for key, value in self._RESONANCE_PARAMS_FORMAT.items()} return resonance_params_dict def _write_resonance_params(self, resonance_params_dict: dict) -> str: @@ -760,7 +753,7 @@ def _write_misc_tzero(self, misc_tzero_dict: dict) -> str: new_text[slice_value] = list(str(misc_tzero_dict[key]).ljust(word_length)) return "".join(new_text) - def _write_misc_deltE(self, misc_deltE_dict: dict) -> str: # noqa: N802, N803 + def _write_misc_deltE(self, misc_deltE_dict: dict) -> str: # write a formatted misc_deltE line from dict with the key-word misc_deltE values new_text = [" "] * 80 # 80 characters long list of spaces to be filled new_text[:5] = list("DELTE") @@ -813,7 +806,7 @@ def bump_group_number(self, increment: int = 0) -> None: spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]] sg_formatted = "".join(spin_groups[:8]).ljust(43) - L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # # noqa: N806 + L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed for l in range(0, L + 1): # noqa: E741 sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78) isotope["spin_groups"] = sg_formatted @@ -845,7 +838,7 @@ def isotopic_masses_abundance(self) -> None: spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]] sg_formatted = "".join(spin_groups[:8]).ljust(43) - L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # noqa: N806 + L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed for l in range(0, L + 1): # noqa: E741 sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78) @@ -864,14 +857,12 @@ def define_as_element(self, name: str, weight: float = 1.0) -> None: spin_groups = [f"{group[0]['group_number'].strip():>5}" for group in self.parent.data["spin_group"]] sg_formatted = "".join(spin_groups[:8]).ljust(43) - L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # # noqa: N806 + L = (len(spin_groups) - 8) // 15 if len(spin_groups) > 8 else -1 # number of extra lines needed for l in range(0, L + 1): # noqa: E741 sg_formatted += "-1\n" + "".join(spin_groups[8 + 15 * l : 8 + 15 * (l + 1)]).ljust(78) aw = [float(i["mass_b"]) for i in self.parent.data["particle_pairs"]] - weights = [ - float(self.parent.data["isotopic_masses"][i]["abundance"]) for i in self.parent.data["isotopic_masses"] - ] + weights = [float(self.parent.data["isotopic_masses"][i]["abundance"]) for i in self.parent.data["isotopic_masses"]] aw = average(aw, weights=weights) iso_dict = { @@ -1038,9 +1029,7 @@ def normalization(self, **kwargs) -> None: "vary_exp_decay_bg": 0, } - self.parent.data["normalization"].update( - {key: value for key, value in kwargs.items() if key in self.parent.data["normalization"]} - ) + self.parent.data["normalization"].update({key: value for key, value in kwargs.items() if key in self.parent.data["normalization"]}) def broadening(self, **kwargs) -> None: """change or vary broadening parameters and vary flags @@ -1074,9 +1063,7 @@ def broadening(self, **kwargs) -> None: "vary_deltae_us": 0, } - self.parent.data["broadening"].update( - {key: value for key, value in kwargs.items() if key in self.parent.data["broadening"]} - ) + self.parent.data["broadening"].update({key: value for key, value in kwargs.items() if key in self.parent.data["broadening"]}) def misc(self, **kwargs) -> None: """change or vary misc parameters and vary flags @@ -1118,9 +1105,7 @@ def misc(self, **kwargs) -> None: "DlnE": "", } - self.parent.data["misc"].update( - {key: value for key, value in kwargs.items() if key in self.parent.data["misc"]} - ) + self.parent.data["misc"].update({key: value for key, value in kwargs.items() if key in self.parent.data["misc"]}) def resolution(self, **kwargs) -> None: """change or vary resolution parameters and vary flags @@ -1253,9 +1238,7 @@ def resolution(self, **kwargs) -> None: "chann": "", } - self.parent.data["resolution"].update( - {key: value for key, value in kwargs.items() if key in self.parent.data["resolution"]} - ) + self.parent.data["resolution"].update({key: value for key, value in kwargs.items() if key in self.parent.data["resolution"]}) def vary_all(self, vary=True, data_key="misc_delta"): """toggle all vary parameters in a data keyword to either vary/fixed diff --git a/tests/unit/pleiades/sammy/io/test_parameter.py b/tests/unit/pleiades/sammy/io/test_parameter.py new file mode 100644 index 0000000..c823b27 --- /dev/null +++ b/tests/unit/pleiades/sammy/io/test_parameter.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +"""Unit tests for SAMMY parameter file handling.""" + +import pytest + +from pleiades.sammy.io.parameter import ParFile + + +@pytest.fixture +def sample_par_content(): + """Sample parameter file content for testing.""" + return ( + "SPIN GROUP INFORMATION\n" + " 1 2 2 0.5 1.000000\n" # Fixed format numbers + " 1 PPair1 1 0.5 0.0000 0.0000\n" + " 2 PPair1 1 0.5 0.0000 0.0000\n" + "\n" + "RESONANCE PARAMETERS\n" + " 6.6720+0 2.3000-2 1.4760-3 0.0000+0 0.0000+0 1 1 1 0 0 1\n" + " 2.0870+1 2.3000-2 1.0880-3 0.0000+0 0.0000+0 0 0 1 0 0 1\n" + "\n" + "PARTICLE PAIR DEFINITIONS\n" + "Name=PPair1 Particle a=neutron Particle b=U-238 \n" + " Za=0 Zb=92 Pent=1 Shift=0\n" + " Sa=0.5 Sb=0.0 Ma=1.008664915 Mb=238.0508\n" + "\n" + "0.1\n" + ) + + +@pytest.fixture +def tmp_par_file(tmp_path, sample_par_content): + """Create a temporary parameter file.""" + par_file = tmp_path / "test.par" + par_file.write_text(sample_par_content) + return par_file + + +@pytest.fixture +def par_file(tmp_par_file): + """Create ParFile instance for testing.""" + return ParFile(filename=str(tmp_par_file)).read() + + +class TestParFile: + """Test ParFile class functionality.""" + + def test_init(self, tmp_par_file): + """Test ParFile initialization.""" + par = ParFile(filename=str(tmp_par_file)) + assert par.filename == str(tmp_par_file) + assert par.weight == 1.0 + assert par.name == "auto" + assert par.emin == 0.001 + assert par.emax == 100 + + def test_read_file(self, par_file): + """Test reading parameter file.""" + # Check particle pairs + assert len(par_file.data["particle_pairs"]) == 1 + + # Each spin group can have multiple channels, stored as nested list + assert len(par_file.data["spin_group"]) > 0 + + # Check first spin group format + first_group = par_file.data["spin_group"][0] + assert "group_number" in first_group[0] + assert "spin" in first_group[0] + assert "isotopic_abundance" in first_group[0] + + def test_write_file(self, par_file, tmp_path): + """Test writing parameter file.""" + print(par_file.data) + # First validate input structure + first_group = par_file.data["spin_group"][0][0] + assert first_group["n_entrance_channel"].strip() != "" + assert first_group["n_exit_channel"].strip() != "" + + # Test write/read cycle + output_file = tmp_path / "output.par" + par_file.write(output_file) + + # Read back and verify + new_par = ParFile(filename=str(output_file)).read() + assert len(new_par.data["resonance_params"]) == len(par_file.data["resonance_params"]) + assert len(new_par.data["spin_group"]) == len(par_file.data["spin_group"]) + + +# def test_combine_parameters(self, par_file, tmp_par_file): +# """Test combining parameter files.""" +# other_par = ParFile(filename=str(tmp_par_file)).read() +# combined = par_file + other_par + +# # Check combined data +# assert len(combined.data["resonance_params"]) == 4 # 2 from each +# assert len(combined.data["spin_group"]) == 2 + + +# class TestUpdate: +# """Test Update class functionality.""" + +# def test_vary_resonances(self, par_file): +# """Test varying resonance parameters.""" +# # Set all resonances to vary +# par_file.update.vary_all_resonances(vary=True) + +# for res in par_file.data["resonance_params"]: +# assert res["vary_energy"] == "1" +# assert res["vary_capture_width"] == "1" +# assert res["vary_neutron_width"] == "1" + +# def test_limit_energies(self, par_file): +# """Test limiting energy range.""" +# # Set energy limits to exclude some resonances +# par_file.update.limit_energies_of_parfile() + +# for res in par_file.data["resonance_params"]: +# energy = float(res["reosnance_energy"]) +# assert par_file.emin <= energy <= par_file.emax + +# def test_isotopic_weight(self, par_file): +# """Test updating isotopic weight.""" +# new_weight = 0.5 +# par_file.weight = new_weight +# par_file.update.isotopic_weight() + +# # Check updated weights +# for group in par_file.data["spin_group"]: +# assert float(group[0]["isotopic_abundance"]) == new_weight + +# @pytest.mark.parametrize("vary", [True, False]) +# def test_vary_all(self, par_file, vary): +# """Test varying all parameters.""" +# par_file.update.vary_all(vary=vary, data_key="normalization") + +# # Check normalization vary flags +# norm = par_file.data["normalization"] +# for key in [k for k in norm if k.startswith("vary_")]: +# assert int(norm[key]) == int(vary) + + +# def test_error_handling(tmp_path): +# """Test error handling for invalid files.""" +# # Test invalid file +# with pytest.raises(FileNotFoundError): +# ParFile(filename=str(tmp_path / "nonexistent.par")).read() + +# # Test invalid format +# invalid_file = tmp_path / "invalid.par" +# invalid_file.write_text("Invalid content") +# with pytest.raises(Exception): +# ParFile(filename=str(invalid_file)).read() + + +if __name__ == "__main__": + pytest.main(["-v", __file__])