Skip to content

Commit

Permalink
Sans2d script transform (#439)
Browse files Browse the repository at this point in the history
* Add SANS2D transforms

* Ruff linting fix

* Formatting and linting commit

* Update fia_api/scripts/transforms/sans_transform.py

* Update fia_api/scripts/transforms/sans_transform.py

Co-authored-by: keiranjprice101 <[email protected]>

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: keiranjprice101 <[email protected]>
  • Loading branch information
3 people authored Jan 27, 2025
1 parent 1303c4e commit 1ff4c48
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 13 deletions.
6 changes: 3 additions & 3 deletions fia_api/scripts/transforms/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging

from fia_api.scripts.transforms.iris_transform import IrisTransform
from fia_api.scripts.transforms.loq_transform import LoqTransform
from fia_api.scripts.transforms.mari_transforms import MariTransform
from fia_api.scripts.transforms.osiris_transform import OsirisTransform
from fia_api.scripts.transforms.sans_transform import SansTransform
from fia_api.scripts.transforms.test_transforms import TestTransform
from fia_api.scripts.transforms.tosca_transform import ToscaTransform
from fia_api.scripts.transforms.transform import MissingTransformError, Transform
Expand All @@ -29,8 +29,8 @@ def get_transform_for_instrument(instrument: str) -> Transform:
return ToscaTransform()
case "osiris":
return OsirisTransform()
case "loq":
return LoqTransform()
case "loq" | "sans2d":
return SansTransform()
case "iris":
return IrisTransform()
case "test":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@


# mypy: disable-error-code="operator, index"
class LoqTransform(Transform):
class SansTransform(Transform):
"""
LoqTransform applies modifications to LOQ instrument scripts based on reduction input parameters in a Job
SansTransform applies modifications to SANS instrument scripts based on reduction input parameters in a Job
entity.
"""

def apply(self, script: PreScript, job: Job) -> None: # noqa: C901, PLR0912
logger.info("Beginning LOQ transform for job %s...", job.id)
logger.info("Beginning %s transform for job %s...", job.instrument, job.id)
lines = script.value.splitlines()
# MyPY does not believe ColumnElement[JSONB] is indexable, despite JSONB implementing the Indexable mixin
# If you get here in the future, try removing the type ignore and see if it passes with newer mypy
for index, line in enumerate(lines):
if "/extras/loq/MaskFile.toml" in line and "user_file" in job.inputs:
lines[index] = line.replace("/extras/loq/MaskFile.toml", job.inputs["user_file"])
if f"/extras/{job.instrument.instrument_name.lower()}/MaskFile.toml" in line and "user_file" in job.inputs:
lines[index] = line.replace(
f"/extras/{job.instrument.instrument_name.lower()}/MaskFile.toml", job.inputs["user_file"]
)
continue
if "run_number" in job.inputs and self._replace_input(
line, lines, index, "sample_scatter", job.inputs["run_number"]
Expand Down
3 changes: 3 additions & 0 deletions test/scripts/transforms/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fia_api.scripts.transforms.iris_transform import IrisTransform
from fia_api.scripts.transforms.mari_transforms import MariTransform
from fia_api.scripts.transforms.osiris_transform import OsirisTransform
from fia_api.scripts.transforms.sans_transform import SansTransform
from fia_api.scripts.transforms.test_transforms import TestTransform
from fia_api.scripts.transforms.tosca_transform import ToscaTransform
from fia_api.scripts.transforms.transform import MissingTransformError
Expand All @@ -21,6 +22,8 @@
("test", TestTransform),
("osiris", OsirisTransform),
("iris", IrisTransform),
("loq", SansTransform),
("sans2d", SansTransform),
],
)
def test_transform_factory(name, expected_transform):
Expand Down
12 changes: 7 additions & 5 deletions test/scripts/transforms/test_loq_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from fia_api.scripts.pre_script import PreScript
from fia_api.scripts.transforms.loq_transform import LoqTransform
from fia_api.scripts.transforms.sans_transform import SansTransform


@pytest.fixture()
Expand Down Expand Up @@ -64,8 +64,9 @@ def reduction_1():
"sample_height": 8.0,
"sample_width": 8.0,
"slice_wavs": "[1.0, 2.0, 3.0]",
"phi_limits_list": "[(-20, 20), (30, 160)]",
"phi_limits": "[(-20, 20), (30, 160)]",
}
mock.instrument.instrument_name = "LOQ"
return mock


Expand All @@ -84,8 +85,9 @@ def reduction_2():
"sample_height": 8.0,
"sample_width": 8.0,
"slice_wavs": "[1.0, 2.0, 3.0]",
"phi_limits_list": "[(-20, 20), (30, 160)]",
"phi_limits": "[(-20, 20), (30, 160)]",
}
mock.instrument.instrument_name = "LOQ"
return mock


Expand All @@ -96,7 +98,7 @@ def test_loq_transform_apply(script, reduction_1):
:param reduction_1: The reduction fixture
:return: None
"""
transform = LoqTransform()
transform = SansTransform()

original_lines = script.value.splitlines()
transform.apply(script, reduction_1)
Expand Down Expand Up @@ -134,7 +136,7 @@ def test_loq_transform_apply_with_optionals(script, reduction_2):
:param reduction_2: The reduction fixture
:return: None
"""
transform = LoqTransform()
transform = SansTransform()

original_lines = script.value.splitlines()
transform.apply(script, reduction_2)
Expand Down
170 changes: 170 additions & 0 deletions test/scripts/transforms/test_sans2d_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Test cases for LoqTransform
"""

from unittest.mock import Mock

import pytest

from fia_api.scripts.pre_script import PreScript
from fia_api.scripts.transforms.sans_transform import SansTransform


@pytest.fixture()
def script():
"""
LoqTransform PreScript fixture
:return:
"""
return PreScript(
value="""
import math
import numpy
import csv
import datetime
from mantid.kernel import ConfigService
from mantid.simpleapi import RenameWorkspace, SaveRKH, SaveNXcanSAS, GroupWorkspaces, mtd, ConjoinWorkspaces
from mantid import config
from sans.user_file.toml_parsers.toml_reader import TomlReader
import sans.command_interface.ISISCommandInterface as ici
# Setup by rundetection
user_file = "/extras/sans2d/MaskFile.toml"
sample_scatter = 110754
sample_transmission = None
sample_direct = None
can_scatter = None
can_transmission = None
can_direct = None
sample_thickness = 1.0
sample_geometry = "Disc"
sample_height = 8.0
sample_width = 8.0
slice_wavs = [1.75, 2.75, 3.75, 4.75, 5.75, 6.75, 8.75, 10.75, 12.5]
phi_limits_list = [(-30, 30), (60, 120)]
"""
)


@pytest.fixture()
def reduction_1():
"""
Reduction fixture
:return:
"""
mock = Mock()
mock.inputs = {
"user_file": "/extras/sans2d/BestMaskFile.toml",
"run_number": 10,
"scatter_transmission": 9,
"scatter_direct": 3,
"can_scatter": 5,
"can_transmission": 4,
"can_direct": 3,
"sample_thickness": 2.0,
"sample_geometry": "Disc",
"sample_height": 8.0,
"sample_width": 8.0,
"slice_wavs": "[1.0, 2.0, 3.0]",
"phi_limits": "[(-20, 20), (30, 160)]",
}
mock.instrument.instrument_name = "SANS2D"
return mock


@pytest.fixture()
def reduction_2():
"""
Reduction fixture
:return:
"""
mock = Mock()
mock.inputs = {
"user_file": "/extras/sans2d/BestMaskFile.toml",
"run_number": 5,
"sample_thickness": 2.0,
"sample_geometry": "Disc",
"sample_height": 8.0,
"sample_width": 8.0,
"slice_wavs": "[1.0, 2.0, 3.0]",
"phi_limits": "[(-20, 20), (30, 160)]",
}
mock.instrument.instrument_name = "SANS2D"
return mock


def test_sans2d_transform_apply(script, reduction_1):
"""
Test loq transform applies correct updates to script
:param script: The script fixture
:param reduction_1: The reduction fixture
:return: None
"""
transform = SansTransform()

original_lines = script.value.splitlines()
transform.apply(script, reduction_1)
updated_lines = script.value.splitlines()
assert len(original_lines) == len(updated_lines)
replacements = {
"user_file": 'user_file = "/extras/sans2d/BestMaskFile.toml"',
"sample_scatter": "sample_scatter = 10",
"sample_transmission": "sample_transmission = 9",
"sample_direct": "sample_direct = 3",
"can_scatter": "can_scatter = 5",
"can_transmission": "can_transmission = 4",
"can_direct": "can_direct = 3",
"sample_thickness": "sample_thickness = 2.0",
"sample_geometry": 'sample_geometry = "Disc"',
"sample_height": "sample_height = 8.0",
"sample_width": "sample_width = 8.0",
"slice_wavs": "slice_wavs = [1.0, 2.0, 3.0]",
"phi_limits": "phi_limits_list = [(-20, 20), (30, 160)]",
}

for index, line in enumerate(updated_lines):
for key, expected_line in replacements.items():
if line.startswith(key):
assert line == expected_line
break
else:
assert line == original_lines[index]


def test_sans2d_transform_apply_with_optionals(script, reduction_2):
"""
Test loq transform applies correct updates to script
:param script: The script fixture
:param reduction_2: The reduction fixture
:return: None
"""
transform = SansTransform()

original_lines = script.value.splitlines()
transform.apply(script, reduction_2)
updated_lines = script.value.splitlines()
assert len(original_lines) == len(updated_lines)
replacements = {
"user_file": 'user_file = "/extras/sans2d/BestMaskFile.toml"',
"sample_scatter": "sample_scatter = 5",
"sample_transmission": "sample_transmission = None",
"sample_direct": "sample_direct = None",
"can_scatter": "can_scatter = None",
"can_transmission": "can_transmission = None",
"can_direct": "can_direct = None",
"sample_thickness": "sample_thickness = 2.0",
"sample_geometry": 'sample_geometry = "Disc"',
"sample_height": "sample_height = 8.0",
"sample_width": "sample_width = 8.0",
"slice_wavs": "slice_wavs = [1.0, 2.0, 3.0]",
"phi_limits": "phi_limits_list = [(-20, 20), (30, 160)]",
}

for index, line in enumerate(updated_lines):
for key, expected_line in replacements.items():
if line.startswith(key):
assert line == expected_line
break
else:
assert line == original_lines[index]

0 comments on commit 1ff4c48

Please sign in to comment.